[FEATURE] Adds Better Const Typing #38
@@ -3,7 +3,7 @@ from jambo.types.json_schema_type import JSONSchemaNativeTypes
|
||||
from jambo.types.type_parser_options import TypeParserOptions
|
||||
|
||||
from pydantic import AfterValidator
|
||||
from typing_extensions import Annotated, Any, Unpack
|
||||
from typing_extensions import Annotated, Any, Literal, Unpack
|
||||
|
||||
|
||||
class ConstTypeParser(GenericTypeParser):
|
||||
@@ -33,11 +33,19 @@ class ConstTypeParser(GenericTypeParser):
|
||||
return const_type, parsed_properties
|
||||
|
||||
def _build_const_type(self, const_value):
|
||||
def _validate_const_value(value: Any) -> Any:
|
||||
if value != const_value:
|
||||
raise ValueError(
|
||||
f"Value must be equal to the constant value: {const_value}"
|
||||
)
|
||||
return value
|
||||
# Try to use Literal for hashable types (required for discriminated unions)
|
||||
# Fall back to validator approach for non-hashable types
|
||||
try:
|
||||
# Test if the value is hashable (can be used in Literal)
|
||||
hash(const_value)
|
||||
return Literal[const_value]
|
||||
except TypeError:
|
||||
# Non-hashable type (like list, dict), use validator approach
|
||||
def _validate_const_value(value: Any) -> Any:
|
||||
if value != const_value:
|
||||
raise ValueError(
|
||||
f"Value must be equal to the constant value: {const_value}"
|
||||
)
|
||||
return value
|
||||
|
||||
return Annotated[type(const_value), AfterValidator(_validate_const_value)]
|
||||
return Annotated[type(const_value), AfterValidator(_validate_const_value)]
|
||||
@@ -38,9 +38,7 @@ class StringTypeParser(GenericTypeParser):
|
||||
def from_properties_impl(
|
||||
self, name, properties, **kwargs: Unpack[TypeParserOptions]
|
||||
):
|
||||
mapped_properties = self.mappings_properties_builder(
|
||||
properties, **kwargs
|
||||
)
|
||||
mapped_properties = self.mappings_properties_builder(properties, **kwargs)
|
||||
|
||||
format_type = properties.get("format")
|
||||
if not format_type:
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from jambo.parser import ConstTypeParser
|
||||
|
||||
from typing_extensions import Annotated, get_args, get_origin
|
||||
from typing_extensions import Annotated, Literal, get_args, get_origin
|
||||
|
||||
from unittest import TestCase
|
||||
|
||||
|
||||
class TestConstTypeParser(TestCase):
|
||||
def test_const_type_parser(self):
|
||||
def test_const_type_parser_hashable_value(self):
|
||||
"""Test const parser with hashable values (uses Literal)"""
|
||||
parser = ConstTypeParser()
|
||||
|
||||
expected_const_value = "United States of America"
|
||||
@@ -16,8 +17,60 @@ class TestConstTypeParser(TestCase):
|
||||
"country", properties
|
||||
)
|
||||
|
||||
# Check that we get a Literal type for hashable values
|
||||
self.assertEqual(get_origin(parsed_type), Literal)
|
||||
self.assertEqual(get_args(parsed_type), (expected_const_value,))
|
||||
|
||||
self.assertEqual(parsed_properties["default"], expected_const_value)
|
||||
|
||||
def test_const_type_parser_non_hashable_value(self):
|
||||
"""Test const parser with non-hashable values (uses Annotated with validator)"""
|
||||
parser = ConstTypeParser()
|
||||
|
||||
expected_const_value = [1, 2, 3] # Lists are not hashable
|
||||
properties = {"const": expected_const_value}
|
||||
|
||||
parsed_type, parsed_properties = parser.from_properties_impl(
|
||||
"list_const", properties
|
||||
)
|
||||
|
||||
# Check that we get an Annotated type for non-hashable values
|
||||
self.assertEqual(get_origin(parsed_type), Annotated)
|
||||
self.assertIn(str, get_args(parsed_type))
|
||||
self.assertIn(list, get_args(parsed_type))
|
||||
|
||||
self.assertEqual(parsed_properties["default"], expected_const_value)
|
||||
|
||||
def test_const_type_parser_integer_value(self):
|
||||
"""Test const parser with integer values (uses Literal)"""
|
||||
parser = ConstTypeParser()
|
||||
|
||||
expected_const_value = 42
|
||||
properties = {"const": expected_const_value}
|
||||
|
||||
parsed_type, parsed_properties = parser.from_properties_impl(
|
||||
"int_const", properties
|
||||
)
|
||||
|
||||
# Check that we get a Literal type for hashable values
|
||||
self.assertEqual(get_origin(parsed_type), Literal)
|
||||
self.assertEqual(get_args(parsed_type), (expected_const_value,))
|
||||
|
||||
self.assertEqual(parsed_properties["default"], expected_const_value)
|
||||
|
||||
def test_const_type_parser_boolean_value(self):
|
||||
"""Test const parser with boolean values (uses Literal)"""
|
||||
parser = ConstTypeParser()
|
||||
|
||||
expected_const_value = True
|
||||
properties = {"const": expected_const_value}
|
||||
|
||||
parsed_type, parsed_properties = parser.from_properties_impl(
|
||||
"bool_const", properties
|
||||
)
|
||||
|
||||
# Check that we get a Literal type for hashable values
|
||||
self.assertEqual(get_origin(parsed_type), Literal)
|
||||
self.assertEqual(get_args(parsed_type), (expected_const_value,))
|
||||
|
||||
self.assertEqual(parsed_properties["default"], expected_const_value)
|
||||
|
||||
@@ -46,4 +99,4 @@ class TestConstTypeParser(TestCase):
|
||||
self.assertIn(
|
||||
"Const type invalid_country must have 'const' value of allowed types",
|
||||
str(context.exception),
|
||||
)
|
||||
)
|
||||
@@ -701,6 +701,29 @@ class TestSchemaConverter(TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
Model(name="Canada")
|
||||
|
||||
def test_const_type_parser_with_non_hashable_value(self):
|
||||
schema = {
|
||||
"title": "Country",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"const": ["Brazil"],
|
||||
}
|
||||
},
|
||||
"required": ["name"],
|
||||
}
|
||||
|
||||
Model = SchemaConverter.build(schema)
|
||||
|
||||
obj = Model()
|
||||
self.assertEqual(obj.name, ["Brazil"])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
obj.name = ["Argentina"]
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
Model(name=["Argentina"])
|
||||
|
||||
def test_null_type_parser(self):
|
||||
schema = {
|
||||
"title": "Test",
|
||||
|
||||
Reference in New Issue
Block a user