#!/usr/bin/env python3
"""
MCP server for Strava API integration.
Designed for marathon training and race analysis.
"""
import os
import logging
import argparse
from typing import Optional
from dotenv import load_dotenv
from mcp.server.fastmcp import FastMCP
from starlette.applications import Starlette
from starlette.routing import Route
from starlette.requests import Request
from starlette.responses import Response
from src.client import StravaClient
from src.tools import register_tools
from src.auth import (
get_authorization_url,
get_env_credentials,
complete_auth_with_code,
)
load_dotenv()
logger = logging.getLogger(__name__)
class StravaMCP(FastMCP):
"""Custom FastMCP server with improved error handling for SSE and streamable-http transports."""
def run(self, transport: str = "stdio") -> None:
"""Override run method to support streamable-http transport."""
if transport == "streamable-http":
import asyncio
asyncio.run(self.run_streamable_http_async())
else:
# Use parent's run method for stdio and sse
super().run(transport=transport)
async def run_streamable_http_async(self) -> None:
"""Run the server using streamable-http transport."""
from mcp.server.sse import SseServerTransport
import uvicorn
import json
import asyncio
# Use SSE transport's message handling for streamable-http
# Streamable-http is similar to SSE but without the persistent connection
sse = SseServerTransport("/mcp")
async def handle_mcp_request(request: Request) -> Response:
"""Handle MCP requests via HTTP POST for streamable-http transport."""
try:
body = await request.body()
if not body or len(body.strip()) == 0:
return Response(
"Request body is empty. JSON body is required.",
status_code=400,
)
try:
json.loads(body)
except json.JSONDecodeError as e:
return Response(
f"Invalid JSON in request body: {str(e)}",
status_code=400,
)
# Create a receive function that replays the body
body_sent = False
async def receive_with_body():
nonlocal body_sent
if not body_sent:
body_sent = True
return {"type": "http.request", "body": body, "more_body": False}
return {"type": "http.request", "body": b"", "more_body": False}
# Collect response via send callback
response_parts = []
response_complete = asyncio.Event()
async def collect_response(message):
"""Collect response messages."""
if message.get("type") == "http.response.start":
response_parts.append(("start", message))
elif message.get("type") == "http.response.body":
response_parts.append(("body", message.get("body", b"")))
if not message.get("more_body", False):
response_complete.set()
await request._send(message)
# Process the message
try:
await sse.handle_post_message(
request.scope,
receive_with_body,
collect_response,
)
# Wait for response
try:
await asyncio.wait_for(response_complete.wait(), timeout=10.0)
except asyncio.TimeoutError:
return Response(
"Request timeout",
status_code=504,
)
# Build response from collected parts
if response_parts:
# Find the body parts
response_body = b""
status_code = 200
for part_type, part_data in response_parts:
if part_type == "start":
status_code = part_data.get("status", 200)
elif part_type == "body":
response_body += part_data
if response_body:
try:
response_json = json.loads(response_body.decode())
return Response(
json.dumps(response_json),
media_type="application/json",
status_code=status_code,
)
except (json.JSONDecodeError, UnicodeDecodeError):
return Response(
response_body.decode('utf-8', errors='ignore'),
media_type="text/plain",
status_code=status_code,
)
else:
return Response(
"No response body",
status_code=status_code,
)
else:
return Response(
"No response generated",
status_code=500,
)
except Exception as e:
logger.error(f"Error processing MCP request: {e}", exc_info=True)
return Response(
f"Error processing request: {str(e)}",
status_code=500,
)
except Exception as e:
logger.error(f"Error handling streamable-http request: {e}", exc_info=True)
return Response(
f"Internal server error: {str(e)}",
status_code=500,
)
starlette_app = Starlette(
debug=self.settings.debug,
routes=[
Route("/mcp", endpoint=handle_mcp_request, methods=["POST"]),
],
)
config = uvicorn.Config(
starlette_app,
host=self.settings.host,
port=self.settings.port,
log_level=self.settings.log_level.lower(),
)
server = uvicorn.Server(config)
await server.serve()
async def run_sse_async(self) -> None:
"""Run the server using SSE transport with improved error handling."""
from mcp.server.sse import SseServerTransport
import uvicorn
sse = SseServerTransport("/messages")
async def handle_sse_asgi(scope, receive, send):
"""Raw ASGI handler for SSE connection."""
try:
async with sse.connect_sse(scope, receive, send) as streams:
await self._mcp_server.run(
streams[0],
streams[1],
self._mcp_server.create_initialization_options(),
)
except Exception as e:
logger.error(f"Error in SSE connection: {e}", exc_info=True)
await send({
"type": "http.response.start",
"status": 500,
"headers": [[b"content-type", b"text/plain"]],
})
await send({
"type": "http.response.body",
"body": f"SSE connection error: {str(e)}".encode(),
})
async def handle_sse(request: Request):
"""Wrapper to convert Request to ASGI handler for SSE."""
await handle_sse_asgi(request.scope, request.receive, request._send)
class SSEAlreadySentResponse(Response):
def __init__(self):
super().__init__(status_code=200)
async def __call__(self, scope, receive, send):
pass
return SSEAlreadySentResponse()
async def handle_messages_asgi(scope, receive, send):
"""Raw ASGI handler for POST messages."""
from starlette.requests import Request
from starlette.responses import Response
import json
request = Request(scope, receive)
session_id = request.query_params.get("session_id", "unknown")
try:
body = await request.body()
except Exception as e:
logger.warning(f"[session_id={session_id}] Error reading request body: {e}")
body = b""
if not body or len(body.strip()) == 0:
logger.warning(f"[session_id={session_id}] Request body is empty")
response = Response(
"Request body is empty. JSON body is required.",
status_code=400,
)
return await response(scope, receive, send)
try:
json.loads(body)
except json.JSONDecodeError as e:
logger.warning(
f"[session_id={session_id}] Invalid JSON in request body: {e}. Body: {body[:200]}"
)
response = Response(f"Invalid JSON in request body: {str(e)}", status_code=400)
return await response(scope, receive, send)
body_sent = False
async def receive_with_body():
nonlocal body_sent
if not body_sent:
body_sent = True
return {"type": "http.request", "body": body, "more_body": False}
return {"type": "http.request", "body": b"", "more_body": False}
await sse.handle_post_message(scope, receive_with_body, send)
async def handle_messages(request: Request):
"""Wrapper to convert Request to ASGI handler."""
response_sent = False
async def wrapped_send(message):
nonlocal response_sent
if message.get("type") == "http.response.start":
response_sent = True
await request._send(message)
await handle_messages_asgi(request.scope, request.receive, wrapped_send)
if response_sent:
class AlreadySentResponse(Response):
def __init__(self):
super().__init__(status_code=200)
async def __call__(self, scope, receive, send):
pass
return AlreadySentResponse()
else:
return Response("Internal error", status_code=500)
starlette_app = Starlette(
debug=self.settings.debug,
routes=[
Route("/sse", endpoint=handle_sse),
Route("/messages", endpoint=handle_messages, methods=["POST"]),
],
)
config = uvicorn.Config(
starlette_app,
host=self.settings.host,
port=self.settings.port,
log_level=self.settings.log_level.lower(),
)
server = uvicorn.Server(config)
await server.serve()
# Create MCP server instance
mcp = StravaMCP("Strava Training MCP")
# Global Strava client instance
strava_client: Optional[StravaClient] = None
def get_client() -> Optional[StravaClient]:
"""Get the Strava client instance."""
return strava_client
# Register all tools
register_tools(mcp, get_client)
def main() -> None:
"""Main function to start the Strava Training MCP server."""
print("Starting Strava Training MCP server!")
parser = argparse.ArgumentParser(description="Strava Training MCP Server")
parser.add_argument(
"--transport",
type=str,
default="stdio",
choices=["stdio", "sse", "streamable-http"],
help="Transport to use for the MCP server (default: stdio)",
)
parser.add_argument(
"--host",
type=str,
default="127.0.0.1",
help="Host for HTTP/SSE transports (default: 127.0.0.1)",
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="Port for HTTP/SSE transports (default: 8000)",
)
parser.add_argument(
"--mount-path",
type=str,
default=None,
help="Optional mount path for SSE (e.g., /my/sse). Ignored for other transports.",
)
parser.add_argument(
"--log-level",
type=str,
default="INFO",
choices=["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"],
help="Log level for the HTTP server (default: INFO)",
)
parser.add_argument(
"--auth-code",
type=str,
default=None,
help="Authorization code to complete authentication (optional)",
)
args = parser.parse_args()
# Initialize Strava client and handle authentication if needed
global strava_client
credentials = get_env_credentials()
refresh_token = os.environ.get("STRAVA_REFRESH_TOKEN")
client_id = credentials.get("client_id")
client_secret = os.environ.get("STRAVA_CLIENT_SECRET")
if refresh_token and client_id and client_secret:
strava_client = StravaClient(refresh_token, client_id, client_secret)
print("ā Strava client initialized successfully")
else:
print("\n" + "=" * 60)
print("Strava Authentication Required")
print("=" * 60)
if client_id and client_secret:
print("\nā Credentials found. Generating authorization URL...")
try:
auth_url = get_authorization_url(client_id)
print(f"\nš Authorization URL:")
print(f" {auth_url}\n")
print("š Instructions:")
print(" 1. Visit the URL above in your browser")
print(" 2. Authorize the application")
print(" 3. Copy the 'code' parameter from the redirect URL")
print(" 4. Restart the server with: --auth-code YOUR_CODE")
print("\n Or use the MCP tool: complete_strava_auth(code)")
except Exception as e:
print(f"ā Error generating auth URL: {e}")
else:
print("\nā ļø No Strava credentials found.")
print("\nTo get started:")
print(" 1. Get your Client ID and Client Secret from:")
print(" https://www.strava.com/settings/api")
print(" 2. Set them as environment variables:")
print(" export STRAVA_CLIENT_ID=your_client_id")
print(" export STRAVA_CLIENT_SECRET=your_client_secret")
print(" 3. Restart the server")
print("\n Or use the MCP tool: save_credentials(client_id, client_secret)")
print("\n" + "=" * 60 + "\n")
if args.auth_code:
try:
print(f"Completing authentication with provided code...")
result = complete_auth_with_code(args.auth_code)
print(f"ā Authentication completed! Refresh token saved.")
print(f" Restart the server to use the new credentials.")
return
except Exception as e:
print(f"ā Authentication failed: {e}")
return
# Apply networking and logging settings to FastMCP for HTTP transports
try:
mcp.settings.host = args.host
mcp.settings.port = args.port
mcp.settings.log_level = args.log_level
if args.mount_path:
mcp.settings.mount_path = args.mount_path
except Exception:
pass
mcp.run(transport=args.transport)
if __name__ == "__main__":
main()