"""Command-line interface for the Keboola MCP server."""
import argparse
import asyncio
import contextlib
import json
import logging.config
import os
import pathlib
import sys
import traceback
from typing import Optional
import pydantic
import requests
from fastmcp import FastMCP
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Route
from keboola_mcp_server.config import Config, ServerRuntimeInfo
from keboola_mcp_server.mcp import ForwardSlashMiddleware
from keboola_mcp_server.server import CustomRoutes, create_server
LOG = logging.getLogger(__name__)
def parse_args(args: Optional[list[str]] = None) -> argparse.Namespace:
"""Parses command line arguments."""
parser = argparse.ArgumentParser(
prog='python -m keboola-mcp-server',
description='Keboola MCP Server',
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
'--transport',
choices=['stdio', 'sse', 'streamable-http', 'http-compat'],
default='stdio',
help='Transport to use for MCP communication',
)
parser.add_argument(
'--log-level',
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
default='INFO',
help='Logging level',
)
parser.add_argument(
'--api-url',
metavar='URL',
help=(
'Keboola Storage API URL using format of https://connection.<REGION>.keboola.com. Example: For AWS region '
'"eu-central-1", use: https://connection.eu-central-1.keboola.com'
),
)
parser.add_argument('--storage-token', metavar='STR', help='Keboola Storage API token.')
parser.add_argument('--workspace-schema', metavar='STR', help='Keboola Storage API workspace schema.')
parser.add_argument('--host', default='localhost', metavar='STR', help='The host to listen on.')
parser.add_argument('--port', type=int, default=8000, metavar='INT', help='The port to listen on.')
parser.add_argument('--log-config', type=pathlib.Path, metavar='PATH', help='Logging config file.')
return parser.parse_args(args)
def _create_exception_handler(status_code: int = 500, log_exception: bool = False):
"""
Returns a JSON message response for all unhandled errors from request handlers. The response JSON body
will show exception message and traceback (if the app runs in the debug mode).
:param status_code: the HTTP status code to return; if not specified 500 (Server Error) status code is used
"""
async def _exception_handler(request: Request, exc):
exc_str = f'{type(exc).__name__}: {exc}'
if log_exception:
LOG.exception(f'Unhandled error: {exc_str}')
if request.app.debug:
exc_type, exc_value, exc_traceback = sys.exc_info()
exc_text = ''.join(traceback.format_exception(exc_type, exc_value, exc_traceback))
return JSONResponse({'message': exc_str, 'exception': exc_text}, status_code)
else:
return JSONResponse({'message': exc_str}, status_code)
return _exception_handler
async def _http_exception_handler(request: Request, exc: HTTPException):
return JSONResponse({'message': exc.detail}, status_code=exc.status_code)
_bad_request_handler = _create_exception_handler(status_code=400)
_exception_handlers = {
HTTPException: _http_exception_handler,
json.JSONDecodeError: _bad_request_handler,
requests.JSONDecodeError: _bad_request_handler,
pydantic.ValidationError: _bad_request_handler,
ValueError: _bad_request_handler,
Exception: _create_exception_handler(status_code=500, log_exception=True),
}
async def run_server(args: Optional[list[str]] = None) -> None:
"""Runs the MCP server in async mode."""
parsed_args = parse_args(args)
log_config: pathlib.Path | None = parsed_args.log_config
if not log_config and os.environ.get('LOG_CONFIG'):
log_config = pathlib.Path(os.environ.get('LOG_CONFIG'))
if log_config and not log_config.is_file():
LOG.warning(f'Invalid log config file: {log_config}. Using default logging configuration.')
log_config = None
if log_config:
# remove fastmcp's rich handler, which is aggressively set up during "import fastmcp"
fastmcp_logger = logging.getLogger('fastmcp')
for hdlr in fastmcp_logger.handlers[:]:
fastmcp_logger.removeHandler(hdlr)
fastmcp_logger.propagate = True
fastmcp_logger.setLevel(logging.NOTSET)
logging.config.fileConfig(log_config, disable_existing_loggers=False)
else:
logging.basicConfig(
format='%(asctime)s %(name)s %(levelname)s: %(message)s',
level=parsed_args.log_level,
stream=sys.stderr,
)
# Create config from the CLI arguments
config = Config(
storage_api_url=parsed_args.api_url,
storage_token=parsed_args.storage_token,
workspace_schema=parsed_args.workspace_schema,
)
try:
# Create and run the server
if parsed_args.transport == 'stdio':
runtime_config = ServerRuntimeInfo(transport=parsed_args.transport)
keboola_mcp_server: FastMCP = create_server(config, runtime_info=runtime_config)
if config.oauth_client_id or config.oauth_client_secret:
raise RuntimeError('OAuth authorization can only be used with HTTP-based transports.')
await keboola_mcp_server.run_async(transport=parsed_args.transport)
else:
# 'http-compat' is a compatibility mode to support both Streamable-HTTP and SSE transports.
# SSE transport is deprecated and will be removed in the future.
# Supporting both transports is implemented by creating a parent app and mounting
# two apps (SSE and Streamable-HTTP) to it. The custom routes (like health check)
# are added to the parent app. We use local imports here due to temporary nature of this code.
from contextlib import asynccontextmanager
import uvicorn
from fastmcp.server.http import StarletteWithLifespan
from starlette.applications import Starlette
mount_paths: dict[str, StarletteWithLifespan] = {}
custom_routes: CustomRoutes | None = None
transports: list[str] = []
mcp_server: FastMCP | None = None
if parsed_args.transport in ['http-compat', 'streamable-http']:
http_runtime_config = ServerRuntimeInfo('http-compat/streamable-http')
mcp_server, custom_routes = create_server(
config, runtime_info=http_runtime_config, custom_routes_handling='return'
)
http_app: StarletteWithLifespan = mcp_server.http_app(
path='/',
transport='streamable-http',
)
mount_paths['/mcp'] = http_app
transports.append('Streamable-HTTP')
if parsed_args.transport in ['http-compat', 'sse']:
sse_runtime_config = ServerRuntimeInfo('http-compat/sse')
mcp_server, custom_routes = create_server(
config, runtime_info=sse_runtime_config, custom_routes_handling='return'
)
sse_app: StarletteWithLifespan = mcp_server.http_app(
path='/',
transport='sse',
)
log_messages: list[str] = []
for route in sse_app.routes:
# make sure that the root path is available for GET requests only
# (i.e. POST requests are not allowed)
if isinstance(route, Route) and route.path == '/' and not route.methods:
route.methods = ['GET', 'HEAD']
log_messages.append(str(route))
LOG.info('SSE Routes:\n{}\n'.format('\n'.join(log_messages)))
mount_paths['/sse'] = sse_app # serves /sse/ and /messages
transports.append('SSE')
@asynccontextmanager
async def lifespan(_app: Starlette):
async with contextlib.AsyncExitStack() as stack:
for _inner_app in mount_paths.values():
await stack.enter_async_context(_inner_app.lifespan(_app))
yield
app = Starlette(
middleware=[Middleware(ForwardSlashMiddleware)],
lifespan=lifespan,
exception_handlers=_exception_handlers,
)
for path, inner_app in mount_paths.items():
app.mount(path, inner_app)
custom_routes.add_to_starlette(app)
assert isinstance(mcp_server, FastMCP)
app.state.mcp_tools_input_schema = {
tool.name: tool.parameters for tool in (await mcp_server.get_tools()).values()
}
config = uvicorn.Config(
app,
host=parsed_args.host,
port=parsed_args.port,
log_config=log_config,
timeout_graceful_shutdown=0,
lifespan='on',
)
server = uvicorn.Server(config)
LOG.info(
f'Starting MCP server with {", ".join(transports)} transport{"s" if len(transports) > 1 else ""}'
f' on http://{parsed_args.host}:{parsed_args.port}/'
)
await server.serve()
except Exception as e:
LOG.exception(f'Server failed: {e}')
sys.exit(1)
def main(args: Optional[list[str]] = None) -> None:
asyncio.run(run_server(args))
if __name__ == '__main__':
main()