from __future__ import annotations as _annotations
import re
from abc import ABC, abstractmethod
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Literal
from .exceptions import UserError
JsonSchema = dict[str, Any]
@dataclass(init=False)
class JsonSchemaTransformer(ABC):
"""Walks a JSON schema, applying transformations to it at each level.
The transformer is called during a model's prepare_request() step to build the JSON schema
before it is sent to the model provider.
Note: We may eventually want to rework tools to build the JSON schema from the type directly, using a subclass of
pydantic.json_schema.GenerateJsonSchema, rather than making use of this machinery.
"""
def __init__(
self,
schema: JsonSchema,
*,
strict: bool | None = None,
prefer_inlined_defs: bool = False,
simplify_nullable_unions: bool = False, # TODO (v2): Remove this, no longer used
):
self.schema = schema
self.strict = strict
"""The `strict` parameter forces the conversion of the original JSON schema (`self.schema`) of a `ToolDefinition` or `OutputObjectDefinition` to a format supported by the model provider.
The "strict mode" offered by model providers ensures that the model's output adheres closely to the defined schema. However, not all model providers offer it, and their support for various schema features may differ. For example, a model provider's required schema may not support certain validation constraints like `minLength` or `pattern`.
"""
self.is_strict_compatible = True
"""Whether the schema is compatible with strict mode.
This value is used to set `ToolDefinition.strict` or `OutputObjectDefinition.strict` when their values are `None`.
"""
self.prefer_inlined_defs = prefer_inlined_defs
self.simplify_nullable_unions = simplify_nullable_unions
self.defs: dict[str, JsonSchema] = self.schema.get('$defs', {})
self.refs_stack: list[str] = []
self.recursive_refs = set[str]()
@abstractmethod
def transform(self, schema: JsonSchema) -> JsonSchema:
"""Make changes to the schema."""
return schema
def walk(self) -> JsonSchema:
schema = deepcopy(self.schema)
# First, handle everything but $defs:
schema.pop('$defs', None)
handled = self._handle(schema)
if not self.prefer_inlined_defs and self.defs:
handled['$defs'] = {k: self._handle(v) for k, v in self.defs.items()}
elif self.recursive_refs:
# If we are preferring inlined defs and there are recursive refs, we _have_ to use a $defs+$ref structure
# We try to use whatever the original root key was, but if it is already in use,
# we modify it to avoid collisions.
defs = {key: self.defs[key] for key in self.recursive_refs}
root_ref = self.schema.get('$ref')
root_key = None if root_ref is None else re.sub(r'^#/\$defs/', '', root_ref)
if root_key is None: # pragma: no cover
root_key = self.schema.get('title', 'root')
while root_key in defs:
# Modify the root key until it is not already in use
root_key = f'{root_key}_root'
defs[root_key] = handled
return {'$defs': defs, '$ref': f'#/$defs/{root_key}'}
return handled
def _handle(self, schema: JsonSchema) -> JsonSchema:
nested_refs = 0
if self.prefer_inlined_defs:
while ref := schema.get('$ref'):
key = re.sub(r'^#/\$defs/', '', ref)
if key in self.recursive_refs:
break
if key in self.refs_stack:
self.recursive_refs.add(key)
break # recursive ref can't be unpacked
self.refs_stack.append(key)
nested_refs += 1
def_schema = self.defs.get(key)
if def_schema is None: # pragma: no cover
raise UserError(f'Could not find $ref definition for {key}')
schema = def_schema
# Handle the schema based on its type / structure
type_ = schema.get('type')
if type_ == 'object':
schema = self._handle_object(schema)
elif type_ == 'array':
schema = self._handle_array(schema)
elif type_ is None:
schema = self._handle_union(schema, 'anyOf')
schema = self._handle_union(schema, 'oneOf')
# Apply the base transform
schema = self.transform(schema)
if nested_refs > 0:
self.refs_stack = self.refs_stack[:-nested_refs]
return schema
def _handle_object(self, schema: JsonSchema) -> JsonSchema:
if properties := schema.get('properties'):
handled_properties = {}
for key, value in properties.items():
handled_properties[key] = self._handle(value)
schema['properties'] = handled_properties
if (additional_properties := schema.get('additionalProperties')) is not None:
if isinstance(additional_properties, bool):
schema['additionalProperties'] = additional_properties
else:
schema['additionalProperties'] = self._handle(additional_properties)
if (pattern_properties := schema.get('patternProperties')) is not None:
handled_pattern_properties = {}
for key, value in pattern_properties.items():
handled_pattern_properties[key] = self._handle(value)
schema['patternProperties'] = handled_pattern_properties
return schema
def _handle_array(self, schema: JsonSchema) -> JsonSchema:
if prefix_items := schema.get('prefixItems'):
schema['prefixItems'] = [self._handle(item) for item in prefix_items]
if items := schema.get('items'):
schema['items'] = self._handle(items)
return schema
def _handle_union(self, schema: JsonSchema, union_kind: Literal['anyOf', 'oneOf']) -> JsonSchema:
try:
members = schema.pop(union_kind)
except KeyError:
return schema
handled = [self._handle(member) for member in members]
# TODO (v2): Remove this feature, no longer used
if self.simplify_nullable_unions:
handled = self._simplify_nullable_union(handled)
if len(handled) == 1:
# In this case, no need to retain the union
return handled[0] | schema
# If we have keys besides the union kind (such as title or discriminator), keep them without modifications
schema = schema.copy()
schema[union_kind] = handled
return schema
@staticmethod
def _simplify_nullable_union(cases: list[JsonSchema]) -> list[JsonSchema]:
# TODO (v2): Remove this method, no longer used
if len(cases) == 2 and {'type': 'null'} in cases:
# Find the non-null schema
non_null_schema = next(
(item for item in cases if item != {'type': 'null'}),
None,
)
if non_null_schema:
# Create a new schema based on the non-null part, mark as nullable
new_schema = deepcopy(non_null_schema)
new_schema['nullable'] = True
return [new_schema]
else: # pragma: no cover
# they are both null, so just return one of them
return [cases[0]]
return cases
class InlineDefsJsonSchemaTransformer(JsonSchemaTransformer):
"""Transforms the JSON Schema to inline $defs."""
def __init__(self, schema: JsonSchema, *, strict: bool | None = None):
super().__init__(schema, strict=strict, prefer_inlined_defs=True)
def transform(self, schema: JsonSchema) -> JsonSchema:
return schema