from fastapi import FastAPI, Request, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from contextlib import asynccontextmanager
import logging
from typing import Optional
import os
import redis
import sqlalchemy
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from .core.router import router as core_router
# Database setup
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./app.db")
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379")
engine = create_engine(DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
# Redis setup
redis_client = redis.from_url(redis_url)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Blacklisted tokens cache (in production, use Redis)
blacklisted_tokens = set()
class AuthMiddleware:
def __init__(self):
self.excluded_paths = {"/health", "/docs", "/openapi.json", "/favicon.ico"}
async def __call__(self, request: Request, call_next):
# Skip middleware for excluded paths
if request.url.path in self.excluded_paths:
return await call_next(request)
# Only apply to /mcp/ routes
if not request.url.path.startswith("/mcp/"):
return await call_next(request)
# Extract token from Authorization header
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Missing or invalid authorization header")
token = auth_header.split(" ")[1]
# Check if token is blacklisted
if token in blacklisted_tokens:
raise HTTPException(status_code=401, detail="Token has been revoked")
# Simple JWT verification (in production, use proper JWT library)
try:
# For now, we'll assume the token format is "user_id:signature"
parts = token.split(":")
if len(parts) != 2:
raise HTTPException(status_code=401, detail="Invalid token format")
user_id = parts[0]
if not user_id.isdigit():
raise HTTPException(status_code=401, detail="Invalid user ID in token")
# Add user_id to request state for use in endpoints
request.state.user_id = int(user_id)
except Exception as e:
logger.error(f"Token verification failed: {e}")
raise HTTPException(status_code=401, detail="Invalid token")
return await call_next(request)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
logger.info("Starting up application...")
try:
# Initialize database connection
Base.metadata.create_all(bind=engine)
logger.info("Database initialized")
# Test Redis connection
redis_client.ping()
logger.info("Redis connection established")
except Exception as e:
logger.error(f"Failed to initialize: {e}")
raise
yield
# Shutdown
logger.info("Shutting down application...")
try:
redis_client.close()
engine.dispose()
logger.info("Connections closed")
except Exception as e:
logger.error(f"Error during shutdown: {e}")
# Create FastAPI app
app = FastAPI(
title="MCP FastAPI Backend",
description="FastAPI backend for MCP (Model Context Protocol) integration with external services",
version="1.0.0",
lifespan=lifespan
)
# Add middleware
auth_middleware = AuthMiddleware()
app.middleware("http")(auth_middleware)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Configure appropriately for production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Add trusted host middleware
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=["*"] # Configure appropriately for production
)
# Include routers
app.include_router(core_router, prefix="/api")
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "ok"}
@app.get("/")
async def root():
"""Root endpoint"""
return {
"message": "MCP FastAPI Backend",
"version": "1.0.0",
"docs": "/docs"
}