diff --git a/jambo/parser/const_type_parser.py b/jambo/parser/const_type_parser.py index a408264..b5c846f 100644 --- a/jambo/parser/const_type_parser.py +++ b/jambo/parser/const_type_parser.py @@ -9,6 +9,11 @@ from typing_extensions import Annotated, Any, Unpack class ConstTypeParser(GenericTypeParser): json_schema_type = "const" + default_mappings = { + "const": "default", + "description": "description", + } + def from_properties_impl( self, name, properties, **kwargs: Unpack[TypeParserOptions] ): @@ -21,15 +26,12 @@ class ConstTypeParser(GenericTypeParser): raise ValueError( f"Const type {name} must have 'const' value of allowed types: {JSONSchemaNativeTypes}." ) - + const_type = self._build_const_type(const_value) - parsed_properties = { - "default": const_value, - "description": properties.get("description"), - } - + parsed_properties = self.mappings_properties_builder(properties, **kwargs) + return const_type, parsed_properties - + def _build_const_type(self, const_value): def _validate_const_value(value: Any) -> Any: if value != const_value: @@ -37,10 +39,5 @@ class ConstTypeParser(GenericTypeParser): f"Value must be equal to the constant value: {const_value}" ) return value - - return Annotated[ - type(const_value), - AfterValidator( - _validate_const_value - ) - ] \ No newline at end of file + + return Annotated[type(const_value), AfterValidator(_validate_const_value)] diff --git a/jambo/parser/object_type_parser.py b/jambo/parser/object_type_parser.py index 6833d40..8deb5ac 100644 --- a/jambo/parser/object_type_parser.py +++ b/jambo/parser/object_type_parser.py @@ -1,7 +1,7 @@ from jambo.parser._type_parser import GenericTypeParser from jambo.types.type_parser_options import TypeParserOptions -from pydantic import BaseModel, Field, create_model +from pydantic import BaseModel, ConfigDict, Field, create_model from typing_extensions import Any, Unpack @@ -43,8 +43,10 @@ class ObjectTypeParser(GenericTypeParser): :param required_keys: List of required keys in the schema. :return: A Pydantic model class. """ + model_config = ConfigDict(validate_assignment=True) fields = cls._parse_properties(schema, required_keys, **kwargs) - return create_model(name, **fields) + + return create_model(name, __config__=model_config, **fields) @classmethod def _parse_properties( diff --git a/tests/parser/test_const_type_parser.py b/tests/parser/test_const_type_parser.py index d37fa57..ca92bb0 100644 --- a/tests/parser/test_const_type_parser.py +++ b/tests/parser/test_const_type_parser.py @@ -1,23 +1,49 @@ -from typing_extensions import Annotated, get_args, get_origin -from webbrowser import get from jambo.parser import ConstTypeParser +from typing_extensions import Annotated, get_args, get_origin + from unittest import TestCase class TestConstTypeParser(TestCase): - def test_parse_const_type(self): + def test_const_type_parser(self): parser = ConstTypeParser() expected_const_value = "United States of America" - properties = { - "const": expected_const_value - } + properties = {"const": expected_const_value} - parsed_type, parsed_properties = parser.from_properties( + parsed_type, parsed_properties = parser.from_properties_impl( "country", properties ) self.assertEqual(get_origin(parsed_type), Annotated) + self.assertIn(str, get_args(parsed_type)) - self.assertIn(str, get_args(parsed_type)) \ No newline at end of file + self.assertEqual(parsed_properties["default"], expected_const_value) + + def test_const_type_parser_invalid_properties(self): + parser = ConstTypeParser() + + expected_const_value = "United States of America" + properties = {"notConst": expected_const_value} + + with self.assertRaises(ValueError) as context: + parser.from_properties_impl("invalid_country", properties) + + self.assertIn( + "Const type invalid_country must have 'const' property defined", + str(context.exception), + ) + + def test_const_type_parser_invalid_const_value(self): + parser = ConstTypeParser() + + properties = {"const": {}} + + with self.assertRaises(ValueError) as context: + parser.from_properties_impl("invalid_country", properties) + + self.assertIn( + "Const type invalid_country must have 'const' value of allowed types", + str(context.exception), + ) diff --git a/tests/test_schema_converter.py b/tests/test_schema_converter.py index a980e3c..fbba3c9 100644 --- a/tests/test_schema_converter.py +++ b/tests/test_schema_converter.py @@ -634,3 +634,26 @@ class TestSchemaConverter(TestCase): obj = Model() self.assertEqual(obj.status.value, "active") + + def test_const_type_parser(self): + schema = { + "title": "Country", + "type": "object", + "properties": { + "name": { + "const": "United States of America", + } + }, + "required": ["name"], + } + + Model = SchemaConverter.build(schema) + + obj = Model() + self.assertEqual(obj.name, "United States of America") + + with self.assertRaises(ValueError): + obj.name = "Canada" + + with self.assertRaises(ValueError): + Model(name="Canada")