Skip to main content
Glama

@arizeai/phoenix-mcp

Official
by Arize-ai
transform.py17.2 kB
import ast import sys from pathlib import Path from typing import Callable, Literal, Mapping, Sequence # ============================================================================= # String-to-DateTime field type conversions # ============================================================================= # Some fields in the schema are `str` but need to be converted to `datetime` # to reconcile with the data type that the client actually returns at runtime. # # This mapping defines which fields in which classes should have their type # annotations converted from string-based types to datetime during the # dataclass-to-TypedDict transformation process. # # Format: { # "ClassName": ["field_name1", "field_name2", ...], # ... # } # # Supported transformations: # - str → datetime # - Optional[str] → Optional[datetime] # - str = "default" → NotRequired[datetime] (with default removed) # - Optional[str] = None → NotRequired[datetime] (with Optional and default removed) # # Example: # Before: created_at: str # After: created_at: datetime # # Before: updated_at: Optional[str] = None # After: updated_at: NotRequired[datetime] STR_TO_DATETIME_ALTERATIONS = { "DatasetVersion": ["created_at"], } # ============================================================================= # AST Transformer to convert dataclass definitions to TypedDict definitions. # ============================================================================= class ConvertDataClassToTypedDict(ast.NodeTransformer): def __init__(self): self.current_class_name = None def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.AST: """ Replace the dataclasses import with a TypedDict import from typing. """ if node.module == "dataclasses": return ast.ImportFrom( module="typing", names=[ast.alias(name="TypedDict", asname=None)], level=0, ) return node def visit_ClassDef(self, node: ast.ClassDef) -> ast.AST: """ Convert a class definition into a TypedDict definition. Also reorders the "type" field (if present) to the top for readability. """ # Remember the current class name before descending into fields old_class_name = self.current_class_name self.current_class_name = node.name # Visit and transform all statements in the class body. new_body = [self.visit(child) for child in node.body] # Look for a field named "type" defined as a Literal and move it to the front. for index, stmt in enumerate(new_body): if ( isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.target.id == "type" and isinstance(stmt.annotation, ast.Subscript) and isinstance(stmt.annotation.value, ast.Name) and stmt.annotation.value.id == "Literal" and isinstance(stmt.annotation.slice, ast.Constant) ): new_body = [new_body[index]] + new_body[:index] + new_body[index + 1 :] break # Restore the previous class name self.current_class_name = old_class_name # Redefine the class so that it inherits from TypedDict. return ast.ClassDef( name=node.name, bases=[ast.Name(id="TypedDict", ctx=ast.Load())], keywords=node.keywords, body=new_body, decorator_list=[], ) def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AST: """ Process annotated assignments: - Convert str fields to datetime based on STR_TO_DATETIME_ALTERATIONS. - Rename fields ending with "_" (like schema_ or json_) by stripping the underscore. - Convert default values on fields (when present) to a NotRequired[...] annotation. - Change `type: str = "xyz"` into `type: Literal["xyz"]`. - If a field is Optional[...] with a default value, remove the Optional. """ # Convert str fields to datetime based on STR_TO_DATETIME_ALTERATIONS if ( self.current_class_name in STR_TO_DATETIME_ALTERATIONS and isinstance(node.target, ast.Name) and node.target.id in STR_TO_DATETIME_ALTERATIONS[self.current_class_name] ): # Handle direct str annotation if isinstance(node.annotation, ast.Name) and node.annotation.id == "str": node = ast.AnnAssign( target=node.target, annotation=ast.Name(id="datetime", ctx=ast.Load()), value=node.value, simple=node.simple, ) # Handle Optional[str] annotation elif ( isinstance(node.annotation, ast.Subscript) and isinstance(node.annotation.value, ast.Name) and node.annotation.value.id == "Optional" and isinstance(node.annotation.slice, ast.Name) and node.annotation.slice.id == "str" ): node = ast.AnnAssign( target=node.target, annotation=ast.Subscript( value=ast.Name(id="Optional", ctx=ast.Load()), slice=ast.Name(id="datetime", ctx=ast.Load()), ctx=ast.Load(), ), value=node.value, simple=node.simple, ) # Rename fields ending with "_" to remove the trailing underscore. if isinstance(node.target, ast.Name) and node.target.id in ("schema_", "json_"): new_target = ast.Name(id=node.target.id.rstrip("_"), ctx=ast.Store()) node = ast.AnnAssign( target=new_target, annotation=node.annotation, value=node.value, simple=node.simple, ) # If there is a default value, perform further transformations. if isinstance(node.value, ast.Constant): # Convert `type: str = "xyz"` into `type: Literal["xyz"]` if ( isinstance(node.target, ast.Name) and node.target.id == "type" and isinstance(node.annotation, ast.Name) and node.annotation.id == "str" ): return ast.AnnAssign( target=node.target, annotation=ast.Subscript( value=ast.Name(id="Literal", ctx=ast.Load()), slice=node.value, ctx=ast.Load(), ), value=None, # Remove default value simple=node.simple, ) # Convert an Optional annotation (with a default) to the inner type. if ( isinstance(node.annotation, ast.Subscript) and isinstance(node.annotation.value, ast.Name) and node.annotation.value.id == "Optional" ): node = ast.AnnAssign( target=node.target, annotation=node.annotation.slice, value=node.value, simple=node.simple, ) # Remove the default value and wrap the annotation in NotRequired. return ast.AnnAssign( target=node.target, annotation=ast.Subscript( value=ast.Name(id="NotRequired", ctx=ast.Load()), slice=node.annotation, ctx=ast.Load(), ), value=None, # Default value is removed simple=node.simple, ) return node def transform_dataclass(code: str) -> ast.AST: """ Parse the provided code, insert an import for NotRequired, and transform dataclass definitions into TypedDict definitions. Args: code: A string of Python source code. Returns: The transformed AST. """ parsed_ast: ast.Module = ast.parse(code) # Insert the imports for NotRequired and datetime before the first class. for index, node in enumerate(parsed_ast.body): if isinstance(node, ast.ClassDef): import_notrequired = ast.ImportFrom( module="typing_extensions", names=[ast.alias(name="NotRequired", asname=None)], level=0, ) import_datetime = ast.ImportFrom( module="datetime", names=[ast.alias(name="datetime", asname=None)], level=0, ) parsed_ast.body.insert(index, import_notrequired) parsed_ast.body.insert(index + 1, import_datetime) break # Remove top-level Union type definitions parsed_ast.body = [ node for node in parsed_ast.body if not ( isinstance(node, ast.Assign) and isinstance(node.value, ast.Subscript) and isinstance(node.value.value, ast.Name) and node.value.value.id == "Union" ) ] transformer = ConvertDataClassToTypedDict() transformed_ast = transformer.visit(parsed_ast) return transformed_ast # ============================================================================= # Functions to adjust class definitions by removing inherited fields. # ============================================================================= # Mapping from a class name to a list of its parent class names. PARENTS: Mapping[str, Sequence[str]] = { "Prompt": ["PromptData"], "PromptVersion": ["PromptVersionData"], "SpanAnnotation": ["SpanAnnotationData"], "LocalUser": ["LocalUserData"], "OAuth2User": ["OAuth2UserData"], } def get_ancestor_fields( class_name: str, class_nodes: Mapping[str, ast.ClassDef], parent_map: Mapping[str, Sequence[str]] = PARENTS, ) -> set[str]: """ Recursively collects the field names defined in all ancestor classes. Args: class_name: The name of the class to inspect. class_nodes: Mapping of class names to their AST ClassDef nodes. parent_map: Mapping from a class name to a list of its parent class names. Returns: A set of field names from all ancestors. """ if class_name not in parent_map: return set() fields: set[str] = set() for parent_name in parent_map[class_name]: parent_node: ast.ClassDef = class_nodes[parent_name] for stmt in parent_node.body: if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name): fields.add(stmt.target.id) # Recursively add fields from higher ancestors. fields |= get_ancestor_fields(parent_name, class_nodes, parent_map) return fields def remove_inherited_fields( class_nodes: Mapping[str, ast.ClassDef], parent_map: Mapping[str, Sequence[str]] = PARENTS, ) -> Mapping[str, ast.ClassDef]: """ For each class that inherits from others, remove any field that is already defined in its ancestors to improve readability. Args: class_nodes: Mapping from class name to its AST ClassDef node. parent_map: Mapping from class name to a list of its parent class names. Returns: A new mapping from class name to a modified AST ClassDef node. """ new_class_nodes: dict[str, ast.ClassDef] = {} for class_name, node in class_nodes.items(): # If the class has no parents, leave it as-is. if class_name not in parent_map: new_class_nodes[class_name] = node continue # Create explicit bases for the class from its parent names. bases: list[ast.expr] = [ ast.Name(id=parent, ctx=ast.Load()) for parent in parent_map[class_name] ] # Collect the field names defined in the class. child_field_names: set[str] = { stmt.target.id for stmt in node.body if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) } # Ensure every statement in the body is an AnnAssign. assert len(child_field_names) == len(node.body), "Every field must be an AnnAssign" # Collect all ancestor field names. ancestor_field_names: set[str] = get_ancestor_fields(class_name, class_nodes, parent_map) assert ancestor_field_names < child_field_names, ( "Ancestor fields must be a subset of child fields" ) # Remove any inherited field from the class body. inherited_fields: set[str] = ancestor_field_names.intersection(child_field_names) new_body: list[ast.stmt] = [ stmt for stmt in node.body if not ( isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) and stmt.target.id in inherited_fields ) ] new_class_nodes[class_name] = ast.ClassDef( name=node.name, bases=bases, keywords=node.keywords, body=new_body, decorator_list=node.decorator_list, ) return new_class_nodes def topologically_sort_classes( class_nodes: Mapping[str, ast.ClassDef], parent_map: Mapping[str, Sequence[str]] = PARENTS, ) -> list[ast.ClassDef]: """ Sort class definitions so that each parent class appears before its children. Args: class_nodes: Mapping from class name to its AST ClassDef node. parent_map: Mapping from class name to a list of its parent class names. Returns: A list of AST ClassDef nodes in topological order. Raises: ValueError: If a cycle is detected in the inheritance hierarchy. """ sorted_classes: list[ast.ClassDef] = [] visit_state: dict[str, Literal["visiting", "visited"]] = {} def visit(node: ast.ClassDef) -> None: class_name: str = node.name if class_name in visit_state: if visit_state[class_name] == "visiting": raise ValueError(f"Cycle detected at class: {class_name}") return # Already visited visit_state[class_name] = "visiting" for parent_name in parent_map.get(class_name, []): visit(class_nodes[parent_name]) visit_state[class_name] = "visited" sorted_classes.append(node) for node in class_nodes.values(): visit(node) return sorted_classes # ============================================================================= # File rewriting logic. # ============================================================================= def rewrite_file( directory: Path, input_filename: str, output_filename: str, transform: Callable[[str], ast.AST], ) -> None: """ Reads a Python file, applies the AST transformation and class adjustments, then writes the updated code to an output file. Args: directory: The directory containing the input file. input_filename: The name of the input file. output_filename: The name of the output file. transform: A function that converts a code string to an AST. """ file_path: Path = directory / input_filename with open(file_path, "r") as file: code: str = file.read() transformed_ast: ast.AST = transform(code) assert isinstance(transformed_ast, ast.Module), "The transformed AST must be a module." # Extract all class definitions from the AST. class_nodes: dict[str, ast.ClassDef] = { node.name: node for node in transformed_ast.body if isinstance(node, ast.ClassDef) } # Remove inherited fields from subclasses. cleaned_classes: Mapping[str, ast.ClassDef] = remove_inherited_fields(class_nodes) # Topologically sort classes so that parent classes come first. sorted_classes: list[ast.ClassDef] = topologically_sort_classes(cleaned_classes) # Preserve non-class statements and then append the sorted class definitions. non_class_statements: list[ast.stmt] = [ stmt for stmt in transformed_ast.body if not isinstance(stmt, ast.ClassDef) ] new_body: list[ast.stmt] = non_class_statements + sorted_classes new_module: ast.Module = ast.Module(body=new_body, type_ignores=[]) new_module = ast.fix_missing_locations(new_module) output_code: str = ast.unparse(new_module) with open(directory / output_filename, "w") as file: file.write('"""Do not edit"""\n\n') file.write(output_code) # ============================================================================= # Main entry point. # ============================================================================= if __name__ == "__main__": if len(sys.argv) != 2: print("Usage: python transform.py <directory>") sys.exit(1) directory: Path = Path(sys.argv[1]) rewrite_file( directory, ".dataclass.py", "__init__.py", transform_dataclass, )

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/Arize-ai/phoenix'

If you have feedback or need assistance with the MCP directory API, please join our Discord server