Implements: allOf, anyOf #11

Merged
HideyoshiNakazone merged 15 commits from any-all-ref-implementation into main 2025-04-19 20:32:58 +00:00
17 changed files with 133 additions and 90 deletions
Showing only changes of commit 5c3d3a39ba - Show all commits

View File

@@ -1,32 +1,36 @@
from pydantic import Field
from typing_extensions import Self
from pydantic import Field, TypeAdapter
from typing_extensions import Annotated, Self
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from typing import Generic, Type, TypeVar
T = TypeVar("T")
class GenericTypeParser(ABC, Generic[T]):
@property
@abstractmethod
def mapped_type(self) -> type[T]: ...
mapped_type: Type[T] = None
@property
@abstractmethod
def json_schema_type(self) -> str: ...
json_schema_type: str = None
@staticmethod
@abstractmethod
def from_properties(
name: str, properties: dict[str, any]
) -> tuple[type[T], Field]: ...
name: str, properties: dict[str, any], required: bool = False
) -> tuple[T, dict]: ...
@classmethod
def get_impl(cls, type_name: str) -> Self:
for subcls in cls.__subclasses__():
if subcls.json_schema_type is None:
raise RuntimeError(f"Unknown type: {type_name}")
if subcls.json_schema_type == type_name:
return subcls
return subcls()
raise ValueError(f"Unknown type: {type_name}")
@staticmethod
def validate_default(field_type: type, field_prop: dict, value):
field = Annotated[field_type, Field(**field_prop)]
TypeAdapter(field).validate_python(value)

View File

@@ -7,7 +7,7 @@ class AllOfTypeParser(GenericTypeParser):
json_schema_type = "allOf"
@staticmethod
def from_properties(name, properties):
def from_properties(name, properties, required=False):
subProperties = properties.get("allOf")
if not subProperties:
raise ValueError("Invalid JSON Schema: 'allOf' is not specified.")

View File

