Skip to main content
Glama

Synapse MCP Server

MIT License
5
  • Linux
  • Apple
proxy.py16.7 kB
"""FastMCP OAuth proxy extensions for Synapse.""" import logging import os from typing import Any, List, Optional from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse from fastmcp.server.auth import OAuthProxy from fastmcp.server.auth.oauth_proxy import ProxyDCRClient from pydantic import AnyUrl, TypeAdapter from ..session_storage import create_session_storage from .client_registry import ( ClientRegistration, create_client_registry, load_static_registrations, ) logger = logging.getLogger("synapse_mcp.oauth") class SessionAwareOAuthProxy(OAuthProxy): """OAuth proxy that mirrors tokens into session storage.""" def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._session_storage = create_session_storage() self._session_tokens: dict[str, tuple[str, Optional[str]]] = {} self._code_sessions: dict[str, str] = {} self._client_registry = create_client_registry(os.environ) if not hasattr(self, "_clients"): # Guard against older fastmcp versions where OAuthProxy skipped initialization. self._clients = {} self._restore_registered_clients() logger.debug( "SessionAwareOAuthProxy initialized with session storage %s and client registry %s", type(self._session_storage).__name__, type(self._client_registry).__name__, ) def _restore_registered_clients(self) -> None: try: registrations = list(self._client_registry.load_all()) except Exception as exc: # pragma: no cover - defensive logger.warning("Failed to load persisted OAuth clients: %s", exc) return # Merge statically configured clients (highest priority) try: registrations.extend(load_static_registrations()) except Exception as exc: # pragma: no cover - defensive logger.warning("Failed to load static OAuth clients: %s", exc) default_grants = ["authorization_code", "refresh_token"] for record in registrations: if record.client_id in self._clients: continue try: adapter = TypeAdapter(List[AnyUrl]) redirect_source = record.redirect_uris if record.redirect_uris else ["http://127.0.0.1"] redirect_uris = adapter.validate_python(redirect_source) proxy_client = ProxyDCRClient( client_id=record.client_id, client_secret=record.client_secret, redirect_uris=redirect_uris, grant_types=record.grant_types or default_grants, scope=self._default_scope_str, token_endpoint_auth_method="none", allowed_redirect_uri_patterns=self._allowed_client_redirect_uris, ) self._clients[record.client_id] = proxy_client logger.info("Restored registered OAuth client %s", record.client_id) except Exception as exc: # pragma: no cover - defensive logger.warning("Failed to restore OAuth client %s: %s", record.client_id, exc) async def register_client(self, client_info): await super().register_client(client_info) try: registration = ClientRegistration( client_id=client_info.client_id, client_secret=_extract_secret(client_info.client_secret), redirect_uris=[str(uri) for uri in (client_info.redirect_uris or [])], grant_types=list(client_info.grant_types or ["authorization_code", "refresh_token"]), ) self._client_registry.save(registration) logger.debug("Persisted OAuth client %s", client_info.client_id) except Exception as exc: # pragma: no cover - defensive logger.warning("Unable to persist OAuth client %s: %s", client_info.client_id, exc) async def _handle_idp_callback(self, request, *args, **kwargs): session_id = _extract_session_id(request) if session_id: logger.debug("OAuth callback processing for session: %s", session_id) existing_tokens = set(getattr(self, "_access_tokens", {}).keys()) existing_codes = set(getattr(self, "_client_codes", {}).keys()) result = await super()._handle_idp_callback(request, *args, **kwargs) if result and hasattr(result, "headers"): location = result.headers.get("location") if location: parsed = urlparse(location) query_pairs = parse_qsl(parsed.query, keep_blank_values=True) filtered_pairs = [ (key, value) for key, value in query_pairs if not ( key == "state" and ( value is None or value == "" or (isinstance(value, str) and value.lower() == "none") ) ) ] if len(filtered_pairs) != len(query_pairs): new_query = urlencode(filtered_pairs, doseq=True) new_location = urlunparse(parsed._replace(query=new_query)) result.headers["location"] = new_location logger.debug( "Removed empty state parameter from callback redirect (session=%s)", session_id, ) if result: if session_id: client_codes = getattr(self, "_client_codes", {}) new_codes = [code for code in client_codes if code not in existing_codes] for code in new_codes: self._code_sessions[code] = session_id logger.debug("Cached authorization code %s for session %s", code[:8], session_id) try: await self._map_new_tokens_to_users() except Exception as exc: # pragma: no cover - defensive logger.warning("Failed to map tokens to users: %s", exc) if session_id: access_tokens = getattr(self, "_access_tokens", {}) new_tokens = [token for token in access_tokens if token not in existing_tokens] logger.debug( "Session %s received %d new tokens (existing=%d)", session_id, len(new_tokens), len(existing_tokens), ) for token_key in new_tokens: subject = await self._session_storage.find_user_by_token(token_key) self._session_tokens[session_id] = (token_key, subject) logger.debug( "Associated session %s with token %s*** (subject=%s)", session_id, token_key[:20], subject, ) return result async def exchange_authorization_code( self, client: Any, authorization_code: Any, ): existing_tokens = set(getattr(self, "_access_tokens", {}).keys()) token_response = await super().exchange_authorization_code(client, authorization_code) try: await self._map_new_tokens_to_users() except Exception as exc: # pragma: no cover - defensive logger.warning("Failed to map tokens to users after exchange: %s", exc) access_tokens = getattr(self, "_access_tokens", {}) new_tokens = [token for token in access_tokens if token not in existing_tokens] session_id = self._code_sessions.pop(authorization_code.code, None) if session_id: token_key: Optional[str] = None if new_tokens: token_key = new_tokens[-1] else: token_key = next((token for token, data in access_tokens.items() if data.client_id == client.client_id), None) if token_key: subject = await self._session_storage.find_user_by_token(token_key) self._session_tokens[session_id] = (token_key, subject) logger.debug( "Associated session %s with token %s*** (subject=%s) via code exchange", session_id, token_key[:20], subject, ) else: logger.debug("No access token recorded for session %s during code exchange", session_id) return token_response async def _map_new_tokens_to_users(self) -> None: existing_users = await self._session_storage.get_all_user_subjects() access_tokens = getattr(self, "_access_tokens", {}) known_attrs = [attr for attr in dir(self) if "token" in attr.lower() and not attr.startswith("__")] logger.debug( "_map_new_tokens_to_users: existing_users=%s tokens=%s token_attrs=%s", existing_users, [t[:8] + "***" for t in access_tokens], {attr: _summarize_token_attr(attr, getattr(self, attr, None)) for attr in known_attrs}, ) unmapped_tokens = [token for token in access_tokens if await self._session_storage.find_user_by_token(token) is None] for token_key in unmapped_tokens: try: import jwt decoded = jwt.decode(token_key, options={"verify_signature": False}) user_subject = decoded.get("sub") if user_subject: await self._session_storage.set_user_token(user_subject, token_key, ttl_seconds=3600) logger.info("Mapped token %s*** to user %s", token_key[:20], user_subject) else: logger.warning("Token %s*** has no subject claim", token_key[:20]) except Exception as exc: # pragma: no cover - decoding failures logger.warning("Failed to decode token %s***: %s", token_key[:20], exc) async def get_user_token(self, user_subject: str) -> Optional[str]: token_key = await self._session_storage.get_user_token(user_subject) if token_key and token_key in self._access_tokens: return token_key return None async def cleanup_user_tokens(self, user_subject: str) -> None: token_key = await self._session_storage.get_user_token(user_subject) if token_key: if token_key in self._access_tokens: del self._access_tokens[token_key] await self._session_storage.remove_user_token(user_subject) logger.info("Cleaned up token for user %s", user_subject) for session_id, (mapped_token, _) in list(self._session_tokens.items()): if mapped_token == token_key: self._session_tokens.pop(session_id, None) async def cleanup_expired_tokens(self) -> None: await self._session_storage.cleanup_expired_tokens() existing_users = await self._session_storage.get_all_user_subjects() mapped_tokens = { token for user_subject in existing_users for token in [await self._session_storage.get_user_token(user_subject)] if token } orphaned = [token for token in list(self._access_tokens.keys()) if token not in mapped_tokens] for token in orphaned: if self._is_token_old_enough_to_cleanup(token): del self._access_tokens[token] if orphaned: logger.info("Cleaned up %s orphaned tokens from OAuth proxy", len(orphaned)) for session_id, (mapped_token, _) in list(self._session_tokens.items()): if mapped_token in orphaned: self._session_tokens.pop(session_id, None) def _is_token_old_enough_to_cleanup(self, token: str, min_age_seconds: int = 30) -> bool: try: import jwt from datetime import datetime, timezone decoded = jwt.decode(token, options={"verify_signature": False}) issued_at = decoded.get("iat") if not issued_at: return True token_age = datetime.now(timezone.utc).timestamp() - issued_at if token_age <= min_age_seconds: logger.debug("Token is only %.1fs old, keeping for now", token_age) return False return True except Exception as exc: # pragma: no cover - decoding failures logger.debug("Error checking token age, assuming old enough: %s", exc) return True async def iter_user_tokens(self) -> list[tuple[str, str]]: """Return all known (subject, token) pairs from storage.""" tokens: list[tuple[str, str]] = [] subjects = await self._session_storage.get_all_user_subjects() for subject in subjects: token = await self._session_storage.get_user_token(subject) if token: tokens.append((subject, token)) logger.debug("iter_user_tokens -> %s", [(sub, tok[:8] + "***") for sub, tok in tokens]) return tokens async def get_token_for_current_user(self) -> Optional[tuple[str, Optional[str]]]: """Return a token/subject pair when a single active user is known.""" tokens = await self.iter_user_tokens() if len(tokens) == 1: subject, token = tokens[0] return token, subject return None def get_session_token_info(self, session_id: str) -> Optional[tuple[str, Optional[str]]]: info = self._session_tokens.get(session_id) logger.debug("get_session_token_info(%s) -> %s", session_id, (info[0][:8] + "***", info[1]) if info else None) return info async def get_token_for_session(self, session_id: str) -> Optional[tuple[str, Optional[str]]]: info = self.get_session_token_info(session_id) if info: return info subjects = await self._session_storage.get_all_user_subjects() logger.debug("get_token_for_session fallback subjects=%s", subjects) return None def _extract_session_id(request) -> Optional[str]: try: if hasattr(request, "headers"): session_id = request.headers.get("mcp-session-id") if session_id: return session_id if hasattr(request, "state"): session_context = getattr(request.state, "session_context", None) if session_context and hasattr(session_context, "session_id"): return session_context.session_id except Exception as exc: # pragma: no cover - defensive logger.debug("Could not extract session ID from callback: %s", exc) return None def _extract_secret(secret: Any) -> Optional[str]: if secret is None: return None try: return secret.get_secret_value() # type: ignore[attr-defined] except AttributeError: return secret # type: ignore[return-value] def _mask_token(token: Optional[str]) -> Optional[str]: if not token: return token return token[:8] + "***" def _summarize_token_attr(attr: str, value: Any) -> Any: if value is None: return None if attr == "_access_tokens" and isinstance(value, dict): summary: dict[str, dict[str, Any]] = {} for token, data in value.items(): masked = _mask_token(token) or "<missing>" summary[masked] = { "client_id": getattr(data, "client_id", None), "scopes": getattr(data, "scopes", None), "expires_at": getattr(data, "expires_at", None), } return summary if attr == "_refresh_tokens" and isinstance(value, dict): summary = {} for token, data in value.items(): masked = _mask_token(token) or "<missing>" summary[masked] = { "client_id": getattr(data, "client_id", None), "scopes": getattr(data, "scopes", None), } return summary if attr == "_session_tokens" and isinstance(value, dict): return { session: {"token": _mask_token(token), "subject": subject} for session, (token, subject) in value.items() } if isinstance(value, dict): return {"type": "dict", "count": len(value)} if isinstance(value, (list, set, tuple)): return {"type": type(value).__name__, "count": len(value)} return type(value).__name__ __all__ = ["SessionAwareOAuthProxy"]

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/susheel/synapse-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server