#!/usr/bin/env python3
#
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""Standalone convenience script used to massage the typing.py.
The `genkit/typing.py` file is generated by datamodel-codegen. However, since
the tool doesn't currently provide options to generate exactly the kind of code
we need, we use this convenience script to parse the Python source code, walk
the AST, modify it to include the bits we need and regenerate the code for
eventual use within our codebase.
Transformations applied:
- We remove the model_config attribute from classes that ineherit from
RootModel.
- We add the `populate_by_name=True` parameter to ensure serialization uses
camelCase for attributes since the JS implementation uses camelCase and Python
uses snake_case. The codegen pass is configured to generate snake_case for a
Pythonic API but serialize to camelCase in order to be compatible with
runtimes.
- We add a license header
- We add a header indicating that this file has been generated by a code
generator pass.
- We add the ability to use forward references.
- Add docstrings if missing.
"""
from __future__ import annotations
import ast
import re
import sys
from datetime import datetime
from pathlib import Path
from typing import cast
class ClassTransformer(ast.NodeTransformer):
"""AST transformer that modifies class definitions."""
def __init__(self) -> None:
"""Initialize the ClassTransformer."""
self.modified = False
self.schema_fields_to_suppress: list[ast.AnnAssign] = []
def is_rootmodel_class(self, node: ast.ClassDef) -> bool:
"""Check if a class definition is a RootModel class."""
for base in node.bases:
if isinstance(base, ast.Name) and base.id == 'RootModel':
return True
elif isinstance(base, ast.Subscript):
value = base.value
if isinstance(value, ast.Name) and value.id == 'RootModel':
return True
return False
def create_model_config(
self, existing_config: ast.Call | None = None, frozen: bool = False, has_schema_field: bool = False
) -> ast.AnnAssign:
"""Create or update a model_config assignment with proper type annotation.
Creates: model_config: ClassVar[ConfigDict] = ConfigDict(...)
Ensures alias_generator=to_camel, populate_by_name=True, and extra='forbid',
keeping other existing settings.
Args:
existing_config: Existing ConfigDict call to preserve settings from.
frozen: Whether to add frozen=True for immutable models.
has_schema_field: Whether the class has a 'schema' field. If True,
adds protected_namespaces=() to allow using 'schema' as a field name.
"""
keywords = []
found_populate = False
found_frozen = False
# Preserve existing keywords if present, but override 'extra' and 'alias_generator'
if existing_config:
for kw in existing_config.keywords:
if kw.arg == 'populate_by_name':
# Ensure it's set to True
keywords.append(
ast.keyword(
arg='populate_by_name',
value=ast.Constant(value=True),
)
)
found_populate = True
elif kw.arg == 'extra':
# Skip the existing 'extra', we will enforce 'forbid'
continue
elif kw.arg == 'alias_generator':
# Skip existing alias_generator, we will add our own
continue
elif kw.arg == 'protected_namespaces':
# Skip existing protected_namespaces, we will add our own if needed
continue
elif kw.arg == 'frozen':
# Use the provided 'frozen' value
keywords.append(
ast.keyword(
arg='frozen',
value=ast.Constant(value=frozen),
)
)
found_frozen = True
else:
keywords.append(kw) # Keep other existing settings
# Always add extra='forbid'
keywords.append(ast.keyword(arg='extra', value=ast.Constant(value='forbid')))
# Add populate_by_name=True if it wasn't found
if not found_populate:
keywords.append(ast.keyword(arg='populate_by_name', value=ast.Constant(value=True)))
# Add frozen=True if it was requested and not found
if frozen and not found_frozen:
keywords.append(ast.keyword(arg='frozen', value=ast.Constant(value=True)))
# Always add alias_generator=to_camel for snake_case -> camelCase serialization
keywords.append(
ast.keyword(
arg='alias_generator',
value=ast.Name(id='to_camel', ctx=ast.Load()),
)
)
# Add protected_namespaces=() if class has a 'schema' field
# This allows using 'schema' as a field name without Pydantic warnings
# (following the same pattern as dotpromptz library)
if has_schema_field:
keywords.append(
ast.keyword(
arg='protected_namespaces',
value=ast.Tuple(elts=[], ctx=ast.Load()),
)
)
# Sort keywords for consistent output (optional but good practice)
keywords.sort(key=lambda kw: kw.arg or '')
# Create ClassVar[ConfigDict] annotation
annotation = ast.Subscript(
value=ast.Name(id='ClassVar', ctx=ast.Load()),
slice=ast.Name(id='ConfigDict', ctx=ast.Load()),
ctx=ast.Load(),
)
return ast.AnnAssign(
target=ast.Name(id='model_config', ctx=ast.Store()),
annotation=annotation,
value=ast.Call(func=ast.Name(id='ConfigDict'), args=[], keywords=keywords),
simple=1,
)
def has_model_config(self, node: ast.ClassDef) -> ast.Assign | ast.AnnAssign | None:
"""Check if class already has model_config assignment and return it."""
for item in node.body:
if isinstance(item, ast.Assign):
targets = item.targets
if len(targets) == 1 and isinstance(targets[0], ast.Name):
if targets[0].id == 'model_config':
return item
elif isinstance(item, ast.AnnAssign):
if isinstance(item.target, ast.Name) and item.target.id == 'model_config':
return item
return None
def has_schema_field(self, node: ast.ClassDef) -> bool:
"""Check if class has a 'schema' or 'schema_' field.
This is used to determine if we need to add protected_namespaces=()
to the model_config to allow 'schema' as a field name.
"""
for item in node.body:
if isinstance(item, ast.AnnAssign) and isinstance(item.target, ast.Name):
if item.target.id in ('schema', 'schema_'):
return True
return False
def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AnnAssign:
"""Visit and transform annotated assignment.
- Transform Role type to Role | str for flexibility
- Remove 'alias' keyword from Field() calls since alias_generator handles it
- Rename schema_ to schema (with protected_namespaces=() in model_config)
"""
# Transform Role type
if isinstance(node.annotation, ast.Name) and node.annotation.id == 'Role':
node.annotation = ast.BinOp(
left=ast.Name(id='Role', ctx=ast.Load()),
op=ast.BitOr(),
right=ast.Name(id='str', ctx=ast.Load()),
)
self.modified = True
# Get the field name if this is a simple name target
field_name = None
if isinstance(node.target, ast.Name):
field_name = node.target.id
# Rename schema_ to schema
# We use protected_namespaces=() in model_config to allow 'schema' as a field name
# This follows the same pattern as the dotpromptz library
# Mark this field for pyrefly suppression comment (added in post-processing)
if field_name == 'schema_':
node.target = ast.Name(id='schema', ctx=ast.Store())
self.modified = True
self.schema_fields_to_suppress.append(node)
# Handle Field() calls
if isinstance(node.value, ast.Call):
func = node.value.func
if isinstance(func, ast.Name) and func.id == 'Field':
original_keywords = node.value.keywords
# Remove 'alias' keyword since alias_generator=to_camel handles it
new_keywords = [kw for kw in original_keywords if kw.arg != 'alias']
if len(new_keywords) != len(original_keywords):
node.value.keywords = new_keywords
self.modified = True
return node
def visit_ClassDef(self, node: ast.ClassDef) -> object:
"""Visit and transform a class definition node.
Args:
node: The ClassDef AST node to transform.
Returns:
The transformed ClassDef node.
"""
# First apply base class transformations recursively
node = cast(ast.ClassDef, super().generic_visit(node))
new_body: list[ast.stmt | ast.Constant | ast.Assign] = []
# Handle Docstrings
if (
not node.body
or not isinstance(node.body[0], ast.Expr)
or not isinstance(node.body[0].value, ast.Constant)
or not isinstance(node.body[0].value.value, str)
):
# Generate a more descriptive docstring based on class type
if self.is_rootmodel_class(node):
docstring = f'Root model for {node.name.lower().replace("_", " ")}.'
elif any(isinstance(base, ast.Name) and base.id == 'BaseModel' for base in node.bases):
docstring = f'Model for {node.name.lower().replace("_", " ")} data.'
elif any(isinstance(base, ast.Name) and base.id == 'Enum' for base in node.bases):
n = node.name.lower().replace('_', ' ')
docstring = f'Enumeration of {n} values.'
else:
docstring = f'{node.name} data type class.'
new_body.append(ast.Expr(value=ast.Constant(value=docstring)))
self.modified = True
else: # Ensure existing docstring is kept
new_body.append(node.body[0])
# Handle model_config for BaseModel and RootModel
existing_model_config_assign = self.has_model_config(node)
existing_model_config_call = None
if existing_model_config_assign and isinstance(existing_model_config_assign.value, ast.Call):
existing_model_config_call = existing_model_config_assign.value
# Determine start index for iterating original body (skip docstring)
body_start_index = (
1
if (
node.body
and isinstance(node.body[0], ast.Expr)
and isinstance(node.body[0].value, ast.Constant)
and isinstance(node.body[0].value.value, str)
)
else 0
)
if self.is_rootmodel_class(node):
# Remove model_config from RootModel classes
for stmt in node.body[body_start_index:]:
# Skip existing model_config (both Assign and AnnAssign)
if isinstance(stmt, ast.Assign) and any(
isinstance(target, ast.Name) and target.id == 'model_config' for target in stmt.targets
):
self.modified = True # Mark modified even if removing
continue
if (
isinstance(stmt, ast.AnnAssign)
and isinstance(stmt.target, ast.Name)
and stmt.target.id == 'model_config'
):
self.modified = True
continue
new_body.append(stmt)
elif any(isinstance(base, ast.Name) and base.id == 'BaseModel' for base in node.bases):
# Add or update model_config for BaseModel classes
added_config = False
frozen = node.name == 'PathMetadata'
has_schema = self.has_schema_field(node)
for stmt in node.body[body_start_index:]:
# Check for model_config (both Assign and AnnAssign)
is_model_config = False
if (
isinstance(stmt, ast.Assign)
and any(isinstance(target, ast.Name) and target.id == 'model_config' for target in stmt.targets)
) or (
isinstance(stmt, ast.AnnAssign)
and isinstance(stmt.target, ast.Name)
and stmt.target.id == 'model_config'
):
is_model_config = True
if is_model_config:
# Update existing model_config
updated_config = self.create_model_config(
existing_model_config_call, frozen=frozen, has_schema_field=has_schema
)
# Check if the config actually changed
if ast.dump(updated_config) != ast.dump(stmt):
new_body.append(updated_config)
self.modified = True
else:
new_body.append(stmt) # No change needed
added_config = True
elif (
isinstance(stmt, ast.Assign)
and any(isinstance(target, ast.Name) and target.id == '__hash__' for target in stmt.targets)
and frozen
):
# Skip manual __hash__ for PathMetadata
self.modified = True
continue
else:
new_body.append(stmt)
if not added_config:
# Add model_config if it wasn't present
# Insert after potential docstring
insert_pos = 1 if len(new_body) > 0 and isinstance(new_body[0], ast.Expr) else 0
new_body.insert(insert_pos, self.create_model_config(frozen=frozen, has_schema_field=has_schema))
self.modified = True
elif any(isinstance(base, ast.Name) and base.id == 'Enum' for base in node.bases):
# Uppercase Enum members
for stmt in node.body[body_start_index:]:
if isinstance(stmt, ast.Assign) and len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
target_name = stmt.targets[0].id
uppercase_name = target_name.upper()
if target_name != uppercase_name:
stmt.targets[0].id = uppercase_name
self.modified = True
new_body.append(stmt)
else:
# For other classes, just copy the rest of the body
new_body.extend(node.body[body_start_index:])
# PYTHON EXTENSION: Add resources field to GenerateActionOptions
if node.name == 'GenerateActionOptions':
self._inject_resources_field(new_body)
node.body = cast(list[ast.stmt], new_body)
return node
def _inject_resources_field(self, body: list[ast.stmt | ast.Constant | ast.Assign]) -> None:
"""Inject resources field after tools field in GenerateActionOptions.
This adds the resources field to match the JS SDK implementation without
modifying the shared schema file. The JS SDK manually adds this field in
model-types.ts line 398.
"""
tools_index = -1
for i, stmt in enumerate(body):
if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name):
if stmt.target.id == 'tools':
tools_index = i
break
if tools_index == -1:
return # tools field not found, skip injection
# Check if resources field already exists
for stmt in body:
if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name):
if stmt.target.id == 'resources':
return # Already exists, don't add again
# Create the resources field: resources: list[str] | None = None
resources_field = ast.AnnAssign(
target=ast.Name(id='resources', ctx=ast.Store()),
annotation=ast.BinOp(
left=ast.Subscript(
value=ast.Name(id='list', ctx=ast.Load()), slice=ast.Name(id='str', ctx=ast.Load()), ctx=ast.Load()
),
op=ast.BitOr(),
right=ast.Constant(value=None),
),
value=ast.Constant(value=None),
simple=1,
)
# Insert after tools field
body.insert(tools_index + 1, resources_field)
self.modified = True
def fix_field_defaults(content: str) -> str:
"""Fix Field(None) and Field(None, ...) to use default=None for pyright compatibility.
Pyright doesn't recognize Field(None) as providing a default value,
but it does recognize Field(default=None).
"""
# Replace Field(None) with Field(default=None)
content = content.replace('Field(None)', 'Field(default=None)')
# Replace Field(None, other_args) with Field(default=None, other_args)
content = re.sub(r'Field\(None,', 'Field(default=None,', content)
return content
def add_schema_field_suppression(content: str) -> str:
"""Add pyrefly suppression comment for 'schema' fields.
The 'schema' field name shadows a method in Pydantic's BaseModel.
While protected_namespaces=() allows this in Pydantic, pyrefly still
reports it as a bad-override. We add a suppression comment.
"""
# Find lines that define a 'schema' field and add the suppression comment
# Pattern: " schema: ... = Field(...)"
pattern = r'^( schema: .+= Field\(.+\))$'
replacement = r' # pyrefly: ignore[bad-override] - Pydantic protected_namespaces=() allows schema field\n\1'
content = re.sub(pattern, replacement, content, flags=re.MULTILINE)
return content
def add_header(content: str) -> str:
"""Add the generated header to the content."""
header = '''# Copyright {year} Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
#
# DO NOT EDIT: Generated by `generate_schema_typing` from `genkit-schemas.json`.
"""Schema types module defining the core data models for Genkit.
This module contains Pydantic models that define the structure and validation
for various data types used throughout the Genkit framework, including messages,
actions, tools, and configuration options.
"""
'''
# Ensure there's exactly one newline between header and content
# and future import is right after the header block's closing quotes.
future_import = 'from __future__ import annotations'
compat_import_block = """
import sys
from typing import ClassVar
from genkit.core._compat import StrEnum
from pydantic.alias_generators import to_camel
"""
header_text = header.format(year=datetime.now().year)
# Remove existing future import and StrEnum import from content.
lines = content.splitlines()
filtered_lines = [
line for line in lines if line.strip() != future_import and line.strip() != 'from enum import StrEnum'
]
cleaned_content = '\n'.join(filtered_lines)
final_output = header_text + future_import + '\n' + compat_import_block + '\n\n' + cleaned_content
if not final_output.endswith('\n'):
final_output += '\n'
return final_output
def process_file(filename: str) -> None:
"""Process a Python file to remove model_config from RootModel classes.
This function reads a Python file, processes its AST to remove model_config
from RootModel classes, and writes the modified code back to the file.
Args:
filename: Path to the Python file to process.
Raises:
FileNotFoundError: If the input file does not exist.
SyntaxError: If the input file contains invalid Python syntax.
"""
path = Path(filename)
if not path.is_file():
sys.exit(1)
try:
with Path(path).open(encoding='utf-8') as f:
source = f.read()
tree = ast.parse(source)
class_transformer = ClassTransformer()
modified_tree = class_transformer.visit(tree)
# Generate source from potentially modified AST
ast.fix_missing_locations(modified_tree)
modified_source_no_header = ast.unparse(modified_tree)
# Fix Field(None) to Field(default=None) for pyright compatibility
modified_source_no_header = fix_field_defaults(modified_source_no_header)
# Add pyrefly suppression for 'schema' fields that shadow BaseModel.schema
modified_source_no_header = add_schema_field_suppression(modified_source_no_header)
# Add header and specific imports correctly
final_source = add_header(modified_source_no_header)
# Write back only if content has changed (header or AST)
if final_source != source:
with Path(path).open('w', encoding='utf-8') as f:
f.write(final_source)
except SyntaxError:
sys.exit(1)
def main() -> None:
"""Main entry point for the script.
This function processes command line arguments and calls the appropriate
functions to process the schema types file.
Usage:
python script.py <filename>
"""
if len(sys.argv) != 2:
sys.exit(1)
process_file(sys.argv[1])
if __name__ == '__main__':
main()