@@ -1,5 +1,8 @@
from jambo.parser._type_parser import GenericTypeParser
from pydantic import Field
from typing_extensions import Annotated
from typing import Union
@@ -9,20 +12,47 @@ class AnyOfTypeParser(GenericTypeParser):
json_schema_type = "anyOf"
@staticmethod
def from_properties(name, properties):
def from_properties(name, properties, required=False):
if "anyOf" not in properties:
raise ValueError(f"Invalid JSON Schema: {properties}")
if not isinstance(properties["anyOf"], list):
raise ValueError(f"Invalid JSON Schema: {properties['anyOf']}")
mapped_properties = dict()
subProperties = properties["anyOf"]
types = [
sub_types = [
GenericTypeParser.get_impl(subProperty["type"]).from_properties(
name, subProperty
)
for subProperty in subProperties
]
return Union[*(t for t, v in types)], {}
default_value = properties.get("default")
if default_value is not None:
for sub_type, sub_property in sub_types:
try:
GenericTypeParser.validate_default(
sub_type, sub_property, default_value
)
break
except ValueError:
continue
else:
raise ValueError(
f"Invalid default value {default_value} for anyOf types: {sub_types}"
)
mapped_properties["default"] = default_value
if not required:
mapped_properties["default"] = mapped_properties.get("default")
# By defining the type as Union, we can use the Field validator to enforce
# the constraints on the union type.
# We use Annotated to attach the Field validators to the type.
field_types = [Annotated[t, Field(**v)] if v else t for t, v in sub_types]
return Union[(*field_types,)], mapped_properties

View File

@@ -15,11 +15,11 @@ class ArrayTypeParser(GenericTypeParser):
json_schema_type = "array"
@classmethod
def from_properties(cls, name, properties):
@staticmethod
def from_properties(name, properties, required=False):
_item_type, _item_args = GenericTypeParser.get_impl(
properties["items"]["type"]
).from_properties(name, properties["items"])
).from_properties(name, properties["items"], required=True)
_mappings = {
"maxItems": "max_length",
@@ -29,11 +29,14 @@ class ArrayTypeParser(GenericTypeParser):
wrapper_type = set if properties.get("uniqueItems", False) else list
mapped_properties = mappings_properties_builder(
properties, _mappings, {"description": "description"}
properties,
_mappings,
required=required,
default_mappings={"description": "description"},
)
if "default" in properties:
default_list = properties["default"]
default_list = properties.get("default")
if default_list is not None:
if not isinstance(default_list, list):
raise ValueError(
f"Default value must be a list, got {type(default_list).__name__}"
@@ -63,4 +66,7 @@ class ArrayTypeParser(GenericTypeParser):
default_list
)
if "default_factory" in mapped_properties and "default" in mapped_properties:
del mapped_properties["default"]
return wrapper_type[_item_type], mapped_properties

View File

@@ -10,7 +10,7 @@ class BooleanTypeParser(GenericTypeParser):
json_schema_type = "boolean"
@staticmethod
def from_properties(name, properties):
def from_properties(name, properties, required=False):
_mappings = {
"default": "default",
}

View File

@@ -10,5 +10,5 @@ class FloatTypeParser(GenericTypeParser):
json_schema_type = "number"
@staticmethod
def from_properties(name, properties):
return float, numeric_properties_builder(properties)
def from_properties(name, properties, required=False):
return float, numeric_properties_builder(properties, required)

View File

@@ -10,5 +10,5 @@ class IntTypeParser(GenericTypeParser):
json_schema_type = "integer"
@staticmethod
def from_properties(name, properties):
return int, numeric_properties_builder(properties)
def from_properties(name, properties, required=False):
return int, numeric_properties_builder(properties, required)

View File

@@ -7,7 +7,7 @@ class ObjectTypeParser(GenericTypeParser):
json_schema_type = "object"
@staticmethod
def from_properties(name, properties):
def from_properties(name, properties, required=False):
from jambo.schema_converter import SchemaConverter
type_parsing = SchemaConverter.build_object(name, properties)

View File

@@ -10,31 +10,17 @@ class StringTypeParser(GenericTypeParser):
json_schema_type = "string"
@staticmethod
def from_properties(name, properties):
def from_properties(name, properties, required=False):
_mappings = {
"maxLength": "max_length",
"minLength": "min_length",
"pattern": "pattern",
}
mapped_properties = mappings_properties_builder(properties, _mappings)
mapped_properties = mappings_properties_builder(properties, _mappings, required)
if "default" in properties:
default_value = properties["default"]
if not isinstance(default_value, str):
raise ValueError(
f"Default value for {name} must be a string, "
f"but got <{type(properties['default']).__name__}>."
)
if len(default_value) > properties.get("maxLength", float("inf")):
raise ValueError(
f"Default value for {name} exceeds maxLength limit of {properties.get('maxLength')}"
)
if len(default_value) < properties.get("minLength", 0):
raise ValueError(
f"Default value for {name} is below minLength limit of {properties.get('minLength')}"
)
default_value = properties.get("default")
if default_value is not None:
StringTypeParser.validate_default(str, mapped_properties, default_value)
return str, mapped_properties

View File

@@ -71,14 +71,13 @@ class SchemaConverter:
fields = {}
for name, prop in properties.items():
fields[name] = SchemaConverter._build_field(name, prop, required_keys)
is_required = name in required_keys
fields[name] = SchemaConverter._build_field(name, prop, is_required)
return fields
@staticmethod
def _build_field(
name, properties: dict, required_keys: list[str]
) -> tuple[type, dict]:
def _build_field(name, properties: dict, required=False) -> tuple[type, dict]:
match properties:
case {"anyOf": _}:
_field_type = "anyOf"
@@ -91,17 +90,6 @@ class SchemaConverter:
_field_type, _field_args = GenericTypeParser.get_impl(
_field_type
).from_properties(name, properties)
_field_args = _field_args or {}
if description := properties.get("description"):
_field_args["description"] = description
if name not in required_keys:
_field_args["default"] = properties.get("default", None)
if "default_factory" in _field_args and "default" in _field_args:
del _field_args["default"]
).from_properties(name, properties, required)
return _field_type, Field(**_field_args)

View File

