Implements Feature Complete AnyOf Keyword
This commit is contained in:
@@ -1,32 +1,36 @@
|
|||||||
from pydantic import Field
|
from pydantic import Field, TypeAdapter
|
||||||
from typing_extensions import Self
|
from typing_extensions import Annotated, Self
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Generic, TypeVar
|
from typing import Generic, Type, TypeVar
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class GenericTypeParser(ABC, Generic[T]):
|
class GenericTypeParser(ABC, Generic[T]):
|
||||||
@property
|
mapped_type: Type[T] = None
|
||||||
@abstractmethod
|
|
||||||
def mapped_type(self) -> type[T]: ...
|
|
||||||
|
|
||||||
@property
|
json_schema_type: str = None
|
||||||
@abstractmethod
|
|
||||||
def json_schema_type(self) -> str: ...
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def from_properties(
|
def from_properties(
|
||||||
name: str, properties: dict[str, any]
|
name: str, properties: dict[str, any], required: bool = False
|
||||||
) -> tuple[type[T], Field]: ...
|
) -> tuple[T, dict]: ...
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_impl(cls, type_name: str) -> Self:
|
def get_impl(cls, type_name: str) -> Self:
|
||||||
for subcls in cls.__subclasses__():
|
for subcls in cls.__subclasses__():
|
||||||
|
if subcls.json_schema_type is None:
|
||||||
|
raise RuntimeError(f"Unknown type: {type_name}")
|
||||||
|
|
||||||
if subcls.json_schema_type == type_name:
|
if subcls.json_schema_type == type_name:
|
||||||
return subcls
|
return subcls()
|
||||||
|
|
||||||
raise ValueError(f"Unknown type: {type_name}")
|
raise ValueError(f"Unknown type: {type_name}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_default(field_type: type, field_prop: dict, value):
|
||||||
|
field = Annotated[field_type, Field(**field_prop)]
|
||||||
|
TypeAdapter(field).validate_python(value)
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ class AllOfTypeParser(GenericTypeParser):
|
|||||||
json_schema_type = "allOf"
|
json_schema_type = "allOf"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_properties(name, properties):
|
def from_properties(name, properties, required=False):
|
||||||
subProperties = properties.get("allOf")
|
subProperties = properties.get("allOf")
|
||||||
if not subProperties:
|
if not subProperties:
|
||||||
raise ValueError("Invalid JSON Schema: 'allOf' is not specified.")
|
raise ValueError("Invalid JSON Schema: 'allOf' is not specified.")
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
from jambo.parser._type_parser import GenericTypeParser
|
from jambo.parser._type_parser import GenericTypeParser
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
|
||||||
@@ -9,20 +12,47 @@ class AnyOfTypeParser(GenericTypeParser):
|
|||||||
json_schema_type = "anyOf"
|
json_schema_type = "anyOf"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_properties(name, properties):
|
def from_properties(name, properties, required=False):
|
||||||
if "anyOf" not in properties:
|
if "anyOf" not in properties:
|
||||||
raise ValueError(f"Invalid JSON Schema: {properties}")
|
raise ValueError(f"Invalid JSON Schema: {properties}")
|
||||||
|
|
||||||
if not isinstance(properties["anyOf"], list):
|
if not isinstance(properties["anyOf"], list):
|
||||||
raise ValueError(f"Invalid JSON Schema: {properties['anyOf']}")
|
raise ValueError(f"Invalid JSON Schema: {properties['anyOf']}")
|
||||||
|
|
||||||
|
mapped_properties = dict()
|
||||||
|
|
||||||
subProperties = properties["anyOf"]
|
subProperties = properties["anyOf"]
|
||||||
|
|
||||||
types = [
|
sub_types = [
|
||||||
GenericTypeParser.get_impl(subProperty["type"]).from_properties(
|
GenericTypeParser.get_impl(subProperty["type"]).from_properties(
|
||||||
name, subProperty
|
name, subProperty
|
||||||
)
|
)
|
||||||
for subProperty in subProperties
|
for subProperty in subProperties
|
||||||
]
|
]
|
||||||
|
|
||||||
return Union[*(t for t, v in types)], {}
|
default_value = properties.get("default")
|
||||||
|
if default_value is not None:
|
||||||
|
for sub_type, sub_property in sub_types:
|
||||||
|
try:
|
||||||
|
GenericTypeParser.validate_default(
|
||||||
|
sub_type, sub_property, default_value
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid default value {default_value} for anyOf types: {sub_types}"
|
||||||
|
)
|
||||||
|
|
||||||
|
mapped_properties["default"] = default_value
|
||||||
|
|
||||||
|
if not required:
|
||||||
|
mapped_properties["default"] = mapped_properties.get("default")
|
||||||
|
|
||||||
|
# By defining the type as Union, we can use the Field validator to enforce
|
||||||
|
# the constraints on the union type.
|
||||||
|
# We use Annotated to attach the Field validators to the type.
|
||||||
|
field_types = [Annotated[t, Field(**v)] if v else t for t, v in sub_types]
|
||||||
|
|
||||||
|
return Union[(*field_types,)], mapped_properties
|
||||||
|
|||||||
@@ -15,11 +15,11 @@ class ArrayTypeParser(GenericTypeParser):
|
|||||||
|
|
||||||
json_schema_type = "array"
|
json_schema_type = "array"
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def from_properties(cls, name, properties):
|
def from_properties(name, properties, required=False):
|
||||||
_item_type, _item_args = GenericTypeParser.get_impl(
|
_item_type, _item_args = GenericTypeParser.get_impl(
|
||||||
properties["items"]["type"]
|
properties["items"]["type"]
|
||||||
).from_properties(name, properties["items"])
|
).from_properties(name, properties["items"], required=True)
|
||||||
|
|
||||||
_mappings = {
|
_mappings = {
|
||||||
"maxItems": "max_length",
|
"maxItems": "max_length",
|
||||||
@@ -29,11 +29,14 @@ class ArrayTypeParser(GenericTypeParser):
|
|||||||
wrapper_type = set if properties.get("uniqueItems", False) else list
|
wrapper_type = set if properties.get("uniqueItems", False) else list
|
||||||
|
|
||||||
mapped_properties = mappings_properties_builder(
|
mapped_properties = mappings_properties_builder(
|
||||||
properties, _mappings, {"description": "description"}
|
properties,
|
||||||
|
_mappings,
|
||||||
|
required=required,
|
||||||
|
default_mappings={"description": "description"},
|
||||||
)
|
)
|
||||||
|
|
||||||
if "default" in properties:
|
default_list = properties.get("default")
|
||||||
default_list = properties["default"]
|
if default_list is not None:
|
||||||
if not isinstance(default_list, list):
|
if not isinstance(default_list, list):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Default value must be a list, got {type(default_list).__name__}"
|
f"Default value must be a list, got {type(default_list).__name__}"
|
||||||
@@ -63,4 +66,7 @@ class ArrayTypeParser(GenericTypeParser):
|
|||||||
default_list
|
default_list
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "default_factory" in mapped_properties and "default" in mapped_properties:
|
||||||
|
del mapped_properties["default"]
|
||||||
|
|
||||||
return wrapper_type[_item_type], mapped_properties
|
return wrapper_type[_item_type], mapped_properties
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ class BooleanTypeParser(GenericTypeParser):
|
|||||||
json_schema_type = "boolean"
|
json_schema_type = "boolean"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_properties(name, properties):
|
def from_properties(name, properties, required=False):
|
||||||
_mappings = {
|
_mappings = {
|
||||||
"default": "default",
|
"default": "default",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,5 +10,5 @@ class FloatTypeParser(GenericTypeParser):
|
|||||||
json_schema_type = "number"
|
json_schema_type = "number"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_properties(name, properties):
|
def from_properties(name, properties, required=False):
|
||||||
return float, numeric_properties_builder(properties)
|
return float, numeric_properties_builder(properties, required)
|
||||||
|
|||||||
@@ -10,5 +10,5 @@ class IntTypeParser(GenericTypeParser):
|
|||||||
json_schema_type = "integer"
|
json_schema_type = "integer"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_properties(name, properties):
|
def from_properties(name, properties, required=False):
|
||||||
return int, numeric_properties_builder(properties)
|
return int, numeric_properties_builder(properties, required)
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ class ObjectTypeParser(GenericTypeParser):
|
|||||||
json_schema_type = "object"
|
json_schema_type = "object"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_properties(name, properties):
|
def from_properties(name, properties, required=False):
|
||||||
from jambo.schema_converter import SchemaConverter
|
from jambo.schema_converter import SchemaConverter
|
||||||
|
|
||||||
type_parsing = SchemaConverter.build_object(name, properties)
|
type_parsing = SchemaConverter.build_object(name, properties)
|
||||||
|
|||||||
@@ -10,31 +10,17 @@ class StringTypeParser(GenericTypeParser):
|
|||||||
json_schema_type = "string"
|
json_schema_type = "string"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_properties(name, properties):
|
def from_properties(name, properties, required=False):
|
||||||
_mappings = {
|
_mappings = {
|
||||||
"maxLength": "max_length",
|
"maxLength": "max_length",
|
||||||
"minLength": "min_length",
|
"minLength": "min_length",
|
||||||
"pattern": "pattern",
|
"pattern": "pattern",
|
||||||
}
|
}
|
||||||
|
|
||||||
mapped_properties = mappings_properties_builder(properties, _mappings)
|
mapped_properties = mappings_properties_builder(properties, _mappings, required)
|
||||||
|
|
||||||
if "default" in properties:
|
default_value = properties.get("default")
|
||||||
default_value = properties["default"]
|
if default_value is not None:
|
||||||
if not isinstance(default_value, str):
|
StringTypeParser.validate_default(str, mapped_properties, default_value)
|
||||||
raise ValueError(
|
|
||||||
f"Default value for {name} must be a string, "
|
|
||||||
f"but got <{type(properties['default']).__name__}>."
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(default_value) > properties.get("maxLength", float("inf")):
|
|
||||||
raise ValueError(
|
|
||||||
f"Default value for {name} exceeds maxLength limit of {properties.get('maxLength')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(default_value) < properties.get("minLength", 0):
|
|
||||||
raise ValueError(
|
|
||||||
f"Default value for {name} is below minLength limit of {properties.get('minLength')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return str, mapped_properties
|
return str, mapped_properties
|
||||||
|
|||||||
@@ -71,14 +71,13 @@ class SchemaConverter:
|
|||||||
|
|
||||||
fields = {}
|
fields = {}
|
||||||
for name, prop in properties.items():
|
for name, prop in properties.items():
|
||||||
fields[name] = SchemaConverter._build_field(name, prop, required_keys)
|
is_required = name in required_keys
|
||||||
|
fields[name] = SchemaConverter._build_field(name, prop, is_required)
|
||||||
|
|
||||||
return fields
|
return fields
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_field(
|
def _build_field(name, properties: dict, required=False) -> tuple[type, dict]:
|
||||||
name, properties: dict, required_keys: list[str]
|
|
||||||
) -> tuple[type, dict]:
|
|
||||||
match properties:
|
match properties:
|
||||||
case {"anyOf": _}:
|
case {"anyOf": _}:
|
||||||
_field_type = "anyOf"
|
_field_type = "anyOf"
|
||||||
@@ -91,17 +90,6 @@ class SchemaConverter:
|
|||||||
|
|
||||||
_field_type, _field_args = GenericTypeParser.get_impl(
|
_field_type, _field_args = GenericTypeParser.get_impl(
|
||||||
_field_type
|
_field_type
|
||||||
).from_properties(name, properties)
|
).from_properties(name, properties, required)
|
||||||
|
|
||||||
_field_args = _field_args or {}
|
|
||||||
|
|
||||||
if description := properties.get("description"):
|
|
||||||
_field_args["description"] = description
|
|
||||||
|
|
||||||
if name not in required_keys:
|
|
||||||
_field_args["default"] = properties.get("default", None)
|
|
||||||
|
|
||||||
if "default_factory" in _field_args and "default" in _field_args:
|
|
||||||
del _field_args["default"]
|
|
||||||
|
|
||||||
return _field_type, Field(**_field_args)
|
return _field_type, Field(**_field_args)
|
||||||
|
|||||||
@@ -1,4 +1,9 @@
|
|||||||
def mappings_properties_builder(properties, mappings, default_mappings=None):
|
def mappings_properties_builder(
|
||||||
|
properties, mappings, required=False, default_mappings=None
|
||||||
|
):
|
||||||
|
if not required:
|
||||||
|
properties["default"] = properties.get("default", None)
|
||||||
|
|
||||||
default_mappings = default_mappings or {
|
default_mappings = default_mappings or {
|
||||||
"default": "default",
|
"default": "default",
|
||||||
"description": "description",
|
"description": "description",
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from jambo.utils.properties_builder.mappings_properties_builder import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def numeric_properties_builder(properties):
|
def numeric_properties_builder(properties, required=False):
|
||||||
_mappings = {
|
_mappings = {
|
||||||
"minimum": "ge",
|
"minimum": "ge",
|
||||||
"exclusiveMinimum": "gt",
|
"exclusiveMinimum": "gt",
|
||||||
@@ -13,9 +13,10 @@ def numeric_properties_builder(properties):
|
|||||||
"default": "default",
|
"default": "default",
|
||||||
}
|
}
|
||||||
|
|
||||||
mapped_properties = mappings_properties_builder(properties, _mappings)
|
mapped_properties = mappings_properties_builder(properties, _mappings, required)
|
||||||
|
|
||||||
if "default" in properties:
|
default_value = properties.get("default")
|
||||||
|
if default_value is not None:
|
||||||
default_value = properties["default"]
|
default_value = properties["default"]
|
||||||
if not isinstance(default_value, (int, float)):
|
if not isinstance(default_value, (int, float)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
from jambo.parser.anyof_type_parser import AnyOfTypeParser
|
from jambo.parser.anyof_type_parser import AnyOfTypeParser
|
||||||
|
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from typing import Union, get_args, get_origin
|
from typing import Union, get_args, get_origin
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
|
|
||||||
@@ -21,5 +23,41 @@ class TestAnyOfTypeParser(TestCase):
|
|||||||
|
|
||||||
# check union type has string and int
|
# check union type has string and int
|
||||||
self.assertEqual(get_origin(type_parsing), Union)
|
self.assertEqual(get_origin(type_parsing), Union)
|
||||||
self.assertIn(str, get_args(type_parsing))
|
|
||||||
self.assertIn(int, get_args(type_parsing))
|
type_1, type_2 = get_args(type_parsing)
|
||||||
|
|
||||||
|
self.assertEqual(get_origin(type_1), Annotated)
|
||||||
|
self.assertIn(str, get_args(type_1))
|
||||||
|
|
||||||
|
self.assertEqual(get_origin(type_2), Annotated)
|
||||||
|
self.assertIn(int, get_args(type_2))
|
||||||
|
|
||||||
|
def test_any_of_string_or_int_with_default(self):
|
||||||
|
"""
|
||||||
|
Tests the AnyOfTypeParser with a string or int type and a default value.
|
||||||
|
"""
|
||||||
|
|
||||||
|
properties = {
|
||||||
|
"anyOf": [
|
||||||
|
{"type": "string"},
|
||||||
|
{"type": "integer"},
|
||||||
|
],
|
||||||
|
"default": 42,
|
||||||
|
}
|
||||||
|
|
||||||
|
type_parsing, type_validator = AnyOfTypeParser.from_properties(
|
||||||
|
"placeholder", properties
|
||||||
|
)
|
||||||
|
|
||||||
|
# check union type has string and int
|
||||||
|
self.assertEqual(get_origin(type_parsing), Union)
|
||||||
|
|
||||||
|
type_1, type_2 = get_args(type_parsing)
|
||||||
|
|
||||||
|
self.assertEqual(get_origin(type_1), Annotated)
|
||||||
|
self.assertIn(str, get_args(type_1))
|
||||||
|
|
||||||
|
self.assertEqual(get_origin(type_2), Annotated)
|
||||||
|
self.assertIn(int, get_args(type_2))
|
||||||
|
|
||||||
|
self.assertEqual(type_validator["default"], 42)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ class TestBoolTypeParser(TestCase):
|
|||||||
type_parsing, type_validator = parser.from_properties("placeholder", properties)
|
type_parsing, type_validator = parser.from_properties("placeholder", properties)
|
||||||
|
|
||||||
self.assertEqual(type_parsing, bool)
|
self.assertEqual(type_parsing, bool)
|
||||||
self.assertEqual(type_validator, {})
|
self.assertEqual(type_validator, {"default": None})
|
||||||
|
|
||||||
def test_bool_parser_with_default(self):
|
def test_bool_parser_with_default(self):
|
||||||
parser = BooleanTypeParser()
|
parser = BooleanTypeParser()
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ class TestFloatTypeParser(TestCase):
|
|||||||
type_parsing, type_validator = parser.from_properties("placeholder", properties)
|
type_parsing, type_validator = parser.from_properties("placeholder", properties)
|
||||||
|
|
||||||
self.assertEqual(type_parsing, float)
|
self.assertEqual(type_parsing, float)
|
||||||
self.assertEqual(type_validator, {})
|
self.assertEqual(type_validator, {"default": None})
|
||||||
|
|
||||||
def test_float_parser_with_options(self):
|
def test_float_parser_with_options(self):
|
||||||
parser = FloatTypeParser()
|
parser = FloatTypeParser()
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ class TestIntTypeParser(TestCase):
|
|||||||
type_parsing, type_validator = parser.from_properties("placeholder", properties)
|
type_parsing, type_validator = parser.from_properties("placeholder", properties)
|
||||||
|
|
||||||
self.assertEqual(type_parsing, int)
|
self.assertEqual(type_parsing, int)
|
||||||
self.assertEqual(type_validator, {})
|
self.assertEqual(type_validator, {"default": None})
|
||||||
|
|
||||||
def test_int_parser_with_options(self):
|
def test_int_parser_with_options(self):
|
||||||
parser = IntTypeParser()
|
parser = IntTypeParser()
|
||||||
|
|||||||
@@ -57,14 +57,9 @@ class TestStringTypeParser(TestCase):
|
|||||||
"minLength": 5,
|
"minLength": 5,
|
||||||
}
|
}
|
||||||
|
|
||||||
with self.assertRaises(ValueError) as context:
|
with self.assertRaises(ValueError):
|
||||||
parser.from_properties("placeholder", properties)
|
parser.from_properties("placeholder", properties)
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
str(context.exception),
|
|
||||||
"Default value for placeholder must be a string, but got <int>.",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_string_parser_with_default_invalid_maxlength(self):
|
def test_string_parser_with_default_invalid_maxlength(self):
|
||||||
parser = StringTypeParser()
|
parser = StringTypeParser()
|
||||||
|
|
||||||
@@ -75,14 +70,9 @@ class TestStringTypeParser(TestCase):
|
|||||||
"minLength": 1,
|
"minLength": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
with self.assertRaises(ValueError) as context:
|
with self.assertRaises(ValueError):
|
||||||
parser.from_properties("placeholder", properties)
|
parser.from_properties("placeholder", properties)
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
str(context.exception),
|
|
||||||
"Default value for placeholder exceeds maxLength limit of 2",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_string_parser_with_default_invalid_minlength(self):
|
def test_string_parser_with_default_invalid_minlength(self):
|
||||||
parser = StringTypeParser()
|
parser = StringTypeParser()
|
||||||
|
|
||||||
@@ -93,10 +83,5 @@ class TestStringTypeParser(TestCase):
|
|||||||
"minLength": 2,
|
"minLength": 2,
|
||||||
}
|
}
|
||||||
|
|
||||||
with self.assertRaises(ValueError) as context:
|
with self.assertRaises(ValueError):
|
||||||
parser.from_properties("placeholder", properties)
|
parser.from_properties("placeholder", properties)
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
str(context.exception),
|
|
||||||
"Default value for placeholder is below minLength limit of 2",
|
|
||||||
)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user