diff --git a/jambo/parser/_type_parser.py b/jambo/parser/_type_parser.py index e64ae77..024fed3 100644 --- a/jambo/parser/_type_parser.py +++ b/jambo/parser/_type_parser.py @@ -1,32 +1,36 @@ -from pydantic import Field -from typing_extensions import Self +from pydantic import Field, TypeAdapter +from typing_extensions import Annotated, Self from abc import ABC, abstractmethod -from typing import Generic, TypeVar +from typing import Generic, Type, TypeVar T = TypeVar("T") class GenericTypeParser(ABC, Generic[T]): - @property - @abstractmethod - def mapped_type(self) -> type[T]: ... + mapped_type: Type[T] = None - @property - @abstractmethod - def json_schema_type(self) -> str: ... + json_schema_type: str = None @staticmethod @abstractmethod def from_properties( - name: str, properties: dict[str, any] - ) -> tuple[type[T], Field]: ... + name: str, properties: dict[str, any], required: bool = False + ) -> tuple[T, dict]: ... @classmethod def get_impl(cls, type_name: str) -> Self: for subcls in cls.__subclasses__(): + if subcls.json_schema_type is None: + raise RuntimeError(f"Unknown type: {type_name}") + if subcls.json_schema_type == type_name: - return subcls + return subcls() raise ValueError(f"Unknown type: {type_name}") + + @staticmethod + def validate_default(field_type: type, field_prop: dict, value): + field = Annotated[field_type, Field(**field_prop)] + TypeAdapter(field).validate_python(value) diff --git a/jambo/parser/allof_type_parser.py b/jambo/parser/allof_type_parser.py index 6f23236..7a65b49 100644 --- a/jambo/parser/allof_type_parser.py +++ b/jambo/parser/allof_type_parser.py @@ -7,7 +7,7 @@ class AllOfTypeParser(GenericTypeParser): json_schema_type = "allOf" @staticmethod - def from_properties(name, properties): + def from_properties(name, properties, required=False): subProperties = properties.get("allOf") if not subProperties: raise ValueError("Invalid JSON Schema: 'allOf' is not specified.") diff --git a/jambo/parser/anyof_type_parser.py b/jambo/parser/anyof_type_parser.py index 8ee5543..49408be 100644 --- a/jambo/parser/anyof_type_parser.py +++ b/jambo/parser/anyof_type_parser.py @@ -1,5 +1,8 @@ from jambo.parser._type_parser import GenericTypeParser +from pydantic import Field +from typing_extensions import Annotated + from typing import Union @@ -9,20 +12,47 @@ class AnyOfTypeParser(GenericTypeParser): json_schema_type = "anyOf" @staticmethod - def from_properties(name, properties): + def from_properties(name, properties, required=False): if "anyOf" not in properties: raise ValueError(f"Invalid JSON Schema: {properties}") if not isinstance(properties["anyOf"], list): raise ValueError(f"Invalid JSON Schema: {properties['anyOf']}") + mapped_properties = dict() + subProperties = properties["anyOf"] - types = [ + sub_types = [ GenericTypeParser.get_impl(subProperty["type"]).from_properties( name, subProperty ) for subProperty in subProperties ] - return Union[*(t for t, v in types)], {} + default_value = properties.get("default") + if default_value is not None: + for sub_type, sub_property in sub_types: + try: + GenericTypeParser.validate_default( + sub_type, sub_property, default_value + ) + break + except ValueError: + continue + else: + raise ValueError( + f"Invalid default value {default_value} for anyOf types: {sub_types}" + ) + + mapped_properties["default"] = default_value + + if not required: + mapped_properties["default"] = mapped_properties.get("default") + + # By defining the type as Union, we can use the Field validator to enforce + # the constraints on the union type. + # We use Annotated to attach the Field validators to the type. + field_types = [Annotated[t, Field(**v)] if v else t for t, v in sub_types] + + return Union[(*field_types,)], mapped_properties diff --git a/jambo/parser/array_type_parser.py b/jambo/parser/array_type_parser.py index 20a6125..d2d6e30 100644 --- a/jambo/parser/array_type_parser.py +++ b/jambo/parser/array_type_parser.py @@ -15,11 +15,11 @@ class ArrayTypeParser(GenericTypeParser): json_schema_type = "array" - @classmethod - def from_properties(cls, name, properties): + @staticmethod + def from_properties(name, properties, required=False): _item_type, _item_args = GenericTypeParser.get_impl( properties["items"]["type"] - ).from_properties(name, properties["items"]) + ).from_properties(name, properties["items"], required=True) _mappings = { "maxItems": "max_length", @@ -29,11 +29,14 @@ class ArrayTypeParser(GenericTypeParser): wrapper_type = set if properties.get("uniqueItems", False) else list mapped_properties = mappings_properties_builder( - properties, _mappings, {"description": "description"} + properties, + _mappings, + required=required, + default_mappings={"description": "description"}, ) - if "default" in properties: - default_list = properties["default"] + default_list = properties.get("default") + if default_list is not None: if not isinstance(default_list, list): raise ValueError( f"Default value must be a list, got {type(default_list).__name__}" @@ -63,4 +66,7 @@ class ArrayTypeParser(GenericTypeParser): default_list ) + if "default_factory" in mapped_properties and "default" in mapped_properties: + del mapped_properties["default"] + return wrapper_type[_item_type], mapped_properties diff --git a/jambo/parser/boolean_type_parser.py b/jambo/parser/boolean_type_parser.py index 1dec65d..4b21cf8 100644 --- a/jambo/parser/boolean_type_parser.py +++ b/jambo/parser/boolean_type_parser.py @@ -10,7 +10,7 @@ class BooleanTypeParser(GenericTypeParser): json_schema_type = "boolean" @staticmethod - def from_properties(name, properties): + def from_properties(name, properties, required=False): _mappings = { "default": "default", } diff --git a/jambo/parser/float_type_parser.py b/jambo/parser/float_type_parser.py index a6dcdd5..5326c08 100644 --- a/jambo/parser/float_type_parser.py +++ b/jambo/parser/float_type_parser.py @@ -10,5 +10,5 @@ class FloatTypeParser(GenericTypeParser): json_schema_type = "number" @staticmethod - def from_properties(name, properties): - return float, numeric_properties_builder(properties) + def from_properties(name, properties, required=False): + return float, numeric_properties_builder(properties, required) diff --git a/jambo/parser/int_type_parser.py b/jambo/parser/int_type_parser.py index 1ef907b..82bbfb9 100644 --- a/jambo/parser/int_type_parser.py +++ b/jambo/parser/int_type_parser.py @@ -10,5 +10,5 @@ class IntTypeParser(GenericTypeParser): json_schema_type = "integer" @staticmethod - def from_properties(name, properties): - return int, numeric_properties_builder(properties) + def from_properties(name, properties, required=False): + return int, numeric_properties_builder(properties, required) diff --git a/jambo/parser/object_type_parser.py b/jambo/parser/object_type_parser.py index 7c0c363..f7c9f6a 100644 --- a/jambo/parser/object_type_parser.py +++ b/jambo/parser/object_type_parser.py @@ -7,7 +7,7 @@ class ObjectTypeParser(GenericTypeParser): json_schema_type = "object" @staticmethod - def from_properties(name, properties): + def from_properties(name, properties, required=False): from jambo.schema_converter import SchemaConverter type_parsing = SchemaConverter.build_object(name, properties) diff --git a/jambo/parser/string_type_parser.py b/jambo/parser/string_type_parser.py index dc44ca4..7b77c75 100644 --- a/jambo/parser/string_type_parser.py +++ b/jambo/parser/string_type_parser.py @@ -10,31 +10,17 @@ class StringTypeParser(GenericTypeParser): json_schema_type = "string" @staticmethod - def from_properties(name, properties): + def from_properties(name, properties, required=False): _mappings = { "maxLength": "max_length", "minLength": "min_length", "pattern": "pattern", } - mapped_properties = mappings_properties_builder(properties, _mappings) + mapped_properties = mappings_properties_builder(properties, _mappings, required) - if "default" in properties: - default_value = properties["default"] - if not isinstance(default_value, str): - raise ValueError( - f"Default value for {name} must be a string, " - f"but got <{type(properties['default']).__name__}>." - ) - - if len(default_value) > properties.get("maxLength", float("inf")): - raise ValueError( - f"Default value for {name} exceeds maxLength limit of {properties.get('maxLength')}" - ) - - if len(default_value) < properties.get("minLength", 0): - raise ValueError( - f"Default value for {name} is below minLength limit of {properties.get('minLength')}" - ) + default_value = properties.get("default") + if default_value is not None: + StringTypeParser.validate_default(str, mapped_properties, default_value) return str, mapped_properties diff --git a/jambo/schema_converter.py b/jambo/schema_converter.py index 8dadb7e..39a5848 100644 --- a/jambo/schema_converter.py +++ b/jambo/schema_converter.py @@ -71,14 +71,13 @@ class SchemaConverter: fields = {} for name, prop in properties.items(): - fields[name] = SchemaConverter._build_field(name, prop, required_keys) + is_required = name in required_keys + fields[name] = SchemaConverter._build_field(name, prop, is_required) return fields @staticmethod - def _build_field( - name, properties: dict, required_keys: list[str] - ) -> tuple[type, dict]: + def _build_field(name, properties: dict, required=False) -> tuple[type, dict]: match properties: case {"anyOf": _}: _field_type = "anyOf" @@ -91,17 +90,6 @@ class SchemaConverter: _field_type, _field_args = GenericTypeParser.get_impl( _field_type - ).from_properties(name, properties) - - _field_args = _field_args or {} - - if description := properties.get("description"): - _field_args["description"] = description - - if name not in required_keys: - _field_args["default"] = properties.get("default", None) - - if "default_factory" in _field_args and "default" in _field_args: - del _field_args["default"] + ).from_properties(name, properties, required) return _field_type, Field(**_field_args) diff --git a/jambo/utils/properties_builder/mappings_properties_builder.py b/jambo/utils/properties_builder/mappings_properties_builder.py index f743891..e78fc8e 100644 --- a/jambo/utils/properties_builder/mappings_properties_builder.py +++ b/jambo/utils/properties_builder/mappings_properties_builder.py @@ -1,4 +1,9 @@ -def mappings_properties_builder(properties, mappings, default_mappings=None): +def mappings_properties_builder( + properties, mappings, required=False, default_mappings=None +): + if not required: + properties["default"] = properties.get("default", None) + default_mappings = default_mappings or { "default": "default", "description": "description", diff --git a/jambo/utils/properties_builder/numeric_properties_builder.py b/jambo/utils/properties_builder/numeric_properties_builder.py index f38dea1..343ef4c 100644 --- a/jambo/utils/properties_builder/numeric_properties_builder.py +++ b/jambo/utils/properties_builder/numeric_properties_builder.py @@ -3,7 +3,7 @@ from jambo.utils.properties_builder.mappings_properties_builder import ( ) -def numeric_properties_builder(properties): +def numeric_properties_builder(properties, required=False): _mappings = { "minimum": "ge", "exclusiveMinimum": "gt", @@ -13,9 +13,10 @@ def numeric_properties_builder(properties): "default": "default", } - mapped_properties = mappings_properties_builder(properties, _mappings) + mapped_properties = mappings_properties_builder(properties, _mappings, required) - if "default" in properties: + default_value = properties.get("default") + if default_value is not None: default_value = properties["default"] if not isinstance(default_value, (int, float)): raise ValueError( diff --git a/tests/parser/test_anyof_type_parser.py b/tests/parser/test_anyof_type_parser.py index cb97e1b..896b394 100644 --- a/tests/parser/test_anyof_type_parser.py +++ b/tests/parser/test_anyof_type_parser.py @@ -1,5 +1,7 @@ from jambo.parser.anyof_type_parser import AnyOfTypeParser +from typing_extensions import Annotated + from typing import Union, get_args, get_origin from unittest import TestCase @@ -21,5 +23,41 @@ class TestAnyOfTypeParser(TestCase): # check union type has string and int self.assertEqual(get_origin(type_parsing), Union) - self.assertIn(str, get_args(type_parsing)) - self.assertIn(int, get_args(type_parsing)) + + type_1, type_2 = get_args(type_parsing) + + self.assertEqual(get_origin(type_1), Annotated) + self.assertIn(str, get_args(type_1)) + + self.assertEqual(get_origin(type_2), Annotated) + self.assertIn(int, get_args(type_2)) + + def test_any_of_string_or_int_with_default(self): + """ + Tests the AnyOfTypeParser with a string or int type and a default value. + """ + + properties = { + "anyOf": [ + {"type": "string"}, + {"type": "integer"}, + ], + "default": 42, + } + + type_parsing, type_validator = AnyOfTypeParser.from_properties( + "placeholder", properties + ) + + # check union type has string and int + self.assertEqual(get_origin(type_parsing), Union) + + type_1, type_2 = get_args(type_parsing) + + self.assertEqual(get_origin(type_1), Annotated) + self.assertIn(str, get_args(type_1)) + + self.assertEqual(get_origin(type_2), Annotated) + self.assertIn(int, get_args(type_2)) + + self.assertEqual(type_validator["default"], 42) diff --git a/tests/parser/test_bool_type_parser.py b/tests/parser/test_bool_type_parser.py index 1ba25aa..92c1513 100644 --- a/tests/parser/test_bool_type_parser.py +++ b/tests/parser/test_bool_type_parser.py @@ -12,7 +12,7 @@ class TestBoolTypeParser(TestCase): type_parsing, type_validator = parser.from_properties("placeholder", properties) self.assertEqual(type_parsing, bool) - self.assertEqual(type_validator, {}) + self.assertEqual(type_validator, {"default": None}) def test_bool_parser_with_default(self): parser = BooleanTypeParser() diff --git a/tests/parser/test_float_type_parser.py b/tests/parser/test_float_type_parser.py index c25ab49..66a29f0 100644 --- a/tests/parser/test_float_type_parser.py +++ b/tests/parser/test_float_type_parser.py @@ -12,7 +12,7 @@ class TestFloatTypeParser(TestCase): type_parsing, type_validator = parser.from_properties("placeholder", properties) self.assertEqual(type_parsing, float) - self.assertEqual(type_validator, {}) + self.assertEqual(type_validator, {"default": None}) def test_float_parser_with_options(self): parser = FloatTypeParser() diff --git a/tests/parser/test_int_type_parser.py b/tests/parser/test_int_type_parser.py index 64da2bd..84c9e17 100644 --- a/tests/parser/test_int_type_parser.py +++ b/tests/parser/test_int_type_parser.py @@ -12,7 +12,7 @@ class TestIntTypeParser(TestCase): type_parsing, type_validator = parser.from_properties("placeholder", properties) self.assertEqual(type_parsing, int) - self.assertEqual(type_validator, {}) + self.assertEqual(type_validator, {"default": None}) def test_int_parser_with_options(self): parser = IntTypeParser() diff --git a/tests/parser/test_string_type_parser.py b/tests/parser/test_string_type_parser.py index 9cdf901..92161d0 100644 --- a/tests/parser/test_string_type_parser.py +++ b/tests/parser/test_string_type_parser.py @@ -57,14 +57,9 @@ class TestStringTypeParser(TestCase): "minLength": 5, } - with self.assertRaises(ValueError) as context: + with self.assertRaises(ValueError): parser.from_properties("placeholder", properties) - self.assertEqual( - str(context.exception), - "Default value for placeholder must be a string, but got .", - ) - def test_string_parser_with_default_invalid_maxlength(self): parser = StringTypeParser() @@ -75,14 +70,9 @@ class TestStringTypeParser(TestCase): "minLength": 1, } - with self.assertRaises(ValueError) as context: + with self.assertRaises(ValueError): parser.from_properties("placeholder", properties) - self.assertEqual( - str(context.exception), - "Default value for placeholder exceeds maxLength limit of 2", - ) - def test_string_parser_with_default_invalid_minlength(self): parser = StringTypeParser() @@ -93,10 +83,5 @@ class TestStringTypeParser(TestCase): "minLength": 2, } - with self.assertRaises(ValueError) as context: + with self.assertRaises(ValueError): parser.from_properties("placeholder", properties) - - self.assertEqual( - str(context.exception), - "Default value for placeholder is below minLength limit of 2", - )