@@ -1,4 +1,9 @@
def mappings_properties_builder(properties, mappings, default_mappings=None):
def mappings_properties_builder(
properties, mappings, required=False, default_mappings=None
):
if not required:
properties["default"] = properties.get("default", None)
default_mappings = default_mappings or {
"default": "default",
"description": "description",

View File

@@ -3,7 +3,7 @@ from jambo.utils.properties_builder.mappings_properties_builder import (
)
def numeric_properties_builder(properties):
def numeric_properties_builder(properties, required=False):
_mappings = {
"minimum": "ge",
"exclusiveMinimum": "gt",
@@ -13,9 +13,10 @@ def numeric_properties_builder(properties):
"default": "default",
}
mapped_properties = mappings_properties_builder(properties, _mappings)
mapped_properties = mappings_properties_builder(properties, _mappings, required)
if "default" in properties:
default_value = properties.get("default")
if default_value is not None:
default_value = properties["default"]
if not isinstance(default_value, (int, float)):
raise ValueError(

View File

@@ -1,5 +1,7 @@
from jambo.parser.anyof_type_parser import AnyOfTypeParser
from typing_extensions import Annotated
from typing import Union, get_args, get_origin
from unittest import TestCase
@@ -21,5 +23,41 @@ class TestAnyOfTypeParser(TestCase):
# check union type has string and int
self.assertEqual(get_origin(type_parsing), Union)
self.assertIn(str, get_args(type_parsing))
self.assertIn(int, get_args(type_parsing))
type_1, type_2 = get_args(type_parsing)
self.assertEqual(get_origin(type_1), Annotated)
self.assertIn(str, get_args(type_1))
self.assertEqual(get_origin(type_2), Annotated)
self.assertIn(int, get_args(type_2))
def test_any_of_string_or_int_with_default(self):
"""
Tests the AnyOfTypeParser with a string or int type and a default value.
"""
properties = {
"anyOf": [
{"type": "string"},
{"type": "integer"},
],
"default": 42,
}
type_parsing, type_validator = AnyOfTypeParser.from_properties(
"placeholder", properties
)
# check union type has string and int
self.assertEqual(get_origin(type_parsing), Union)
type_1, type_2 = get_args(type_parsing)
self.assertEqual(get_origin(type_1), Annotated)
self.assertIn(str, get_args(type_1))
self.assertEqual(get_origin(type_2), Annotated)
self.assertIn(int, get_args(type_2))
self.assertEqual(type_validator["default"], 42)

View File

@@ -12,7 +12,7 @@ class TestBoolTypeParser(TestCase):
type_parsing, type_validator = parser.from_properties("placeholder", properties)
self.assertEqual(type_parsing, bool)
self.assertEqual(type_validator, {})
self.assertEqual(type_validator, {"default": None})
def test_bool_parser_with_default(self):
parser = BooleanTypeParser()

View File

@@ -12,7 +12,7 @@ class TestFloatTypeParser(TestCase):
type_parsing, type_validator = parser.from_properties("placeholder", properties)
self.assertEqual(type_parsing, float)
self.assertEqual(type_validator, {})
self.assertEqual(type_validator, {"default": None})
def test_float_parser_with_options(self):
parser = FloatTypeParser()

View File

@@ -12,7 +12,7 @@ class TestIntTypeParser(TestCase):
type_parsing, type_validator = parser.from_properties("placeholder", properties)
self.assertEqual(type_parsing, int)
self.assertEqual(type_validator, {})
self.assertEqual(type_validator, {"default": None})
def test_int_parser_with_options(self):
parser = IntTypeParser()

View File

@@ -57,14 +57,9 @@ class TestStringTypeParser(TestCase):
"minLength": 5,
}
with self.assertRaises(ValueError) as context:
with self.assertRaises(ValueError):
parser.from_properties("placeholder", properties)
self.assertEqual(
str(context.exception),
"Default value for placeholder must be a string, but got <int>.",
)
def test_string_parser_with_default_invalid_maxlength(self):
parser = StringTypeParser()
@@ -75,14 +70,9 @@ class TestStringTypeParser(TestCase):
"minLength": 1,
}
with self.assertRaises(ValueError) as context:
with self.assertRaises(ValueError):
parser.from_properties("placeholder", properties)
self.assertEqual(
str(context.exception),
"Default value for placeholder exceeds maxLength limit of 2",
)
def test_string_parser_with_default_invalid_minlength(self):
parser = StringTypeParser()
@@ -93,10 +83,5 @@ class TestStringTypeParser(TestCase):
"minLength": 2,
}
with self.assertRaises(ValueError) as context:
with self.assertRaises(ValueError):
parser.from_properties("placeholder", properties)
self.assertEqual(
str(context.exception),
"Default value for placeholder is below minLength limit of 2",
)