"""MCP Server Core implementation.
Implements the MCP server using FastMCP with HTTP Streamable transport
including SSE streaming support per FR-006b.
Supports dual-mode authentication:
- LOCAL mode: OAuth client flow with browser SSO
- CLOUD mode: Resource Server with Bearer token validation
"""
from __future__ import annotations
import asyncio
import signal
from typing import TYPE_CHECKING, Any
from mcp.server.fastmcp import FastMCP
from sso_mcp_server import get_logger
from sso_mcp_server.auth import (
AuthManager,
JWKSClient,
TokenStore,
TokenValidator,
require_auth,
set_auth_manager,
set_token_validator,
)
from sso_mcp_server.checklists import ChecklistService
from sso_mcp_server.config import AuthMode
from sso_mcp_server.processes import ProcessService
from sso_mcp_server.tools import (
get_checklist_impl,
get_process_impl,
list_checklists_impl,
list_processes_impl,
search_processes_impl,
)
if TYPE_CHECKING:
from sso_mcp_server.config import Settings
# Module-level logger
_logger = get_logger("server")
# Module-level FastMCP server instance (initialized during run_server)
_mcp: FastMCP | None = None
# Module-level JWKS client (CLOUD mode)
_jwks_client: JWKSClient | None = None
# Shutdown event for graceful termination
_shutdown_event: asyncio.Event | None = None
def _initialize_local_auth(settings: Settings) -> AuthManager:
"""Initialize LOCAL mode authentication components.
Creates token store and auth manager. Authentication is deferred until
the first tool call via the @require_auth middleware decorator.
Args:
settings: Application settings.
Returns:
Configured AuthManager instance.
"""
_logger.info("initializing_local_auth")
# Create token store
token_store = TokenStore(settings.token_cache_path)
# Create auth manager
auth_manager = AuthManager(settings, token_store)
# Set global auth manager for middleware
set_auth_manager(auth_manager)
_logger.info(
"local_auth_ready",
has_cached_tokens=token_store.has_cached_tokens(),
)
return auth_manager
def _initialize_cloud_auth(settings: Settings) -> TokenValidator:
"""Initialize CLOUD mode authentication components.
Creates JWKS client and token validator for validating incoming
Bearer tokens.
Args:
settings: Application settings.
Returns:
Configured TokenValidator instance.
"""
global _jwks_client # noqa: PLW0603
_logger.info("initializing_cloud_auth")
# Create JWKS client
_jwks_client = JWKSClient(cache_ttl=settings.jwks_cache_ttl)
# Create token validator
token_validator = TokenValidator(
resource_identifier=settings.resource_identifier,
allowed_issuers=settings.allowed_issuers,
jwks_client=_jwks_client,
)
# Set global token validator for middleware
set_token_validator(token_validator)
_logger.info(
"cloud_auth_ready",
resource_identifier=settings.resource_identifier,
issuers_count=len(settings.allowed_issuers),
)
return token_validator
def _initialize_auth(settings: Settings) -> None:
"""Initialize authentication based on configured mode.
Args:
settings: Application settings.
"""
if settings.auth_mode == AuthMode.LOCAL:
_initialize_local_auth(settings)
elif settings.auth_mode == AuthMode.CLOUD:
_initialize_cloud_auth(settings)
else: # AUTO mode - initialize both
_logger.info("initializing_auto_auth_mode")
_initialize_local_auth(settings)
if settings.resource_identifier and settings.allowed_issuers:
_initialize_cloud_auth(settings)
async def _cleanup_resources() -> None:
"""Clean up server resources on shutdown.
Closes JWKS client HTTP connections and any other resources
that need explicit cleanup.
"""
global _jwks_client # noqa: PLW0603
_logger.info("cleaning_up_resources")
if _jwks_client is not None:
try:
await _jwks_client.close()
_logger.debug("jwks_client_closed")
except Exception as e:
_logger.warning("jwks_client_close_error", error=str(e))
finally:
_jwks_client = None
def _handle_signal(signum: int, _frame: Any) -> None:
"""Handle shutdown signals (SIGINT, SIGTERM).
Args:
signum: Signal number.
_frame: Current stack frame (unused).
"""
signal_name = signal.Signals(signum).name
_logger.info("received_shutdown_signal", signal=signal_name)
if _shutdown_event is not None:
_shutdown_event.set()
# Re-raise KeyboardInterrupt for SIGINT to trigger proper shutdown
if signum == signal.SIGINT:
raise KeyboardInterrupt
def run_server(settings: Settings) -> None:
"""Run the MCP server with HTTP Streamable transport.
Initializes authentication, checklist service, and starts the MCP server
on the configured port with HTTP Streamable transport and SSE streaming.
Args:
settings: Application settings containing port and other configuration.
"""
global _mcp, _shutdown_event # noqa: PLW0603
# Set up signal handlers for graceful shutdown
signal.signal(signal.SIGINT, _handle_signal)
signal.signal(signal.SIGTERM, _handle_signal)
# Create shutdown event
_shutdown_event = asyncio.Event()
# Initialize FastMCP with configured port
_mcp = FastMCP(
name="sso-mcp-server",
host="127.0.0.1",
port=settings.mcp_port,
)
# Register tools
_register_tools(_mcp)
# Initialize authentication
_initialize_auth(settings)
# Initialize checklist service
_initialize_checklist_service(settings)
# Initialize process service
_initialize_process_service(settings)
_logger.info(
"server_starting",
port=settings.mcp_port,
transport="http-streamable",
auth_mode=settings.auth_mode.value,
)
# Run with HTTP Streamable transport (includes SSE support per FR-006b)
try:
_mcp.run(transport="streamable-http")
finally:
# Ensure cleanup runs even if run() exits abnormally
# Create a new event loop for cleanup since the main loop has exited
_logger.info("server_stopped_running_cleanup")
try:
asyncio.run(_cleanup_resources())
except Exception as e:
_logger.warning("cleanup_error", error=str(e))
def get_mcp_server() -> FastMCP:
"""Get the MCP server instance.
Returns:
The configured FastMCP server instance.
"""
if _mcp is None:
raise RuntimeError("MCP server not initialized. Call run_server first.")
return _mcp
# Global checklist service instance (set during initialization)
_checklist_service: ChecklistService | None = None
# Global process service instance (set during initialization)
_process_service: ProcessService | None = None
def _initialize_checklist_service(settings: Settings) -> ChecklistService:
"""Initialize the checklist service.
Args:
settings: Application settings.
Returns:
Configured ChecklistService instance.
"""
global _checklist_service # noqa: PLW0603
_checklist_service = ChecklistService(settings.checklist_dir)
_logger.info("checklist_service_initialized", directory=str(settings.checklist_dir))
return _checklist_service
def _initialize_process_service(settings: Settings) -> ProcessService:
"""Initialize the process service.
Args:
settings: Application settings.
Returns:
Configured ProcessService instance.
"""
global _process_service # noqa: PLW0603
_process_service = ProcessService(settings.process_dir)
_logger.info("process_service_initialized", directory=str(settings.process_dir))
return _process_service
def _register_tools(mcp_instance: FastMCP) -> None:
"""Register MCP tools on the FastMCP instance.
Args:
mcp_instance: The FastMCP server instance to register tools on.
"""
@mcp_instance.tool()
@require_auth
async def get_checklist(name: str) -> dict[str, Any]:
"""Get a development checklist by name.
Retrieves the full content of a specific checklist.
Args:
name: Name of the checklist to retrieve.
Returns:
Dictionary containing checklist name, description, and content.
If not found, returns error with list of available checklists.
"""
if _checklist_service is None:
return {
"error": "SERVICE_NOT_INITIALIZED",
"message": "Checklist service not initialized.",
}
return get_checklist_impl(name, _checklist_service)
@mcp_instance.tool()
@require_auth
async def list_checklists() -> dict[str, Any]:
"""List all available development checklists.
Returns metadata (name and description) for all checklists
without including the full content.
Returns:
Dictionary containing list of checklists with name and description.
"""
if _checklist_service is None:
return {
"error": "SERVICE_NOT_INITIALIZED",
"message": "Checklist service not initialized.",
}
return list_checklists_impl(_checklist_service)
@mcp_instance.tool()
@require_auth
async def get_process(name: str) -> dict[str, Any]:
"""Get a development process by name.
Retrieves the full content of a specific development process
including steps, guidelines, or workflows.
Args:
name: Name of the process to retrieve (case-insensitive).
Returns:
Dictionary containing process name, description, and content.
If not found, returns error with list of available processes.
"""
if _process_service is None:
return {
"error": "SERVICE_NOT_INITIALIZED",
"message": "Process service not initialized.",
}
return get_process_impl(name, _process_service)
@mcp_instance.tool()
@require_auth
async def list_processes() -> dict[str, Any]:
"""List all available development processes.
Returns metadata (name and description) for all processes
without including the full content.
Returns:
Dictionary containing list of processes with name and description.
"""
if _process_service is None:
return {
"error": "SERVICE_NOT_INITIALIZED",
"message": "Process service not initialized.",
}
return list_processes_impl(_process_service)
@mcp_instance.tool()
@require_auth
async def search_processes(query: str) -> dict[str, Any]:
"""Search across all development processes for a keyword.
Returns matching processes ranked by relevance (title matches
rank higher than content matches). Maximum 50 results.
Args:
query: Keyword or phrase to search for (case-insensitive).
Returns:
Dictionary containing search results, query, count, and total_processes.
"""
if _process_service is None:
return {
"error": "SERVICE_NOT_INITIALIZED",
"message": "Process service not initialized.",
}
return search_processes_impl(query, _process_service)