from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
class AuthenticationMiddleware(BaseHTTPMiddleware):
def __init__(self, app, protected_routes=None):
super().__init__(app)
self.protected_routes = protected_routes or ["/docs", "/", "/mcp", "/health", "/info", "/tools"]
async def dispatch(self, request: Request, call_next):
"""Process the request and check authentication for protected routes."""
path = request.url.path
check_protected = any(path.startswith(route) for route in self.protected_routes)
# Check if the current path matches any protected route
if not check_protected:
response = await call_next(request)
return response
else:
auth_header = request.headers.get("Authorization")
if not auth_header:
return Response(
"Authorization header missing",
status_code=401
)
if auth_header != "test-token":
return Response(
"Unauthorized: Invalid or missing authentication token",
status_code=401
)
else:
response = await call_next(request)
return response