sanitize_schema_typing.py•13.4 kB
#!/usr/bin/env python3
#
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
"""Standalone convenience script used to massage the typing.py.
The `genkit/typing.py` file is generated by datamodel-codegen. However, since
the tool doesn't currently provide options to generate exactly the kind of code
we need, we use this convenience script to parse the Python source code, walk
the AST, modify it to include the bits we need and regenerate the code for
eventual use within our codebase.
Transformations applied:
- We remove the model_config attribute from classes that ineherit from
  RootModel.
- We add the `populate_by_name=True` parameter to ensure serialization uses
  camelCase for attributes since the JS implementation uses camelCase and Python
  uses snake_case. The codegen pass is configured to generate snake_case for a
  Pythonic API but serialize to camelCase in order to be compatible with
  runtimes.
- We add a license header
- We add a header indicating that this file has been generated by a code
  generator pass.
- We add the ability to use forward references.
- Add docstrings if missing.
"""
import ast
import sys
from _ast import AST
from datetime import datetime
from pathlib import Path
from typing import Type, cast
class ClassTransformer(ast.NodeTransformer):
    """AST transformer that modifies class definitions."""
    def __init__(self) -> None:
        """Initialize the ClassTransformer."""
        self.modified = False
    def is_rootmodel_class(self, node: ast.ClassDef) -> bool:
        """Check if a class definition is a RootModel class."""
        for base in node.bases:
            if isinstance(base, ast.Name) and base.id == 'RootModel':
                return True
            elif isinstance(base, ast.Subscript):
                value = base.value
                if isinstance(value, ast.Name) and value.id == 'RootModel':
                    return True
        return False
    def create_model_config(self, existing_config: ast.Call | None = None) -> ast.Assign:
        """Create or update a model_config assignment.
        Ensures populate_by_name=True and extra='forbid', keeping other existing
        settings.
        """
        keywords = []
        found_populate = False
        # Preserve existing keywords if present, but override 'extra'
        if existing_config:
            for kw in existing_config.keywords:
                if kw.arg == 'populate_by_name':
                    # Ensure it's set to True
                    keywords.append(
                        ast.keyword(
                            arg='populate_by_name',
                            value=ast.Constant(value=True),
                        )
                    )
                    found_populate = True
                elif kw.arg == 'extra':
                    # Skip the existing 'extra', we will enforce 'forbid'
                    continue
                else:
                    keywords.append(kw)  # Keep other existing settings
        # Always add extra='forbid'
        keywords.append(ast.keyword(arg='extra', value=ast.Constant(value='forbid')))
        # Add populate_by_name=True if it wasn't found
        if not found_populate:
            keywords.append(ast.keyword(arg='populate_by_name', value=ast.Constant(value=True)))
        # Sort keywords for consistent output (optional but good practice)
        keywords.sort(key=lambda kw: kw.arg or '')
        return ast.Assign(
            targets=[ast.Name(id='model_config')],
            value=ast.Call(func=ast.Name(id='ConfigDict'), args=[], keywords=keywords),
        )
    def has_model_config(self, node: ast.ClassDef) -> ast.Assign | None:
        """Check if class already has model_config assignment and return it."""
        for item in node.body:
            if isinstance(item, ast.Assign):
                targets = item.targets
                if len(targets) == 1 and isinstance(targets[0], ast.Name):
                    if targets[0].id == 'model_config':
                        return item
        return None
    def visit_ClassDef(self, _node: ast.ClassDef) -> ast.ClassDef:  # noqa: N802
        """Visit and transform a class definition node.
        Args:
            node: The ClassDef AST node to transform.
        Returns:
            The transformed ClassDef node.
        """
        # First apply base class transformations recursively
        node = super().generic_visit(_node)
        new_body: list[ast.stmt | ast.Constant | ast.Assign] = []
        # Handle Docstrings
        if not node.body or not isinstance(node.body[0], ast.Expr) or not isinstance(node.body[0].value, ast.Constant):
            # Generate a more descriptive docstring based on class type
            if self.is_rootmodel_class(node):
                docstring = f'Root model for {node.name.lower().replace("_", " ")}.'
            elif any(isinstance(base, ast.Name) and base.id == 'BaseModel' for base in node.bases):
                docstring = f'Model for {node.name.lower().replace("_", " ")} data.'
            elif any(isinstance(base, ast.Name) and base.id == 'Enum' for base in node.bases):
                n = node.name.lower().replace('_', ' ')
                docstring = f'Enumeration of {n} values.'
            else:
                docstring = f'{node.name} data type class.'
            new_body.append(ast.Expr(value=ast.Constant(value=docstring)))
            self.modified = True
        else:  # Ensure existing docstring is kept
            new_body.append(node.body[0])
        # Handle model_config for BaseModel and RootModel
        existing_model_config_assign = self.has_model_config(node)
        existing_model_config_call = None
        if existing_model_config_assign and isinstance(existing_model_config_assign.value, ast.Call):
            existing_model_config_call = existing_model_config_assign.value
        # Determine start index for iterating original body (skip docstring)
        body_start_index = (
            1 if (node.body and isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value, ast.Str)) else 0
        )
        if self.is_rootmodel_class(node):
            # Remove model_config from RootModel classes
            for stmt in node.body[body_start_index:]:
                # Skip existing model_config
                if isinstance(stmt, ast.Assign) and any(
                    isinstance(target, ast.Name) and target.id == 'model_config' for target in stmt.targets
                ):
                    self.modified = True  # Mark modified even if removing
                    continue
                new_body.append(stmt)
        elif any(isinstance(base, ast.Name) and base.id == 'BaseModel' for base in node.bases):
            # Add or update model_config for BaseModel classes
            added_config = False
            for stmt in node.body[body_start_index:]:
                if isinstance(stmt, ast.Assign) and any(
                    isinstance(target, ast.Name) and target.id == 'model_config' for target in stmt.targets
                ):
                    # Update existing model_config
                    updated_config = self.create_model_config(existing_model_config_call)
                    # Check if the config actually changed
                    if ast.dump(updated_config) != ast.dump(stmt):
                        new_body.append(updated_config)
                        self.modified = True
                    else:
                        new_body.append(stmt)  # No change needed
                    added_config = True
                else:
                    new_body.append(stmt)
            if not added_config:
                # Add model_config if it wasn't present
                # Insert after potential docstring
                insert_pos = 1 if len(new_body) > 0 and isinstance(new_body[0], ast.Expr) else 0
                new_body.insert(insert_pos, self.create_model_config())
                self.modified = True
        elif any(isinstance(base, ast.Name) and base.id == 'Enum' for base in node.bases):
            # Uppercase Enum members
            for stmt in node.body[body_start_index:]:
                if isinstance(stmt, ast.Assign) and len(stmt.targets) == 1 and isinstance(stmt.targets[0], ast.Name):
                    target_name = stmt.targets[0].id
                    uppercase_name = target_name.upper()
                    if target_name != uppercase_name:
                        stmt.targets[0].id = uppercase_name
                        self.modified = True
                new_body.append(stmt)
        else:
            # For other classes, just copy the rest of the body
            new_body.extend(node.body[body_start_index:])
        node.body = cast(list[ast.stmt], new_body)
        return node
