"""Validation middleware for input validation and sanitization."""
from typing import Any, Callable, Dict, List
import mcp.types as types
from .base import BaseMiddleware
from ..utils.validation import validate_json_schema, sanitize_input
from ..utils.logging import get_logger
logger = get_logger(__name__)
class ValidationMiddleware(BaseMiddleware):
"""Middleware for input validation and sanitization."""
def __init__(self, sanitize_inputs: bool = True, max_input_length: int = 10000) -> None:
"""Initialize validation middleware.
Args:
sanitize_inputs: Whether to sanitize string inputs
max_input_length: Maximum allowed input length
"""
super().__init__("validation")
self.sanitize_inputs = sanitize_inputs
self.max_input_length = max_input_length
self._logger = get_logger(f"{__name__}.ValidationMiddleware")
def _sanitize_arguments(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Sanitize arguments dictionary.
Args:
arguments: Arguments to sanitize
Returns:
Sanitized arguments
"""
if not self.sanitize_inputs:
return arguments
sanitized = {}
for key, value in arguments.items():
if isinstance(value, str):
sanitized[key] = sanitize_input(value, self.max_input_length)
elif isinstance(value, dict):
sanitized[key] = self._sanitize_arguments(value)
elif isinstance(value, list):
sanitized[key] = [
sanitize_input(item, self.max_input_length) if isinstance(item, str) else item for item in value
]
else:
sanitized[key] = value
return sanitized
def _validate_argument_types(self, arguments: Dict[str, Any]) -> None:
"""Basic argument type validation.
Args:
arguments: Arguments to validate
Raises:
ValueError: If validation fails
"""
for key, value in arguments.items():
# Check for extremely large inputs
if isinstance(value, str) and len(value) > self.max_input_length:
raise ValueError(f"Argument '{key}' exceeds maximum length of {self.max_input_length}")
# Check for null bytes and other potentially harmful content
if isinstance(value, str) and "\x00" in value:
raise ValueError(f"Argument '{key}' contains null bytes")
def _validate_uri(self, uri: Any) -> None:
"""Validate resource URI.
Args:
uri: URI to validate (can be string or AnyUrl)
Raises:
ValueError: If URI is invalid
"""
# Convert URI to string if it's not already
uri_str = str(uri) if uri else ""
if not uri_str or not uri_str.strip():
raise ValueError("URI cannot be empty")
if len(uri_str) > self.max_input_length:
raise ValueError(f"URI exceeds maximum length of {self.max_input_length}")
# Basic URI validation
if "\x00" in uri_str:
raise ValueError("URI contains null bytes")
# Check for potentially dangerous schemes
dangerous_schemes = ["javascript:", "data:", "vbscript:"]
uri_lower = uri_str.lower()
for scheme in dangerous_schemes:
if uri_lower.startswith(scheme):
self._logger.warning(f"Potentially dangerous URI scheme detected: {uri_str}")
async def process_tool_call(
self, name: str, arguments: Dict[str, Any], next_handler: Callable[[str, Dict[str, Any]], Any]
) -> List[types.ContentBlock]:
"""Process tool call with validation."""
# Validate tool name
if not name or not name.strip():
raise ValueError("Tool name cannot be empty")
if len(name) > 100: # Reasonable limit for tool names
raise ValueError("Tool name is too long")
# Validate and sanitize arguments
try:
self._validate_argument_types(arguments)
sanitized_args = self._sanitize_arguments(arguments)
self._logger.debug(f"Validated tool call: {name}")
return await next_handler(name, sanitized_args)
except Exception as e:
self._logger.error(f"Validation failed for tool {name}: {str(e)}")
raise
async def process_resource_read(self, uri: Any, next_handler: Callable[[Any], Any]) -> str:
"""Process resource read with validation."""
try:
self._validate_uri(uri)
self._logger.debug(f"Validated resource URI: {uri}")
return await next_handler(uri)
except Exception as e:
self._logger.error(f"Validation failed for resource {uri}: {str(e)}")
raise
async def process_prompt_get(
self, name: str, arguments: Dict[str, str] | None, next_handler: Callable[[str, Dict[str, str] | None], Any]
) -> types.GetPromptResult:
"""Process prompt get with validation."""
# Validate prompt name
if not name or not name.strip():
raise ValueError("Prompt name cannot be empty")
if len(name) > 100:
raise ValueError("Prompt name is too long")
try:
# Validate and sanitize arguments if present
sanitized_args = arguments
if arguments:
# Convert to Any type for validation, then back to str
args_any = {k: v for k, v in arguments.items()}
self._validate_argument_types(args_any)
sanitized_any = self._sanitize_arguments(args_any)
sanitized_args = {k: str(v) for k, v in sanitized_any.items()}
self._logger.debug(f"Validated prompt: {name}")
return await next_handler(name, sanitized_args)
except Exception as e:
self._logger.error(f"Validation failed for prompt {name}: {str(e)}")
raise