feature: add instance level ref cache #63

Merged
HideyoshiNakazone merged 12 commits from feature/add-instance-level-ref-cache into main 2025-11-25 00:07:55 +00:00
3 changed files with 41 additions and 15 deletions
Showing only changes of commit abc8bc2e40 - Show all commits

View File

@@ -1,3 +1,4 @@
from jambo.exceptions import InternalAssertionException
from jambo.parser._type_parser import GenericTypeParser from jambo.parser._type_parser import GenericTypeParser
from jambo.types.json_schema_type import JSONSchema from jambo.types.json_schema_type import JSONSchema
from jambo.types.type_parser_options import TypeParserOptions from jambo.types.type_parser_options import TypeParserOptions
@@ -6,6 +7,8 @@ from pydantic import BaseModel, ConfigDict, Field, create_model
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
from typing_extensions import Unpack from typing_extensions import Unpack
import warnings
class ObjectTypeParser(GenericTypeParser): class ObjectTypeParser(GenericTypeParser):
mapped_type = object mapped_type = object
@@ -15,6 +18,12 @@ class ObjectTypeParser(GenericTypeParser):
def from_properties_impl( def from_properties_impl(
self, name: str, properties: JSONSchema, **kwargs: Unpack[TypeParserOptions] self, name: str, properties: JSONSchema, **kwargs: Unpack[TypeParserOptions]
) -> tuple[type[BaseModel], dict]: ) -> tuple[type[BaseModel], dict]:
ref_cache = kwargs.get("ref_cache")
if ref_cache is None:
raise InternalAssertionException(
"`ref_cache` must be provided in kwargs for RefTypeParser"
)
type_parsing = self.to_model( type_parsing = self.to_model(
name, name,
properties.get("properties", {}), properties.get("properties", {}),
@@ -37,6 +46,13 @@ class ObjectTypeParser(GenericTypeParser):
type_parsing.model_validate(example) for example in example_values type_parsing.model_validate(example) for example in example_values
] ]
if name in ref_cache:
warnings.warn(
f"Type '{name}' is already in the ref_cache and will be overwritten.",
UserWarning,
)
ref_cache[name] = type_parsing
return type_parsing, type_properties return type_parsing, type_properties
@classmethod @classmethod

View File

