[FEATURE] Implements OneOf #37

Merged
HideyoshiNakazone merged 7 commits from feature/implements-one-of into main 2025-08-19 23:45:30 +00:00
3 changed files with 96 additions and 12 deletions
Showing only changes of commit b386d4954e - Show all commits

View File

@@ -3,7 +3,7 @@ from jambo.types.json_schema_type import JSONSchemaNativeTypes
from jambo.types.type_parser_options import TypeParserOptions
from pydantic import AfterValidator
from typing_extensions import Annotated, Any, Unpack
from typing_extensions import Annotated, Any, Literal, Unpack
class ConstTypeParser(GenericTypeParser):
@@ -33,11 +33,19 @@ class ConstTypeParser(GenericTypeParser):
return const_type, parsed_properties
def _build_const_type(self, const_value):
def _validate_const_value(value: Any) -> Any:
if value != const_value:
raise ValueError(
f"Value must be equal to the constant value: {const_value}"
)
return value
# Try to use Literal for hashable types (required for discriminated unions)
# Fall back to validator approach for non-hashable types
try:
# Test if the value is hashable (can be used in Literal)
hash(const_value)
return Literal[const_value]
except TypeError:
# Non-hashable type (like list, dict), use validator approach
def _validate_const_value(value: Any) -> Any:
if value != const_value:
raise ValueError(
f"Value must be equal to the constant value: {const_value}"
)
return value
return Annotated[type(const_value), AfterValidator(_validate_const_value)]
return Annotated[type(const_value), AfterValidator(_validate_const_value)]

View File

@@ -1,12 +1,13 @@
from jambo.parser import ConstTypeParser
from typing_extensions import Annotated, get_args, get_origin
from typing_extensions import Annotated, Literal, get_args, get_origin
from unittest import TestCase
class TestConstTypeParser(TestCase):
def test_const_type_parser(self):
def test_const_type_parser_hashable_value(self):
"""Test const parser with hashable values (uses Literal)"""
parser = ConstTypeParser()
expected_const_value = "United States of America"
@@ -16,8 +17,60 @@ class TestConstTypeParser(TestCase):
"country", properties
)
# Check that we get a Literal type for hashable values
self.assertEqual(get_origin(parsed_type), Literal)
self.assertEqual(get_args(parsed_type), (expected_const_value,))
self.assertEqual(parsed_properties["default"], expected_const_value)
def test_const_type_parser_non_hashable_value(self):
"""Test const parser with non-hashable values (uses Annotated with validator)"""
parser = ConstTypeParser()
expected_const_value = [1, 2, 3] # Lists are not hashable
properties = {"const": expected_const_value}
parsed_type, parsed_properties = parser.from_properties_impl(
"list_const", properties
)
# Check that we get an Annotated type for non-hashable values
self.assertEqual(get_origin(parsed_type), Annotated)
self.assertIn(str, get_args(parsed_type))
self.assertIn(list, get_args(parsed_type))
self.assertEqual(parsed_properties["default"], expected_const_value)
def test_const_type_parser_integer_value(self):
"""Test const parser with integer values (uses Literal)"""
parser = ConstTypeParser()
expected_const_value = 42
properties = {"const": expected_const_value}
parsed_type, parsed_properties = parser.from_properties_impl(
"int_const", properties
)
# Check that we get a Literal type for hashable values
self.assertEqual(get_origin(parsed_type), Literal)
self.assertEqual(get_args(parsed_type), (expected_const_value,))
self.assertEqual(parsed_properties["default"], expected_const_value)
def test_const_type_parser_boolean_value(self):
"""Test const parser with boolean values (uses Literal)"""
parser = ConstTypeParser()
expected_const_value = True
properties = {"const": expected_const_value}
parsed_type, parsed_properties = parser.from_properties_impl(
"bool_const", properties
)
# Check that we get a Literal type for hashable values
self.assertEqual(get_origin(parsed_type), Literal)
self.assertEqual(get_args(parsed_type), (expected_const_value,))
self.assertEqual(parsed_properties["default"], expected_const_value)

View File

@@ -701,6 +701,29 @@ class TestSchemaConverter(TestCase):
with self.assertRaises(ValueError):
Model(name="Canada")
def test_const_type_parser_with_non_hashable_value(self):
schema = {
"title": "Country",
"type": "object",
"properties": {
"name": {
"const": ["Brazil"],
}
},
"required": ["name"],
}
Model = SchemaConverter.build(schema)
obj = Model()
self.assertEqual(obj.name, ["Brazil"])
with self.assertRaises(ValueError):
obj.name = ["Argentina"]
with self.assertRaises(ValueError):
Model(name=["Argentina"])
def test_null_type_parser(self):
schema = {
"title": "Test",