def add_header(content: str) -> str:
    """Add the generated header to the content."""
    header = '''# Copyright {year} Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
#
# DO NOT EDIT: Generated by `generate_schema_typing` from `genkit-schemas.json`.
"""Schema types module defining the core data models for Genkit.
This module contains Pydantic models that define the structure and validation
for various data types used throughout the Genkit framework, including messages,
actions, tools, and configuration options.
"""
'''
    # Ensure there's exactly one newline between header and content
    # and future import is right after the header block's closing quotes.
    future_import = 'from __future__ import annotations'
    str_enum_block = """
import sys # noqa
if sys.version_info < (3, 11):  # noqa
    from strenum import StrEnum  # noqa
else: # noqa
    from enum import StrEnum  # noqa
"""
    header_text = header.format(year=datetime.now().year)
    # Remove existing future import and StrEnum import from content.
    lines = content.splitlines()
    filtered_lines = [
        line for line in lines if line.strip() != future_import and line.strip() != 'from enum import StrEnum'
    ]
    cleaned_content = '\n'.join(filtered_lines)
    final_output = header_text + future_import + '\n' + str_enum_block + '\n\n' + cleaned_content
    if not final_output.endswith('\n'):
        final_output += '\n'
    return final_output
def process_file(filename: str) -> None:
    """Process a Python file to remove model_config from RootModel classes.
    This function reads a Python file, processes its AST to remove model_config
    from RootModel classes, and writes the modified code back to the file.
    Args:
        filename: Path to the Python file to process.
    Raises:
        FileNotFoundError: If the input file does not exist.
        SyntaxError: If the input file contains invalid Python syntax.
    """
    path = Path(filename)
    if not path.is_file():
        print(f'Error: File not found: {filename}')
        sys.exit(1)
    try:
        with open(path, encoding='utf-8') as f:
            source = f.read()
        tree = ast.parse(source)
        class_transformer = ClassTransformer()
        modified_tree = class_transformer.visit(tree)
        # Generate source from potentially modified AST
        ast.fix_missing_locations(modified_tree)
        modified_source_no_header = ast.unparse(modified_tree)
        # Add header and specific imports correctly
        final_source = add_header(modified_source_no_header)
        # Write back only if content has changed (header or AST)
        if final_source != source:
            with open(path, 'w', encoding='utf-8') as f:
                f.write(final_source)
            print(f'Successfully processed and updated {filename}')
        else:
            print(f'No changes needed for {filename}')
    except SyntaxError as e:
        print(f'Error: Invalid Python syntax in {filename}: {e}')
        sys.exit(1)
def main() -> None:
    """Main entry point for the script.
    This function processes command line arguments and calls the appropriate
    functions to process the schema types file.
    Usage:
        python script.py <filename>
    """
    if len(sys.argv) != 2:
        print('Usage: python script.py <filename>')
        sys.exit(1)
    process_file(sys.argv[1])
if __name__ == '__main__':
    main()