@@ -42,7 +42,7 @@ class TestAllOfTypeParser(TestCase):
} }
type_parsing, type_validator = AllOfTypeParser().from_properties( type_parsing, type_validator = AllOfTypeParser().from_properties(
"placeholder", properties "placeholder", properties, ref_cache={}
) )
with self.assertRaises(ValidationError): with self.assertRaises(ValidationError):
@@ -87,7 +87,7 @@ class TestAllOfTypeParser(TestCase):
} }
type_parsing, type_validator = AllOfTypeParser().from_properties( type_parsing, type_validator = AllOfTypeParser().from_properties(
"placeholder", properties "placeholder", properties, ref_cache={}
) )
with self.assertRaises(ValidationError): with self.assertRaises(ValidationError):
@@ -116,7 +116,7 @@ class TestAllOfTypeParser(TestCase):
} }
type_parsing, type_validator = AllOfTypeParser().from_properties( type_parsing, type_validator = AllOfTypeParser().from_properties(
"placeholder", properties "placeholder", properties, ref_cache={}
) )
self.assertEqual(type_parsing, str) self.assertEqual(type_parsing, str)
@@ -137,7 +137,7 @@ class TestAllOfTypeParser(TestCase):
} }
type_parsing, type_validator = AllOfTypeParser().from_properties( type_parsing, type_validator = AllOfTypeParser().from_properties(
"placeholder", properties "placeholder", properties, ref_cache={}
) )
self.assertEqual(type_parsing, str) self.assertEqual(type_parsing, str)
@@ -158,7 +158,7 @@ class TestAllOfTypeParser(TestCase):
} }
with self.assertRaises(InvalidSchemaException): with self.assertRaises(InvalidSchemaException):
AllOfTypeParser().from_properties("placeholder", properties) AllOfTypeParser().from_properties("placeholder", properties, ref_cache={})
def test_all_of_invalid_type_not_present(self): def test_all_of_invalid_type_not_present(self):
properties = { properties = {
@@ -171,7 +171,7 @@ class TestAllOfTypeParser(TestCase):
} }
with self.assertRaises(InvalidSchemaException): with self.assertRaises(InvalidSchemaException):
AllOfTypeParser().from_properties("placeholder", properties) AllOfTypeParser().from_properties("placeholder", properties, ref_cache={})
def test_all_of_invalid_type_in_fields(self): def test_all_of_invalid_type_in_fields(self):
properties = { properties = {
@@ -184,7 +184,7 @@ class TestAllOfTypeParser(TestCase):
} }
with self.assertRaises(InvalidSchemaException): with self.assertRaises(InvalidSchemaException):
AllOfTypeParser().from_properties("placeholder", properties) AllOfTypeParser().from_properties("placeholder", properties, ref_cache={})
def test_all_of_invalid_type_not_all_equal(self): def test_all_of_invalid_type_not_all_equal(self):
""" """
@@ -200,7 +200,7 @@ class TestAllOfTypeParser(TestCase):
} }
with self.assertRaises(InvalidSchemaException): with self.assertRaises(InvalidSchemaException):
AllOfTypeParser().from_properties("placeholder", properties) AllOfTypeParser().from_properties("placeholder", properties, ref_cache={})
def test_all_of_description_field(self): def test_all_of_description_field(self):
""" """
@@ -237,7 +237,9 @@ class TestAllOfTypeParser(TestCase):
], ],
} }
type_parsing, _ = AllOfTypeParser().from_properties("placeholder", properties) type_parsing, _ = AllOfTypeParser().from_properties(
"placeholder", properties, ref_cache={}
)
self.assertEqual( self.assertEqual(
type_parsing.model_json_schema()["properties"]["name"]["description"], type_parsing.model_json_schema()["properties"]["name"]["description"],
@@ -275,7 +277,9 @@ class TestAllOfTypeParser(TestCase):
], ],
} }
type_parsing, _ = AllOfTypeParser().from_properties("placeholder", properties) type_parsing, _ = AllOfTypeParser().from_properties(
"placeholder", properties, ref_cache={}
)
obj = type_parsing() obj = type_parsing()
self.assertEqual(obj.name, "John") self.assertEqual(obj.name, "John")
self.assertEqual(obj.age, 30) self.assertEqual(obj.age, 30)
@@ -308,7 +312,7 @@ class TestAllOfTypeParser(TestCase):
} }
with self.assertRaises(InvalidSchemaException): with self.assertRaises(InvalidSchemaException):
AllOfTypeParser().from_properties("placeholder", properties) AllOfTypeParser().from_properties("placeholder", properties, ref_cache={})
def test_all_of_with_root_examples(self): def test_all_of_with_root_examples(self):
""" """
@@ -344,7 +348,7 @@ class TestAllOfTypeParser(TestCase):
} }
type_parsed, type_properties = AllOfTypeParser().from_properties( type_parsed, type_properties = AllOfTypeParser().from_properties(
"placeholder", properties "placeholder", properties, ref_cache={}
) )
self.assertEqual( self.assertEqual(

View File

@@ -15,7 +15,9 @@ class TestObjectTypeParser(TestCase):
}, },
} }
Model, _args = parser.from_properties_impl("placeholder", properties) Model, _args = parser.from_properties_impl(
"placeholder", properties, ref_cache={}
)
obj = Model(name="name", age=10) obj = Model(name="name", age=10)
@@ -39,7 +41,9 @@ class TestObjectTypeParser(TestCase):
], ],
} }
_, type_validator = parser.from_properties_impl("placeholder", properties) _, type_validator = parser.from_properties_impl(
"placeholder", properties, ref_cache={}
)
test_example = type_validator["examples"][0] test_example = type_validator["examples"][0]
@@ -61,7 +65,9 @@ class TestObjectTypeParser(TestCase):
}, },
} }
_, type_validator = parser.from_properties_impl("placeholder", properties) _, type_validator = parser.from_properties_impl(
"placeholder", properties, ref_cache={}
)
# Check default value # Check default value
default_obj = type_validator["default_factory"]() default_obj = type_validator["default_factory"]()