from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Generic, Literal
from mcp.server.elicitation import (
CancelledElicitation,
DeclinedElicitation,
)
from pydantic import BaseModel
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
from pydantic_core import core_schema
from typing_extensions import TypeVar
from fastmcp.utilities.json_schema import compress_schema
from fastmcp.utilities.logging import get_logger
from fastmcp.utilities.types import get_cached_typeadapter
__all__ = [
"AcceptedElicitation",
"CancelledElicitation",
"DeclinedElicitation",
"get_elicitation_schema",
"ScalarElicitationType",
]
logger = get_logger(__name__)
T = TypeVar("T", default=Any)
class ElicitationJsonSchema(GenerateJsonSchema):
"""Custom JSON schema generator for MCP elicitation that always inlines enums.
MCP elicitation requires inline enum schemas without $ref/$defs references.
This generator ensures enums are always generated inline for compatibility.
Optionally adds enumNames for better UI display when available.
"""
def generate_inner(self, schema: core_schema.CoreSchema) -> JsonSchemaValue:
"""Override to prevent ref generation for enums."""
# For enum schemas, bypass the ref mechanism entirely
if schema["type"] == "enum":
# Directly call our custom enum_schema without going through handler
# This prevents the ref/defs mechanism from being invoked
return self.enum_schema(schema)
# For all other types, use the default implementation
return super().generate_inner(schema)
def enum_schema(self, schema: core_schema.EnumSchema) -> JsonSchemaValue:
"""Generate inline enum schema with optional enumNames for better UI.
If enum members have a _display_name_ attribute or custom __str__,
we'll include enumNames for better UI representation.
"""
# Get the base schema from parent
result = super().enum_schema(schema)
# Try to add enumNames if the enum has display-friendly names
enum_cls = schema.get("cls")
if enum_cls:
members = schema.get("members", [])
enum_names = []
has_custom_names = False
for member in members:
# Check if member has a custom display name attribute
if hasattr(member, "_display_name_"):
enum_names.append(member._display_name_)
has_custom_names = True
# Or use the member name with better formatting
else:
# Convert SNAKE_CASE to Title Case for display
display_name = member.name.replace("_", " ").title()
enum_names.append(display_name)
if display_name != member.value:
has_custom_names = True
# Only add enumNames if they differ from the values
if has_custom_names:
result["enumNames"] = enum_names
return result
# we can't use the low-level AcceptedElicitation because it only works with BaseModels
class AcceptedElicitation(BaseModel, Generic[T]):
"""Result when user accepts the elicitation."""
action: Literal["accept"] = "accept"
data: T
@dataclass
class ScalarElicitationType(Generic[T]):
value: T
def get_elicitation_schema(response_type: type[T]) -> dict[str, Any]:
"""Get the schema for an elicitation response.
Args:
response_type: The type of the response
"""
# Use custom schema generator that inlines enums for MCP compatibility
schema = get_cached_typeadapter(response_type).json_schema(
schema_generator=ElicitationJsonSchema
)
schema = compress_schema(schema)
# Validate the schema to ensure it follows MCP elicitation requirements
validate_elicitation_json_schema(schema)
return schema
def validate_elicitation_json_schema(schema: dict[str, Any]) -> None:
"""Validate that a JSON schema follows MCP elicitation requirements.
This ensures the schema is compatible with MCP elicitation requirements:
- Must be an object schema
- Must only contain primitive field types (string, number, integer, boolean)
- Must be flat (no nested objects or arrays of objects)
- Allows const fields (for Literal types) and enum fields (for Enum types)
- Only primitive types and their nullable variants are allowed
Args:
schema: The JSON schema to validate
Raises:
TypeError: If the schema doesn't meet MCP elicitation requirements
"""
ALLOWED_TYPES = {"string", "number", "integer", "boolean"}
# Check that the schema is an object
if schema.get("type") != "object":
raise TypeError(
f"Elicitation schema must be an object schema, got type '{schema.get('type')}'. "
"Elicitation schemas are limited to flat objects with primitive properties only."
)
properties = schema.get("properties", {})
for prop_name, prop_schema in properties.items():
prop_type = prop_schema.get("type")
# Handle nullable types
if isinstance(prop_type, list):
if "null" in prop_type:
prop_type = [t for t in prop_type if t != "null"]
if len(prop_type) == 1:
prop_type = prop_type[0]
elif prop_schema.get("nullable", False):
continue # Nullable with no other type is fine
# Handle const fields (Literal types)
if "const" in prop_schema:
continue # const fields are allowed regardless of type
# Handle enum fields (Enum types)
if "enum" in prop_schema:
continue # enum fields are allowed regardless of type
# Handle references to definitions (like Enum types)
if "$ref" in prop_schema:
# Get the referenced definition
ref_path = prop_schema["$ref"]
if ref_path.startswith("#/$defs/"):
def_name = ref_path[8:] # Remove "#/$defs/" prefix
ref_def = schema.get("$defs", {}).get(def_name, {})
# If the referenced definition has an enum, it's allowed
if "enum" in ref_def:
continue
# If the referenced definition has a type that's allowed, it's allowed
ref_type = ref_def.get("type")
if ref_type in ALLOWED_TYPES:
continue
# If we can't determine what the ref points to, reject it for safety
raise TypeError(
f"Elicitation schema field '{prop_name}' contains a reference '{ref_path}' "
"that could not be validated. Only references to enum types or primitive types are allowed."
)
# Handle union types (oneOf/anyOf)
if "oneOf" in prop_schema or "anyOf" in prop_schema:
union_schemas = prop_schema.get("oneOf", []) + prop_schema.get("anyOf", [])
for union_schema in union_schemas:
# Allow const and enum in unions
if "const" in union_schema or "enum" in union_schema:
continue
union_type = union_schema.get("type")
if union_type not in ALLOWED_TYPES:
raise TypeError(
f"Elicitation schema field '{prop_name}' has union type '{union_type}' which is not "
f"a primitive type. Only {ALLOWED_TYPES} are allowed in elicitation schemas."
)
continue
# Check if it's a primitive type
if prop_type not in ALLOWED_TYPES:
raise TypeError(
f"Elicitation schema field '{prop_name}' has type '{prop_type}' which is not "
f"a primitive type. Only {ALLOWED_TYPES} are allowed in elicitation schemas."
)
# Check for nested objects or arrays of objects (not allowed)
if prop_type == "object":
raise TypeError(
f"Elicitation schema field '{prop_name}' is an object, but nested objects are not allowed. "
"Elicitation schemas must be flat objects with primitive properties only."
)
if prop_type == "array":
items_schema = prop_schema.get("items", {})
if items_schema.get("type") == "object":
raise TypeError(
f"Elicitation schema field '{prop_name}' is an array of objects, but arrays of objects are not allowed. "
"Elicitation schemas must be flat objects with primitive properties only."
)