(feature): Fix OneOf behavior on invalid discriminator

According to the spec, propertyName is required when using a discriminator. If it is missing, the schema is invalid and should throw.
This commit is contained in:
2025-08-19 20:08:34 -03:00
parent b386d4954e
commit 86894fa918
2 changed files with 99 additions and 128 deletions

View File

@@ -31,33 +31,50 @@ class OneOfTypeParser(GenericTypeParser):
if not kwargs.get("required", False): if not kwargs.get("required", False):
mapped_properties["default"] = mapped_properties.get("default") mapped_properties["default"] = mapped_properties.get("default")
field_types = [ subfield_types = [Annotated[t, Field(**v)] for t, v in sub_types]
Annotated[t, Field(**v)] if self._has_meaningful_constraints(v) else t
for t, v in sub_types
]
union_type = Union[(*field_types,)]
# Added with the understanding of discriminator are not in the JsonSchema Spec,
# they were added by OpenAI and not all implementations may support them,
# and they do not always generate a model one-to-one to the Pydantic model
# TL;DR: Discriminators were added by OpenAI and not a Official JSON Schema feature
discriminator = properties.get("discriminator") discriminator = properties.get("discriminator")
if discriminator and isinstance(discriminator, dict): if discriminator is not None:
property_name = discriminator.get("propertyName") validated_type = self._build_type_one_of_with_discriminator(
if property_name: subfield_types, discriminator
validated_type = Annotated[ )
union_type, Field(discriminator=property_name) else:
] validated_type = self._build_type_one_of_with_func(subfield_types)
return validated_type, mapped_properties
return validated_type, mapped_properties
@staticmethod
def _build_type_one_of_with_discriminator(
subfield_types: list[Annotated], discriminator_prop: dict
) -> Annotated:
if not isinstance(discriminator_prop, dict):
raise ValueError("Discriminator must be a dictionary")
property_name = discriminator_prop.get("propertyName")
if property_name is None or not isinstance(property_name, str):
raise ValueError("Discriminator must have a 'propertyName' key")
return Annotated[Union[(*subfield_types,)], Field(discriminator=property_name)]
@staticmethod
def _build_type_one_of_with_func(subfield_types: list[Annotated]) -> Annotated:
"""
Build a validation function for the oneOf constraint.
This function will validate that the value matches exactly one of the schemas.
"""
def validate_one_of(value: Any) -> Any: def validate_one_of(value: Any) -> Any:
matched_count = 0 matched_count = 0
validation_errors = []
for field_type in field_types: for field_type in subfield_types:
try: try:
adapter = TypeAdapter(field_type) TypeAdapter(field_type).validate_python(value)
adapter.validate_python(value)
matched_count += 1 matched_count += 1
except ValidationError as e: except ValidationError:
validation_errors.append(str(e))
continue continue
if matched_count == 0: if matched_count == 0:
@@ -69,8 +86,7 @@ class OneOfTypeParser(GenericTypeParser):
return value return value
validated_type = Annotated[union_type, BeforeValidator(validate_one_of)] return Annotated[Union[(*subfield_types,)], BeforeValidator(validate_one_of)]
return validated_type, mapped_properties
@staticmethod @staticmethod
def _has_meaningful_constraints(field_props): def _has_meaningful_constraints(field_props):

View File

