diff --git a/jambo/parser/const_type_parser.py b/jambo/parser/const_type_parser.py index b5c846f..1e4ce84 100644 --- a/jambo/parser/const_type_parser.py +++ b/jambo/parser/const_type_parser.py @@ -3,7 +3,7 @@ from jambo.types.json_schema_type import JSONSchemaNativeTypes from jambo.types.type_parser_options import TypeParserOptions from pydantic import AfterValidator -from typing_extensions import Annotated, Any, Unpack +from typing_extensions import Annotated, Any, Literal, Unpack class ConstTypeParser(GenericTypeParser): @@ -33,11 +33,19 @@ class ConstTypeParser(GenericTypeParser): return const_type, parsed_properties def _build_const_type(self, const_value): - def _validate_const_value(value: Any) -> Any: - if value != const_value: - raise ValueError( - f"Value must be equal to the constant value: {const_value}" - ) - return value + # Try to use Literal for hashable types (required for discriminated unions) + # Fall back to validator approach for non-hashable types + try: + # Test if the value is hashable (can be used in Literal) + hash(const_value) + return Literal[const_value] + except TypeError: + # Non-hashable type (like list, dict), use validator approach + def _validate_const_value(value: Any) -> Any: + if value != const_value: + raise ValueError( + f"Value must be equal to the constant value: {const_value}" + ) + return value - return Annotated[type(const_value), AfterValidator(_validate_const_value)] + return Annotated[type(const_value), AfterValidator(_validate_const_value)] \ No newline at end of file diff --git a/tests/parser/test_const_type_parser.py b/tests/parser/test_const_type_parser.py index ca92bb0..5a8c9c1 100644 --- a/tests/parser/test_const_type_parser.py +++ b/tests/parser/test_const_type_parser.py @@ -1,12 +1,13 @@ from jambo.parser import ConstTypeParser -from typing_extensions import Annotated, get_args, get_origin +from typing_extensions import Annotated, Literal, get_args, get_origin from unittest import TestCase class TestConstTypeParser(TestCase): - def test_const_type_parser(self): + def test_const_type_parser_hashable_value(self): + """Test const parser with hashable values (uses Literal)""" parser = ConstTypeParser() expected_const_value = "United States of America" @@ -16,8 +17,60 @@ class TestConstTypeParser(TestCase): "country", properties ) + # Check that we get a Literal type for hashable values + self.assertEqual(get_origin(parsed_type), Literal) + self.assertEqual(get_args(parsed_type), (expected_const_value,)) + + self.assertEqual(parsed_properties["default"], expected_const_value) + + def test_const_type_parser_non_hashable_value(self): + """Test const parser with non-hashable values (uses Annotated with validator)""" + parser = ConstTypeParser() + + expected_const_value = [1, 2, 3] # Lists are not hashable + properties = {"const": expected_const_value} + + parsed_type, parsed_properties = parser.from_properties_impl( + "list_const", properties + ) + + # Check that we get an Annotated type for non-hashable values self.assertEqual(get_origin(parsed_type), Annotated) - self.assertIn(str, get_args(parsed_type)) + self.assertIn(list, get_args(parsed_type)) + + self.assertEqual(parsed_properties["default"], expected_const_value) + + def test_const_type_parser_integer_value(self): + """Test const parser with integer values (uses Literal)""" + parser = ConstTypeParser() + + expected_const_value = 42 + properties = {"const": expected_const_value} + + parsed_type, parsed_properties = parser.from_properties_impl( + "int_const", properties + ) + + # Check that we get a Literal type for hashable values + self.assertEqual(get_origin(parsed_type), Literal) + self.assertEqual(get_args(parsed_type), (expected_const_value,)) + + self.assertEqual(parsed_properties["default"], expected_const_value) + + def test_const_type_parser_boolean_value(self): + """Test const parser with boolean values (uses Literal)""" + parser = ConstTypeParser() + + expected_const_value = True + properties = {"const": expected_const_value} + + parsed_type, parsed_properties = parser.from_properties_impl( + "bool_const", properties + ) + + # Check that we get a Literal type for hashable values + self.assertEqual(get_origin(parsed_type), Literal) + self.assertEqual(get_args(parsed_type), (expected_const_value,)) self.assertEqual(parsed_properties["default"], expected_const_value) @@ -46,4 +99,4 @@ class TestConstTypeParser(TestCase): self.assertIn( "Const type invalid_country must have 'const' value of allowed types", str(context.exception), - ) + ) \ No newline at end of file diff --git a/tests/test_schema_converter.py b/tests/test_schema_converter.py index c5a7c3a..9a756c7 100644 --- a/tests/test_schema_converter.py +++ b/tests/test_schema_converter.py @@ -701,6 +701,29 @@ class TestSchemaConverter(TestCase): with self.assertRaises(ValueError): Model(name="Canada") + def test_const_type_parser_with_non_hashable_value(self): + schema = { + "title": "Country", + "type": "object", + "properties": { + "name": { + "const": ["Brazil"], + } + }, + "required": ["name"], + } + + Model = SchemaConverter.build(schema) + + obj = Model() + self.assertEqual(obj.name, ["Brazil"]) + + with self.assertRaises(ValueError): + obj.name = ["Argentina"] + + with self.assertRaises(ValueError): + Model(name=["Argentina"]) + def test_null_type_parser(self): schema = { "title": "Test",