diff --git a/jambo/parser/ref_type_parser.py b/jambo/parser/ref_type_parser.py index 5229f0e..57abeac 100644 --- a/jambo/parser/ref_type_parser.py +++ b/jambo/parser/ref_type_parser.py @@ -6,6 +6,8 @@ from typing_extensions import Any, ForwardRef, Literal, TypeVar, Union, Unpack RefType = TypeVar("RefType", bound=Union[type, ForwardRef]) +RefStrategy = Literal["forward_ref", "def_ref"] + class RefTypeParser(GenericTypeParser): json_schema_type = "$ref" @@ -16,38 +18,42 @@ class RefTypeParser(GenericTypeParser): if "$ref" not in properties: raise ValueError(f"RefTypeParser: Missing $ref in properties for {name}") - context = kwargs["context"] - ref_cache = kwargs["ref_cache"] - - mapped_type = None - mapped_properties = self.mappings_properties_builder(properties, **kwargs) - + context = kwargs.get("context") if context is None: raise RuntimeError( - f"RefTypeParser: Missing $content in properties for {name}" + f"RefTypeParser: Missing `content` in properties for {name}" ) - if not properties["$ref"].startswith("#"): - raise ValueError( - "At the moment, only local references are supported. " - "Look into $defs and # for recursive references." + ref_cache = kwargs.get("ref_cache") + if ref_cache is None: + raise RuntimeError( + f"RefTypeParser: Missing `ref_cache` in properties for {name}" ) + mapped_properties = self.mappings_properties_builder(properties, **kwargs) + ref_strategy, ref_name, ref_property = self._examine_ref_strategy( name, properties, **kwargs ) - # In this code ellipsis is used to indicate that the reference is still being processed, - # If the reference is already in the cache, return it. - ref_state = ref_cache.setdefault(ref_name) - - if ref_state is Ellipsis: - return ForwardRef(ref_name), mapped_properties - elif ref_state is not None: + ref_state = self._get_ref_from_cache(ref_name, ref_cache) + if ref_state is not None: + # If the reference is either processing or already cached return ref_state, mapped_properties - else: - ref_cache[ref_name] = Ellipsis + ref_cache[ref_name] = self._parse_from_strategy( + ref_strategy, ref_name, ref_property, **kwargs + ) + + return ref_cache[ref_name], mapped_properties + + def _parse_from_strategy( + self, + ref_strategy: RefStrategy, + ref_name: str, + ref_property: dict[str, Any], + **kwargs: Unpack[TypeParserOptions], + ): match ref_strategy: case "forward_ref": mapped_type = ForwardRef(ref_name) @@ -57,22 +63,35 @@ class RefTypeParser(GenericTypeParser): ) case _: raise ValueError( - f"RefTypeParser: Unsupported $ref {properties['$ref']}" + f"RefTypeParser: Unsupported $ref {ref_property['$ref']}" ) - # Sets cached reference to the mapped type - ref_cache[ref_name] = mapped_type + return mapped_type - return mapped_type, mapped_properties + def _get_ref_from_cache( + self, ref_name: str, ref_cache: dict[str, type] + ) -> RefType | type | None: + try: + ref_state = ref_cache[ref_name] + + if ref_state is None: + # If the reference is being processed, we return a ForwardRef + return ForwardRef(ref_name) + + # If the reference is already cached, we return it + return ref_state + except KeyError: + # If the reference is not in the cache, we will set it to None + ref_cache[ref_name] = None def _examine_ref_strategy( self, name: str, properties: dict[str, Any], **kwargs: Unpack[TypeParserOptions] - ) -> tuple[Literal["forward_ref", "def_ref"], str, dict]: + ) -> tuple[RefStrategy, str, dict] | None: if properties["$ref"] == "#": ref_name = kwargs["context"].get("title") if ref_name is None: raise ValueError( - f"RefTypeParser: Missing title in properties for $ref {properties['$ref']}" + "RefTypeParser: Missing title in properties for $ref of Root Reference" ) return "forward_ref", ref_name, {} @@ -82,7 +101,9 @@ class RefTypeParser(GenericTypeParser): ) return "def_ref", target_name, target_property - raise ValueError(f"RefTypeParser: Unsupported $ref {properties['$ref']}") + raise ValueError( + "RefTypeParser: Only Root and $defs references are supported at the moment" + ) def _extract_target_ref( self, name: str, properties: dict[str, Any], **kwargs: Unpack[TypeParserOptions] diff --git a/jambo/schema_converter.py b/jambo/schema_converter.py index 3bca9c5..952cdc2 100644 --- a/jambo/schema_converter.py +++ b/jambo/schema_converter.py @@ -34,10 +34,16 @@ class SchemaConverter: schema_type = SchemaConverter._get_schema_type(schema) - parsed_model = None match schema_type: case "object": - parsed_model = SchemaConverter._from_object(schema) + return ObjectTypeParser.to_model( + schema["title"], + schema["properties"], + schema.get("required", []), + context=schema, + ref_cache=dict(), + ) + case "$ref": parsed_model, _ = RefTypeParser().from_properties( schema["title"], @@ -46,35 +52,10 @@ class SchemaConverter: ref_cache=dict(), required=True, ) + return parsed_model case _: raise TypeError(f"Unsupported schema type: {schema_type}") - if not issubclass(parsed_model, BaseModel): - raise TypeError( - f"Parsed model {parsed_model.__name__} is not a subclass of BaseModel." - ) - - return parsed_model - - @staticmethod - def _from_object(schema: JSONSchema) -> type[BaseModel]: - """ - Converts a JSON Schema object to a Pydantic model. - :param schema: The JSON Schema object to convert. - :return: A Pydantic model class. - """ - - if "properties" not in schema: - raise ValueError("JSON Schema object must have properties defined.") - - return ObjectTypeParser.to_model( - schema["title"], - schema["properties"], - schema.get("required", []), - context=schema, - ref_cache=dict(), - ) - @staticmethod def _get_schema_type(schema: JSONSchema) -> str: """ diff --git a/tests/parser/test_ref_type_parser.py b/tests/parser/test_ref_type_parser.py index c6ad1c3..3e08ff4 100644 --- a/tests/parser/test_ref_type_parser.py +++ b/tests/parser/test_ref_type_parser.py @@ -1,9 +1,156 @@ from jambo.parser import ObjectTypeParser, RefTypeParser +from typing import ForwardRef from unittest import TestCase class TestRefTypeParser(TestCase): + def test_ref_type_parser_throws_without_ref(self): + properties = { + "title": "person", + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name", "age"], + } + + with self.assertRaises(ValueError): + RefTypeParser().from_properties( + "person", + properties, + context=properties, + ref_cache={}, + required=True, + ) + + def test_ref_type_parser_throws_without_context(self): + properties = { + "title": "person", + "$ref": "#/$defs/person", + "$defs": { + "person": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + } + }, + } + + with self.assertRaises(RuntimeError): + RefTypeParser().from_properties( + "person", + properties, + ref_cache={}, + required=True, + ) + + def test_ref_type_parser_throws_without_ref_cache(self): + properties = { + "title": "person", + "$ref": "#/$defs/person", + "$defs": { + "person": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + } + }, + } + + with self.assertRaises(RuntimeError): + RefTypeParser().from_properties( + "person", + properties, + context=properties, + required=True, + ) + + def test_ref_type_parser_throws_if_network_ref_type(self): + properties = { + "title": "person", + "$ref": "https://example.com/schemas/person.json", + } + + with self.assertRaises(ValueError): + RefTypeParser().from_properties( + "person", + properties, + context=properties, + ref_cache={}, + required=True, + ) + + def test_ref_type_parser_throws_if_non_root_or_def_ref(self): + # This is invalid because object3 is referencing object2, + # but object2 is not defined in $defs or as a root reference. + properties = { + "title": "object1", + "type": "object", + "properties": { + "object2": { + "type": "object", + "properties": { + "attr1": { + "type": "string", + }, + "attr2": { + "type": "integer", + }, + }, + }, + "object3": { + "$ref": "#/$defs/object2", + }, + }, + } + + with self.assertRaises(ValueError): + ObjectTypeParser().from_properties( + "person", + properties, + context=properties, + ref_cache={}, + required=True, + ) + + def test_ref_type_parser_throws_if_def_doesnt_exists(self): + properties = { + "title": "person", + "$ref": "#/$defs/employee", + "$defs": {}, + } + + with self.assertRaises(ValueError): + RefTypeParser().from_properties( + "person", + properties, + context=properties, + ref_cache={}, + required=True, + ) + + def test_ref_type_parser_throws_if_ref_property_doesnt_exists(self): + properties = { + "title": "person", + "$ref": "#/$defs/person", + "$defs": {"person": None}, + } + + with self.assertRaises(ValueError): + RefTypeParser().from_properties( + "person", + properties, + context=properties, + ref_cache={}, + required=True, + ) + def test_ref_type_parser_with_def(self): properties = { "title": "person", @@ -71,6 +218,29 @@ class TestRefTypeParser(TestCase): self.assertEqual(obj.emergency_contact.name, "Jane") self.assertEqual(obj.emergency_contact.age, 28) + def test_ref_type_parser_invalid_forward_ref(self): + properties = { + # Doesn't have a title, which is required for forward references + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "emergency_contact": { + "$ref": "#", + }, + }, + "required": ["name", "age"], + } + + with self.assertRaises(ValueError): + ObjectTypeParser().from_properties( + "person", + properties, + context=properties, + ref_cache={}, + required=True, + ) + def test_ref_type_parser_forward_ref_can_checks_validation(self): properties = { "title": "person", @@ -143,3 +313,172 @@ class TestRefTypeParser(TestCase): self.assertIsInstance(obj.emergency_contact, model) self.assertEqual(obj.emergency_contact.name, "Jane") self.assertEqual(obj.emergency_contact.age, 28) + + def test_ref_type_parser_with_repeated_ref(self): + properties = { + "title": "person", + "$ref": "#/$defs/person", + "$defs": { + "person": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "emergency_contact": { + "$ref": "#/$defs/person", + }, + "friends": { + "type": "array", + "items": { + "$ref": "#/$defs/person", + }, + }, + }, + } + }, + } + + model, type_validator = RefTypeParser().from_properties( + "person", + properties, + context=properties, + ref_cache={}, + required=True, + ) + + obj = model( + name="John", + age=30, + emergency_contact=model( + name="Jane", + age=28, + ), + friends=[ + model(name="Alice", age=25), + model(name="Bob", age=26), + ], + ) + + self.assertEqual( + type(obj.emergency_contact), + type(obj.friends[0]), + "Emergency contact and friends should be of the same type", + ) + + def test_ref_type_parser_pre_computed_ref_cache(self): + ref_cache = {} + + parent_properties = { + "$defs": { + "person": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + } + }, + } + + properties1 = { + "title": "person1", + "$ref": "#/$defs/person", + } + model1, _ = RefTypeParser().from_properties( + "person", + properties1, + context=parent_properties, + ref_cache=ref_cache, + required=True, + ) + + properties2 = { + "title": "person2", + "$ref": "#/$defs/person", + } + model2, _ = RefTypeParser().from_properties( + "person", + properties2, + context=parent_properties, + ref_cache=ref_cache, + required=True, + ) + + self.assertIs(model1, model2, "Models should be the same instance") + + def test_parse_from_strategy_invalid_ref_strategy(self): + properties = { + "title": "person", + "$ref": "#/$defs/person", + "$defs": { + "person": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + } + }, + } + + with self.assertRaises(ValueError): + ref_strategy, ref_name, ref_property = RefTypeParser()._parse_from_strategy( + "invalid_strategy", + "person", + properties, + ) + + def test_parse_from_strategy_forward_ref(self): + properties = { + "title": "person", + "$ref": "#/$defs/person", + "$defs": { + "person": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + } + }, + } + + parsed_type = RefTypeParser()._parse_from_strategy( + "forward_ref", + "person", + properties, + ) + + self.assertIsInstance(parsed_type, ForwardRef) + + def test_parse_from_strategy_def_ref(self): + properties = { + "title": "person", + "$ref": "#/$defs/person", + "$defs": { + "person": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + } + }, + } + + parsed_type = RefTypeParser()._parse_from_strategy( + "def_ref", + "person", + properties, + context=properties, + ref_cache={}, + required=True, + ) + + obj = parsed_type( + name="John", + age=30, + ) + + self.assertEqual(obj.name, "John") + self.assertEqual(obj.age, 30)