@@ -354,131 +354,86 @@ class TestOneOfTypeParser(TestCase):
}, },
} }
Model = SchemaConverter.build(schema) # Should throw because the spec determines propertyName is required for discriminator
with self.assertRaises(ValueError):
SchemaConverter.build(schema)
# Should succeed because input matches exactly one schema (the first one) def test_oneof_overlapping_strings_from_docs(self):
# The first schema matches: type="a" matches const("a"), value="test" is a string """Test the overlapping strings example from documentation"""
# The second schema doesn't match: type="a" does not match const("b")
obj = Model(value={"type": "a", "value": "test", "extra": "invalid"})
self.assertEqual(obj.value.type, "a")
self.assertEqual(obj.value.value, "test")
# Test with input that matches the second schema
obj2 = Model(value={"type": "b", "value": 42})
self.assertEqual(obj2.value.type, "b")
self.assertEqual(obj2.value.value, 42)
# Test with input that matches neither schema (should fail)
with self.assertRaises(ValueError) as cm:
Model(value={"type": "c", "value": "test"})
self.assertIn("does not match any of the oneOf schemas", str(cm.exception))
def test_oneof_multiple_matches_without_discriminator(self):
"""Test case where input genuinely matches multiple oneOf schemas"""
schema = { schema = {
"title": "Test", "title": "SimpleExample",
"type": "object", "type": "object",
"properties": { "properties": {
"value": { "value": {
"oneOf": [ "oneOf": [
{"type": "object", "properties": {"data": {"type": "string"}}}, {"type": "string", "maxLength": 6},
{ {"type": "string", "minLength": 4},
"type": "object", ]
"properties": {
"data": {"type": "string"},
"optional": {"type": "string"},
},
},
],
"discriminator": {}, # discriminator without propertyName
} }
}, },
"required": ["value"],
} }
Model = SchemaConverter.build(schema) Model = SchemaConverter.build(schema)
# This input matches both schemas since both accept data as string # Valid: Short string (matches first schema only)
# and neither requires specific additional properties obj1 = Model(value="hi")
self.assertEqual(obj1.value, "hi")
# Valid: Long string (matches second schema only)
obj2 = Model(value="very long string")
self.assertEqual(obj2.value, "very long string")
# Invalid: Medium string (matches BOTH schemas - violates oneOf)
with self.assertRaises(ValueError) as cm: with self.assertRaises(ValueError) as cm:
Model(value={"data": "test"}) Model(value="hello") # 5 chars: matches maxLength=6 AND minLength=4
self.assertIn("matches multiple oneOf schemas", str(cm.exception)) self.assertIn("matches multiple oneOf schemas", str(cm.exception))
def test_oneof_overlapping_strings_from_docs(self): def test_oneof_shapes_discriminator_from_docs(self):
"""Test the overlapping strings example from documentation""" """Test the shapes discriminator example from documentation"""
schema = { schema = {
"title": "SimpleExample", "title": "Shape",
"type": "object", "type": "object",
"properties": { "properties": {
"value": { "shape": {
"oneOf": [ "oneOf": [
{"type": "string", "maxLength": 6}, {
{"type": "string", "minLength": 4}, "type": "object",
] "properties": {
} "type": {"const": "circle"},
}, "radius": {"type": "number", "minimum": 0},
"required": ["value"],
}
Model = SchemaConverter.build(schema)
# Valid: Short string (matches first schema only)
obj1 = Model(value="hi")
self.assertEqual(obj1.value, "hi")
# Valid: Long string (matches second schema only)
obj2 = Model(value="very long string")
self.assertEqual(obj2.value, "very long string")
# Invalid: Medium string (matches BOTH schemas - violates oneOf)
with self.assertRaises(ValueError) as cm:
Model(value="hello") # 5 chars: matches maxLength=6 AND minLength=4
self.assertIn("matches multiple oneOf schemas", str(cm.exception))
def test_oneof_shapes_discriminator_from_docs(self):
"""Test the shapes discriminator example from documentation"""
schema = {
"title": "Shape",
"type": "object",
"properties": {
"shape": {
"oneOf": [
{
"type": "object",
"properties": {
"type": {"const": "circle"},
"radius": {"type": "number", "minimum": 0},
},
"required": ["type", "radius"],
}, },
{ "required": ["type", "radius"],
"type": "object", },
"properties": { {
"type": {"const": "rectangle"}, "type": "object",
"width": {"type": "number", "minimum": 0}, "properties": {
"height": {"type": "number", "minimum": 0}, "type": {"const": "rectangle"},
}, "width": {"type": "number", "minimum": 0},
"required": ["type", "width", "height"], "height": {"type": "number", "minimum": 0},
}, },
], "required": ["type", "width", "height"],
"discriminator": {"propertyName": "type"}, },
} ],
}, "discriminator": {"propertyName": "type"},
"required": ["shape"], }
} },
"required": ["shape"],
}
Model = SchemaConverter.build(schema) Model = SchemaConverter.build(schema)
# Valid: Circle # Valid: Circle
circle = Model(shape={"type": "circle", "radius": 5.0}) circle = Model(shape={"type": "circle", "radius": 5.0})
self.assertEqual(circle.shape.type, "circle") self.assertEqual(circle.shape.type, "circle")
self.assertEqual(circle.shape.radius, 5.0) self.assertEqual(circle.shape.radius, 5.0)
# Valid: Rectangle # Valid: Rectangle
rectangle = Model(shape={"type": "rectangle", "width": 10, "height": 20}) rectangle = Model(shape={"type": "rectangle", "width": 10, "height": 20})
self.assertEqual(rectangle.shape.type, "rectangle") self.assertEqual(rectangle.shape.type, "rectangle")
self.assertEqual(rectangle.shape.width, 10) self.assertEqual(rectangle.shape.width, 10)
self.assertEqual(rectangle.shape.height, 20) self.assertEqual(rectangle.shape.height, 20)
# Invalid: Wrong properties for the type # Invalid: Wrong properties for the type
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
Model(shape={"type": "circle", "width": 10}) Model(shape={"type": "circle", "width": 10})