"""
FleetMind MCP Authentication Proxy
Captures API keys from initial SSE connections and injects them into tool requests.
This proxy sits between MCP clients and the FastMCP server, solving the
multi-tenant authentication problem by:
1. Capturing api_key from initial /sse?api_key=xxx connection
2. Storing api_key mapped to session_id
3. Injecting api_key into subsequent /messages/?session_id=xxx requests
Architecture:
MCP Client -> Proxy (port 7860) -> FastMCP (port 7861)
"""
import asyncio
import logging
import os
from aiohttp import web, ClientSession, ClientTimeout
from aiohttp.client_exceptions import ClientConnectionResetError
from urllib.parse import urlencode, parse_qs, urlparse, urlunparse
import sys
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)
# Proxy configuration
# On HuggingFace, PORT env var is set to 7860
PROXY_PORT = int(os.getenv("PORT", 7860)) # Public-facing port
FASTMCP_PORT = 7861 # Internal FastMCP server port (fixed)
FASTMCP_HOST = "localhost"
# Session storage: session_id -> api_key
session_api_keys = {}
async def proxy_handler(request):
"""
Main proxy handler - forwards all requests to FastMCP server.
Captures API keys from SSE connections and injects them into tool calls.
"""
path = request.path
query_params = dict(request.query)
# Extract API key if present (initial SSE connection)
api_key = query_params.get('api_key')
session_id = query_params.get('session_id')
# STEP 1: Capture API key from initial SSE connection
if api_key and path == '/sse':
logger.info(f"[AUTH] Captured API key from SSE connection: {api_key[:20]}...")
# Store temporarily - will be linked to session when we see it
session_api_keys['_pending_api_key'] = api_key
# STEP 2: Link session_id to API key (from /messages requests)
if session_id and path.startswith('/messages'):
# Check if we have a stored API key for this session
if session_id not in session_api_keys:
# Link this session to the pending API key
if '_pending_api_key' in session_api_keys:
api_key_to_store = session_api_keys['_pending_api_key']
session_api_keys[session_id] = api_key_to_store
logger.info(f"[AUTH] Linked session {session_id[:12]}... to API key")
# STEP 3: Inject API key into request for FastMCP
stored_api_key = session_api_keys.get(session_id)
if stored_api_key:
query_params['api_key'] = stored_api_key
logger.debug(f"[AUTH] Injected API key into request for session {session_id[:12]}...")
# Build target URL for FastMCP server
query_string = urlencode(query_params) if query_params else ""
target_url = f"http://{FASTMCP_HOST}:{FASTMCP_PORT}{path}"
if query_string:
target_url += f"?{query_string}"
# Forward request to FastMCP
# For SSE connections: total=None disables overall timeout (keeps connection alive)
# Still use socket timeouts for safety (sock_connect, sock_read)
async with ClientSession(
timeout=ClientTimeout(
total=None, # No total timeout for long-lived SSE connections
sock_connect=30, # 30 seconds for initial connection
sock_read=300 # 5 minutes for individual socket reads
)
) as session:
try:
# Copy headers
headers = dict(request.headers)
# Remove host header to avoid conflicts
headers.pop('Host', None)
# Forward request based on method
if request.method == 'GET':
async with session.get(target_url, headers=headers) as resp:
# For SSE, stream the response
if 'text/event-stream' in resp.content_type:
# Create streaming response for SSE
response = web.StreamResponse(
status=resp.status,
reason=resp.reason,
headers=dict(resp.headers)
)
await response.prepare(request)
# Background task to send keep-alive pings (prevents timeout)
async def send_keepalive():
try:
while True:
await asyncio.sleep(30) # Send ping every 30 seconds
await response.write(b":\n\n") # SSE comment (ignored by client)
except asyncio.CancelledError:
pass
keepalive_task = asyncio.create_task(send_keepalive())
try:
# Stream chunks from FastMCP to client
async for chunk in resp.content.iter_any():
await response.write(chunk)
await response.write_eof()
finally:
# Cancel keep-alive task when streaming completes
keepalive_task.cancel()
try:
await keepalive_task
except asyncio.CancelledError:
pass
return response
else:
# For regular responses, read entire body
body = await resp.read()
resp_headers = dict(resp.headers)
return web.Response(
body=body,
status=resp.status,
headers=resp_headers
)
elif request.method == 'POST':
body = await request.read()
async with session.post(target_url, data=body, headers=headers) as resp:
resp_body = await resp.read()
# Don't pass content_type separately - it's already in headers
resp_headers = dict(resp.headers)
return web.Response(
body=resp_body,
status=resp.status,
headers=resp_headers
)
else:
# Forward other methods
async with session.request(
request.method,
target_url,
data=await request.read(),
headers=headers
) as resp:
body = await resp.read()
return web.Response(
body=body,
status=resp.status,
headers=dict(resp.headers)
)
except (ClientConnectionResetError, ConnectionResetError) as e:
# Client disconnected - this is normal for SSE connections
# Log at DEBUG level to reduce noise
logger.debug(f"[SSE] Client disconnected: {e}")
return web.Response(text="Client disconnected", status=499)
except Exception as e:
import traceback
error_details = traceback.format_exc()
logger.error(f"[ERROR] Proxy error: {type(e).__name__}: {e}")
logger.error(f"[ERROR] Traceback:\n{error_details}")
return web.Response(
text=f"Proxy error: {type(e).__name__}: {str(e)}",
status=502
)
async def health_check(request):
"""Health check endpoint"""
return web.Response(text="FleetMind Proxy OK", status=200)
def create_app():
"""Create and configure the proxy application"""
app = web.Application()
# Health check endpoint
app.router.add_get('/health', health_check)
# Proxy all other requests
app.router.add_route('*', '/{path:.*}', proxy_handler)
return app
async def main():
"""Start the proxy server"""
print("\n" + "=" * 70)
print("FleetMind MCP Authentication Proxy")
print("=" * 70)
print(f"Proxy listening on: http://0.0.0.0:{PROXY_PORT}")
print(f"Forwarding to FastMCP: http://{FASTMCP_HOST}:{FASTMCP_PORT}")
print("=" * 70)
print("[OK] Multi-tenant authentication enabled")
print("[OK] API keys captured from SSE connections")
print("[OK] Sessions automatically linked to API keys")
print("=" * 70 + "\n")
app = create_app()
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, '0.0.0.0', PROXY_PORT)
await site.start()
logger.info(f"[OK] Proxy server started on port {PROXY_PORT}")
logger.info(f"[OK] Forwarding to FastMCP on {FASTMCP_HOST}:{FASTMCP_PORT}")
# Keep running
try:
await asyncio.Event().wait()
except KeyboardInterrupt:
logger.info("Shutting down proxy server...")
await runner.cleanup()
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
print("\nProxy server stopped.")
sys.exit(0)