import logging
from fastmcp import Context, FastMCP
from fastmcp.server.dependencies import get_http_request
from fastmcp.server.middleware import Middleware, MiddlewareContext
from langchain_google_genai import ChatGoogleGenerativeAI
from mcp.types import InitializeRequest
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("AIServer")
# Create a FastMCP server
mcp = FastMCP("AIServer")
from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain_core.messages import AIMessage, HumanMessage
# Global session storage for LLMs, tokens, and histories
_session_llms = {}
_session_histories = {}
@dataclass
class Configure:
gemini_model_name: str = "gemini-2.0-flash"
config = Configure()
class AuthMiddleware(Middleware):
async def on_initialize(
self, context: MiddlewareContext[InitializeRequest], call_next
):
"""
Extract the API key during the initialization phase.
Note: session_id might not be available yet here.
"""
token = None
# 1. Try to extract from HTTP headers
try:
request = get_http_request()
if request:
token = request.headers.get("Bearer")
except Exception:
logger.exception("Failed to extract token from HTTP headers")
# 2. Try to extract from initializationOptions
if not token and context.message.params.initializationOptions:
token = context.message.params.initializationOptions.get("api_key")
if token:
try:
local_llm = ChatGoogleGenerativeAI(
model=config.gemini_model_name, google_api_key=token
)
logger.info(f"Authorization successful")
except Exception as e:
logger.error(f"Failed to process API key: {e}")
raise ValueError("Unauthorized: API key processing failed.")
return await call_next(context)
mcp.add_middleware(AuthMiddleware())
def get_session_llm(ctx: Context) -> ChatGoogleGenerativeAI:
"""Helper to retrieve or lazily initialize the LLM for the current session."""
session_id = ctx.session_id
# 1. Check if already initialized for this session
if session_id in _session_llms:
return _session_llms[session_id]
# 2. Check if we have a Bearer token in the current request headers
token = None
try:
request = get_http_request()
if request:
token = request.headers.get("Bearer")
if token:
llm = ChatGoogleGenerativeAI(
model=config.gemini_model_name, google_api_key=token
)
_session_llms[session_id] = llm
_session_histories[session_id] = InMemoryChatMessageHistory()
return llm
except Exception:
logger.exception("Failed to extract token from HTTP headers")
return None
def get_session_history(session_id: str) -> InMemoryChatMessageHistory:
"""Helper to retrieve or initialize history for a session."""
if session_id not in _session_histories:
_session_histories[session_id] = InMemoryChatMessageHistory()
return _session_histories[session_id]
@mcp.tool()
async def ask_ai(prompt: str, ctx: Context) -> str:
"""
Asks the session-specific AI a question, maintaining conversation memory.
Args:
prompt: The question to ask the AI
"""
llm = get_session_llm(ctx)
if not llm:
return "Error: No API key was provided for this session. Please provide a 'Bearer' token in headers or 'api_key' in initialization options."
session_id = ctx.session_id
history = get_session_history(session_id)
try:
# Get existing messages from history
messages = history.messages
# Add current user message
messages.append(HumanMessage(content=prompt))
# Invoke LLM with full context
response = await llm.ainvoke(messages)
# Persist conversation to memory
await history.aadd_messages([HumanMessage(content=prompt), response])
return response.content
except Exception as e:
return f"Error invoking AI: {str(e)}"
if __name__ == "__main__":
# Run the server using streamable-http to support sessions better
mcp.run(transport="streamable-http")