Skip to main content
Glama
juanqui
by juanqui
middleware.py13 kB
"""FastAPI middleware for CORS, error handling, and other cross-cutting concerns.""" import logging import time from typing import Any, Dict from fastapi import FastAPI, HTTPException, Request, Response from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware from ..config import ServerConfig from .models.web_models import ErrorResponse logger = logging.getLogger(__name__) class RequestLoggingMiddleware(BaseHTTPMiddleware): """Middleware for logging HTTP requests and responses.""" async def dispatch(self, request: Request, call_next) -> Response: """Process HTTP request with logging. Args: request: HTTP request call_next: Next middleware/handler Returns: HTTP response """ start_time = time.time() # Log request logger.info(f"HTTP {request.method} {request.url.path} - Started") try: response = await call_next(request) # Calculate request duration duration = time.time() - start_time # Log response logger.info( f"HTTP {request.method} {request.url.path} - " f"Status: {response.status_code} - " f"Duration: {duration:.3f}s" ) # Add duration header response.headers["X-Process-Time"] = str(duration) return response except Exception as e: duration = time.time() - start_time logger.error( f"HTTP {request.method} {request.url.path} - " f"Error: {str(e)} - " f"Duration: {duration:.3f}s" ) raise class ErrorHandlingMiddleware(BaseHTTPMiddleware): """Middleware for handling exceptions and returning standardized error responses.""" async def dispatch(self, request: Request, call_next) -> Response: """Process HTTP request with error handling. Args: request: HTTP request call_next: Next middleware/handler Returns: HTTP response with error handling """ try: return await call_next(request) except HTTPException: # Re-raise HTTP exceptions (they're already handled by FastAPI) raise except RequestValidationError as e: # Handle validation errors logger.warning(f"Validation error on {request.method} {request.url.path}: {e}") error_details = [] for error in e.errors(): error_details.append( { "field": " -> ".join(str(loc) for loc in error["loc"]), "message": error["msg"], "type": error["type"], } ) error_response = ErrorResponse( error="Validation failed", error_code="VALIDATION_ERROR", details={"validation_errors": error_details} ) return JSONResponse(status_code=422, content=error_response.model_dump()) except Exception as e: # Handle unexpected exceptions logger.error(f"Unexpected error on {request.method} {request.url.path}: {e}", exc_info=True) error_response = ErrorResponse( error="Internal server error", error_code="INTERNAL_ERROR", details={"message": str(e)} ) return JSONResponse(status_code=500, content=error_response.model_dump()) class SecurityHeadersMiddleware(BaseHTTPMiddleware): """Middleware for adding security headers to responses.""" def __init__(self, app, add_security_headers: bool = True): """Initialize security headers middleware. Args: app: FastAPI application add_security_headers: Whether to add security headers """ super().__init__(app) self.add_security_headers = add_security_headers async def dispatch(self, request: Request, call_next) -> Response: """Process HTTP request with security headers. Args: request: HTTP request call_next: Next middleware/handler Returns: HTTP response with security headers """ response = await call_next(request) if self.add_security_headers: # Add security headers response.headers.update( { "X-Content-Type-Options": "nosniff", "X-Frame-Options": "DENY", "X-XSS-Protection": "1; mode=block", "Referrer-Policy": "strict-origin-when-cross-origin", "Content-Security-Policy": ( "default-src 'self'; script-src 'self' 'unsafe-inline'; " "style-src 'self' 'unsafe-inline'" ), } ) return response class RateLimitingMiddleware(BaseHTTPMiddleware): """Simple rate limiting middleware.""" def __init__(self, app, max_requests_per_minute: int = 100): """Initialize rate limiting middleware. Args: app: FastAPI application max_requests_per_minute: Maximum requests per minute per client """ super().__init__(app) self.max_requests = max_requests_per_minute self.request_counts: Dict[str, Dict[str, Any]] = {} async def dispatch(self, request: Request, call_next) -> Response: """Process HTTP request with rate limiting. Args: request: HTTP request call_next: Next middleware/handler Returns: HTTP response or rate limit error """ # Get client IP client_ip = self._get_client_ip(request) current_time = time.time() current_minute = int(current_time // 60) # Clean old entries self._cleanup_old_entries(current_minute) # Check rate limit if client_ip in self.request_counts: client_data = self.request_counts[client_ip] if client_data["minute"] == current_minute: if client_data["count"] >= self.max_requests: logger.warning(f"Rate limit exceeded for client {client_ip}") error_response = ErrorResponse( error="Rate limit exceeded", error_code="RATE_LIMIT_EXCEEDED", details={"max_requests_per_minute": self.max_requests, "retry_after": 60 - (current_time % 60)}, ) return JSONResponse( status_code=429, content=error_response.model_dump(), headers={"Retry-After": "60"} ) client_data["count"] += 1 else: # New minute self.request_counts[client_ip] = {"minute": current_minute, "count": 1} else: # New client self.request_counts[client_ip] = {"minute": current_minute, "count": 1} response = await call_next(request) # Add rate limit headers remaining = max(0, self.max_requests - self.request_counts[client_ip]["count"]) response.headers.update( { "X-RateLimit-Limit": str(self.max_requests), "X-RateLimit-Remaining": str(remaining), "X-RateLimit-Reset": str((current_minute + 1) * 60), } ) return response def _get_client_ip(self, request: Request) -> str: """Get client IP address from request. Args: request: HTTP request Returns: Client IP address """ # Check for forwarded headers first (for reverse proxy setups) forwarded_for = request.headers.get("X-Forwarded-For") if forwarded_for: return forwarded_for.split(",")[0].strip() real_ip = request.headers.get("X-Real-IP") if real_ip: return real_ip # Fall back to client host if request.client: return request.client.host return "unknown" def _cleanup_old_entries(self, current_minute: int) -> None: """Clean up old rate limit entries. Args: current_minute: Current minute timestamp """ # Remove entries older than 2 minutes cutoff_minute = current_minute - 2 clients_to_remove = [ client_ip for client_ip, data in self.request_counts.items() if data["minute"] < cutoff_minute ] for client_ip in clients_to_remove: del self.request_counts[client_ip] def setup_middleware(app: FastAPI, config: ServerConfig) -> None: """Set up all middleware for the FastAPI application. Args: app: FastAPI application config: Server configuration """ # CORS middleware (must be added first) app.add_middleware( CORSMiddleware, allow_origins=config.web_cors_origins, allow_credentials=True, allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["*"], expose_headers=["X-Process-Time", "X-RateLimit-*"], ) # Custom middleware (order matters - last added is executed first) # Security headers app.add_middleware(SecurityHeadersMiddleware, add_security_headers=True) # Rate limiting (optional - can be disabled for development) if hasattr(config, "enable_rate_limiting") and config.enable_rate_limiting: app.add_middleware( RateLimitingMiddleware, max_requests_per_minute=getattr(config, "max_requests_per_minute", 100) ) # Error handling app.add_middleware(ErrorHandlingMiddleware) # Request logging app.add_middleware(RequestLoggingMiddleware) logger.info("All middleware configured successfully") def setup_exception_handlers(app: FastAPI) -> None: """Set up global exception handlers for the FastAPI application. Args: app: FastAPI application """ @app.exception_handler(HTTPException) async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse: """Handle HTTP exceptions with standardized error format. Args: request: HTTP request exc: HTTP exception Returns: Standardized error response """ logger.warning(f"HTTP {exc.status_code} on {request.method} {request.url.path}: {exc.detail}") error_response = ErrorResponse( error=exc.detail, error_code=f"HTTP_{exc.status_code}", details={"status_code": exc.status_code} ) return JSONResponse( status_code=exc.status_code, content=error_response.model_dump(), headers=getattr(exc, "headers", None) ) @app.exception_handler(RequestValidationError) async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse: """Handle validation exceptions with detailed error information. Args: request: HTTP request exc: Validation exception Returns: Detailed validation error response """ logger.warning(f"Validation error on {request.method} {request.url.path}: {exc}") error_details = [] for error in exc.errors(): error_details.append( { "field": " -> ".join(str(loc) for loc in error["loc"]), "message": error["msg"], "type": error["type"], "input": error.get("input"), } ) error_response = ErrorResponse( error="Request validation failed", error_code="VALIDATION_ERROR", details={ "validation_errors": error_details, "error_count": len(error_details), }, ) return JSONResponse(status_code=422, content=error_response.model_dump()) @app.exception_handler(Exception) async def general_exception_handler(request: Request, exc: Exception) -> JSONResponse: """Handle unexpected exceptions with logging. Args: request: HTTP request exc: Unexpected exception Returns: Generic error response """ logger.error(f"Unexpected error on {request.method} {request.url.path}: {exc}", exc_info=True) error_response = ErrorResponse( error="An unexpected error occurred", error_code="INTERNAL_SERVER_ERROR", details={"message": "Please try again or contact support if the problem persists"}, ) return JSONResponse(status_code=500, content=error_response.model_dump()) logger.info("Exception handlers configured successfully")

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/juanqui/pdfkb-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server