Implements Feature Complete AnyOf Keyword
This commit is contained in:
@@ -1,32 +1,36 @@
|
||||
from pydantic import Field
|
||||
from typing_extensions import Self
|
||||
from pydantic import Field, TypeAdapter
|
||||
from typing_extensions import Annotated, Self
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, TypeVar
|
||||
from typing import Generic, Type, TypeVar
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class GenericTypeParser(ABC, Generic[T]):
|
||||
@property
|
||||
@abstractmethod
|
||||
def mapped_type(self) -> type[T]: ...
|
||||
mapped_type: Type[T] = None
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def json_schema_type(self) -> str: ...
|
||||
json_schema_type: str = None
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def from_properties(
|
||||
name: str, properties: dict[str, any]
|
||||
) -> tuple[type[T], Field]: ...
|
||||
name: str, properties: dict[str, any], required: bool = False
|
||||
) -> tuple[T, dict]: ...
|
||||
|
||||
@classmethod
|
||||
def get_impl(cls, type_name: str) -> Self:
|
||||
for subcls in cls.__subclasses__():
|
||||
if subcls.json_schema_type is None:
|
||||
raise RuntimeError(f"Unknown type: {type_name}")
|
||||
|
||||
if subcls.json_schema_type == type_name:
|
||||
return subcls
|
||||
return subcls()
|
||||
|
||||
raise ValueError(f"Unknown type: {type_name}")
|
||||
|
||||
@staticmethod
|
||||
def validate_default(field_type: type, field_prop: dict, value):
|
||||
field = Annotated[field_type, Field(**field_prop)]
|
||||
TypeAdapter(field).validate_python(value)
|
||||
|
||||
@@ -7,7 +7,7 @@ class AllOfTypeParser(GenericTypeParser):
|
||||
json_schema_type = "allOf"
|
||||
|
||||
@staticmethod
|
||||
def from_properties(name, properties):
|
||||
def from_properties(name, properties, required=False):
|
||||
subProperties = properties.get("allOf")
|
||||
if not subProperties:
|
||||
raise ValueError("Invalid JSON Schema: 'allOf' is not specified.")
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from jambo.parser._type_parser import GenericTypeParser
|
||||
|
||||
from pydantic import Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from typing import Union
|
||||
|
||||
|
||||
@@ -9,20 +12,47 @@ class AnyOfTypeParser(GenericTypeParser):
|
||||
json_schema_type = "anyOf"
|
||||
|
||||
@staticmethod
|
||||
def from_properties(name, properties):
|
||||
def from_properties(name, properties, required=False):
|
||||
if "anyOf" not in properties:
|
||||
raise ValueError(f"Invalid JSON Schema: {properties}")
|
||||
|
||||
if not isinstance(properties["anyOf"], list):
|
||||
raise ValueError(f"Invalid JSON Schema: {properties['anyOf']}")
|
||||
|
||||
mapped_properties = dict()
|
||||
|
||||
subProperties = properties["anyOf"]
|
||||
|
||||
types = [
|
||||
sub_types = [
|
||||
GenericTypeParser.get_impl(subProperty["type"]).from_properties(
|
||||
name, subProperty
|
||||
)
|
||||
for subProperty in subProperties
|
||||
]
|
||||
|
||||
return Union[*(t for t, v in types)], {}
|
||||
default_value = properties.get("default")
|
||||
if default_value is not None:
|
||||
for sub_type, sub_property in sub_types:
|
||||
try:
|
||||
GenericTypeParser.validate_default(
|
||||
sub_type, sub_property, default_value
|
||||
)
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid default value {default_value} for anyOf types: {sub_types}"
|
||||
)
|
||||
|
||||
mapped_properties["default"] = default_value
|
||||
|
||||
if not required:
|
||||
mapped_properties["default"] = mapped_properties.get("default")
|
||||
|
||||
# By defining the type as Union, we can use the Field validator to enforce
|
||||
# the constraints on the union type.
|
||||
# We use Annotated to attach the Field validators to the type.
|
||||
field_types = [Annotated[t, Field(**v)] if v else t for t, v in sub_types]
|
||||
|
||||
return Union[(*field_types,)], mapped_properties
|
||||
|
||||
@@ -15,11 +15,11 @@ class ArrayTypeParser(GenericTypeParser):
|
||||
|
||||
json_schema_type = "array"
|
||||
|
||||
@classmethod
|
||||
def from_properties(cls, name, properties):
|
||||
@staticmethod
|
||||
def from_properties(name, properties, required=False):
|
||||
_item_type, _item_args = GenericTypeParser.get_impl(
|
||||
properties["items"]["type"]
|
||||
).from_properties(name, properties["items"])
|
||||
).from_properties(name, properties["items"], required=True)
|
||||
|
||||
_mappings = {
|
||||
"maxItems": "max_length",
|
||||
@@ -29,11 +29,14 @@ class ArrayTypeParser(GenericTypeParser):
|
||||
wrapper_type = set if properties.get("uniqueItems", False) else list
|
||||
|
||||
mapped_properties = mappings_properties_builder(
|
||||
properties, _mappings, {"description": "description"}
|
||||
properties,
|
||||
_mappings,
|
||||
required=required,
|
||||
default_mappings={"description": "description"},
|
||||
)
|
||||
|
||||
if "default" in properties:
|
||||
default_list = properties["default"]
|
||||
default_list = properties.get("default")
|
||||
if default_list is not None:
|
||||
if not isinstance(default_list, list):
|
||||
raise ValueError(
|
||||
f"Default value must be a list, got {type(default_list).__name__}"
|
||||
@@ -63,4 +66,7 @@ class ArrayTypeParser(GenericTypeParser):
|
||||
default_list
|
||||
)
|
||||
|
||||
if "default_factory" in mapped_properties and "default" in mapped_properties:
|
||||
del mapped_properties["default"]
|
||||
|
||||
return wrapper_type[_item_type], mapped_properties
|
||||
|
||||
@@ -10,7 +10,7 @@ class BooleanTypeParser(GenericTypeParser):
|
||||
json_schema_type = "boolean"
|
||||
|
||||
@staticmethod
|
||||
def from_properties(name, properties):
|
||||
def from_properties(name, properties, required=False):
|
||||
_mappings = {
|
||||
"default": "default",
|
||||
}
|
||||
|
||||
@@ -10,5 +10,5 @@ class FloatTypeParser(GenericTypeParser):
|
||||
json_schema_type = "number"
|
||||
|
||||
@staticmethod
|
||||
def from_properties(name, properties):
|
||||
return float, numeric_properties_builder(properties)
|
||||
def from_properties(name, properties, required=False):
|
||||
return float, numeric_properties_builder(properties, required)
|
||||
|
||||
@@ -10,5 +10,5 @@ class IntTypeParser(GenericTypeParser):
|
||||
json_schema_type = "integer"
|
||||
|
||||
@staticmethod
|
||||
def from_properties(name, properties):
|
||||
return int, numeric_properties_builder(properties)
|
||||
def from_properties(name, properties, required=False):
|
||||
return int, numeric_properties_builder(properties, required)
|
||||
|
||||
@@ -7,7 +7,7 @@ class ObjectTypeParser(GenericTypeParser):
|
||||
json_schema_type = "object"
|
||||
|
||||
@staticmethod
|
||||
def from_properties(name, properties):
|
||||
def from_properties(name, properties, required=False):
|
||||
from jambo.schema_converter import SchemaConverter
|
||||
|
||||
type_parsing = SchemaConverter.build_object(name, properties)
|
||||
|
||||
@@ -10,31 +10,17 @@ class StringTypeParser(GenericTypeParser):
|
||||
json_schema_type = "string"
|
||||
|
||||
@staticmethod
|
||||
def from_properties(name, properties):
|
||||
def from_properties(name, properties, required=False):
|
||||
_mappings = {
|
||||
"maxLength": "max_length",
|
||||
"minLength": "min_length",
|
||||
"pattern": "pattern",
|
||||
}
|
||||
|
||||
mapped_properties = mappings_properties_builder(properties, _mappings)
|
||||
mapped_properties = mappings_properties_builder(properties, _mappings, required)
|
||||
|
||||
if "default" in properties:
|
||||
default_value = properties["default"]
|
||||
if not isinstance(default_value, str):
|
||||
raise ValueError(
|
||||
f"Default value for {name} must be a string, "
|
||||
f"but got <{type(properties['default']).__name__}>."
|
||||
)
|
||||
|
||||
if len(default_value) > properties.get("maxLength", float("inf")):
|
||||
raise ValueError(
|
||||
f"Default value for {name} exceeds maxLength limit of {properties.get('maxLength')}"
|
||||
)
|
||||
|
||||
if len(default_value) < properties.get("minLength", 0):
|
||||
raise ValueError(
|
||||
f"Default value for {name} is below minLength limit of {properties.get('minLength')}"
|
||||
)
|
||||
default_value = properties.get("default")
|
||||
if default_value is not None:
|
||||
StringTypeParser.validate_default(str, mapped_properties, default_value)
|
||||
|
||||
return str, mapped_properties
|
||||
|
||||
@@ -71,14 +71,13 @@ class SchemaConverter:
|
||||
|
||||
fields = {}
|
||||
for name, prop in properties.items():
|
||||
fields[name] = SchemaConverter._build_field(name, prop, required_keys)
|
||||
is_required = name in required_keys
|
||||
fields[name] = SchemaConverter._build_field(name, prop, is_required)
|
||||
|
||||
return fields
|
||||
|
||||
@staticmethod
|
||||
def _build_field(
|
||||
name, properties: dict, required_keys: list[str]
|
||||
) -> tuple[type, dict]:
|
||||
def _build_field(name, properties: dict, required=False) -> tuple[type, dict]:
|
||||
match properties:
|
||||
case {"anyOf": _}:
|
||||
_field_type = "anyOf"
|
||||
@@ -91,17 +90,6 @@ class SchemaConverter:
|
||||
|
||||
_field_type, _field_args = GenericTypeParser.get_impl(
|
||||
_field_type
|
||||
).from_properties(name, properties)
|
||||
|
||||
_field_args = _field_args or {}
|
||||
|
||||
if description := properties.get("description"):
|
||||
_field_args["description"] = description
|
||||
|
||||
if name not in required_keys:
|
||||
_field_args["default"] = properties.get("default", None)
|
||||
|
||||
if "default_factory" in _field_args and "default" in _field_args:
|
||||
del _field_args["default"]
|
||||
).from_properties(name, properties, required)
|
||||
|
||||
return _field_type, Field(**_field_args)
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
def mappings_properties_builder(properties, mappings, default_mappings=None):
|
||||
def mappings_properties_builder(
|
||||
properties, mappings, required=False, default_mappings=None
|
||||
):
|
||||
if not required:
|
||||
properties["default"] = properties.get("default", None)
|
||||
|
||||
default_mappings = default_mappings or {
|
||||
"default": "default",
|
||||
"description": "description",
|
||||
|
||||
@@ -3,7 +3,7 @@ from jambo.utils.properties_builder.mappings_properties_builder import (
|
||||
)
|
||||
|
||||
|
||||
def numeric_properties_builder(properties):
|
||||
def numeric_properties_builder(properties, required=False):
|
||||
_mappings = {
|
||||
"minimum": "ge",
|
||||
"exclusiveMinimum": "gt",
|
||||
@@ -13,9 +13,10 @@ def numeric_properties_builder(properties):
|
||||
"default": "default",
|
||||
}
|
||||
|
||||
mapped_properties = mappings_properties_builder(properties, _mappings)
|
||||
mapped_properties = mappings_properties_builder(properties, _mappings, required)
|
||||
|
||||
if "default" in properties:
|
||||
default_value = properties.get("default")
|
||||
if default_value is not None:
|
||||
default_value = properties["default"]
|
||||
if not isinstance(default_value, (int, float)):
|
||||
raise ValueError(
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from jambo.parser.anyof_type_parser import AnyOfTypeParser
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from typing import Union, get_args, get_origin
|
||||
from unittest import TestCase
|
||||
|
||||
@@ -21,5 +23,41 @@ class TestAnyOfTypeParser(TestCase):
|
||||
|
||||
# check union type has string and int
|
||||
self.assertEqual(get_origin(type_parsing), Union)
|
||||
self.assertIn(str, get_args(type_parsing))
|
||||
self.assertIn(int, get_args(type_parsing))
|
||||
|
||||
type_1, type_2 = get_args(type_parsing)
|
||||
|
||||
self.assertEqual(get_origin(type_1), Annotated)
|
||||
self.assertIn(str, get_args(type_1))
|
||||
|
||||
self.assertEqual(get_origin(type_2), Annotated)
|
||||
self.assertIn(int, get_args(type_2))
|
||||
|
||||
def test_any_of_string_or_int_with_default(self):
|
||||
"""
|
||||
Tests the AnyOfTypeParser with a string or int type and a default value.
|
||||
"""
|
||||
|
||||
properties = {
|
||||
"anyOf": [
|
||||
{"type": "string"},
|
||||
{"type": "integer"},
|
||||
],
|
||||
"default": 42,
|
||||
}
|
||||
|
||||
type_parsing, type_validator = AnyOfTypeParser.from_properties(
|
||||
"placeholder", properties
|
||||
)
|
||||
|
||||
# check union type has string and int
|
||||
self.assertEqual(get_origin(type_parsing), Union)
|
||||
|
||||
type_1, type_2 = get_args(type_parsing)
|
||||
|
||||
self.assertEqual(get_origin(type_1), Annotated)
|
||||
self.assertIn(str, get_args(type_1))
|
||||
|
||||
self.assertEqual(get_origin(type_2), Annotated)
|
||||
self.assertIn(int, get_args(type_2))
|
||||
|
||||
self.assertEqual(type_validator["default"], 42)
|
||||
|
||||
@@ -12,7 +12,7 @@ class TestBoolTypeParser(TestCase):
|
||||
type_parsing, type_validator = parser.from_properties("placeholder", properties)
|
||||
|
||||
self.assertEqual(type_parsing, bool)
|
||||
self.assertEqual(type_validator, {})
|
||||
self.assertEqual(type_validator, {"default": None})
|
||||
|
||||
def test_bool_parser_with_default(self):
|
||||
parser = BooleanTypeParser()
|
||||
|
||||
@@ -12,7 +12,7 @@ class TestFloatTypeParser(TestCase):
|
||||
type_parsing, type_validator = parser.from_properties("placeholder", properties)
|
||||
|
||||
self.assertEqual(type_parsing, float)
|
||||
self.assertEqual(type_validator, {})
|
||||
self.assertEqual(type_validator, {"default": None})
|
||||
|
||||
def test_float_parser_with_options(self):
|
||||
parser = FloatTypeParser()
|
||||
|
||||
@@ -12,7 +12,7 @@ class TestIntTypeParser(TestCase):
|
||||
type_parsing, type_validator = parser.from_properties("placeholder", properties)
|
||||
|
||||
self.assertEqual(type_parsing, int)
|
||||
self.assertEqual(type_validator, {})
|
||||
self.assertEqual(type_validator, {"default": None})
|
||||
|
||||
def test_int_parser_with_options(self):
|
||||
parser = IntTypeParser()
|
||||
|
||||
@@ -57,14 +57,9 @@ class TestStringTypeParser(TestCase):
|
||||
"minLength": 5,
|
||||
}
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
with self.assertRaises(ValueError):
|
||||
parser.from_properties("placeholder", properties)
|
||||
|
||||
self.assertEqual(
|
||||
str(context.exception),
|
||||
"Default value for placeholder must be a string, but got <int>.",
|
||||
)
|
||||
|
||||
def test_string_parser_with_default_invalid_maxlength(self):
|
||||
parser = StringTypeParser()
|
||||
|
||||
@@ -75,14 +70,9 @@ class TestStringTypeParser(TestCase):
|
||||
"minLength": 1,
|
||||
}
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
with self.assertRaises(ValueError):
|
||||
parser.from_properties("placeholder", properties)
|
||||
|
||||
self.assertEqual(
|
||||
str(context.exception),
|
||||
"Default value for placeholder exceeds maxLength limit of 2",
|
||||
)
|
||||
|
||||
def test_string_parser_with_default_invalid_minlength(self):
|
||||
parser = StringTypeParser()
|
||||
|
||||
@@ -93,10 +83,5 @@ class TestStringTypeParser(TestCase):
|
||||
"minLength": 2,
|
||||
}
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
with self.assertRaises(ValueError):
|
||||
parser.from_properties("placeholder", properties)
|
||||
|
||||
self.assertEqual(
|
||||
str(context.exception),
|
||||
"Default value for placeholder is below minLength limit of 2",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user