Skip to main content
Glama
oauth_flow.py18.7 kB
# # Copyright (C) 2024 Billy Bryant # Portions copyright (C) 2024 Sergey Parfenyuk (original MIT-licensed author) # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see <https://www.gnu.org/licenses/>. # # MIT License attribution: Portions of this file were originally licensed # under the MIT License by Sergey Parfenyuk (2024). # """Complete OAuth flow implementation for MCP Remote Python.""" import asyncio import json import threading import time from typing import Any from urllib.parse import urlparse import httpx from mcp_foxxy_bridge.utils.logging import get_logger from .config import OAUTH_USER_AGENT from .coordination import cleanup_lockfile, coordinate_auth from .events import EventEmitter from .oauth_client_provider import OAuthClientProvider from .types import OAuthClientInformation, OAuthProviderOptions, OAuthTokens from .utils import setup_signal_handlers logger = get_logger(__name__, facility="OAUTH") class OAuthFlow: """Complete OAuth authentication flow manager.""" def __init__(self, options: OAuthProviderOptions) -> None: self.options = options self.provider = OAuthClientProvider(options) self.events = EventEmitter() # Bridge server handles OAuth callbacks, no separate callback server needed self._auth_result: dict[str, Any] | None = None self._auth_completed = threading.Event() # Set up event handlers self.events.on("auth_success", self._handle_auth_success) self.events.on("auth_error", self._handle_auth_error) # Set up cleanup on exit setup_signal_handlers(self._cleanup) def _handle_auth_success(self, data: dict[str, Any]) -> None: """Handle successful authorization callback.""" self._auth_result = {"success": True, "data": data} self._auth_completed.set() def _handle_auth_error(self, data: dict[str, Any]) -> None: """Handle authorization error callback.""" self._auth_result = {"success": False, "error": data} self._auth_completed.set() def _cleanup(self) -> None: """Clean up resources.""" # No callback server to stop - bridge server handles OAuth callbacks cleanup_lockfile(self.provider.server_url_hash) def discover_endpoints(self) -> dict[str, str]: """Discover OAuth endpoints from the server.""" # Extract base URL from server URL parsed_url = urlparse(self.options.server_url) base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" # Try OpenID Connect discovery first on the MCP server discovery_urls = [ f"{base_url}/.well-known/openid_configuration", f"{base_url}/.well-known/oauth-authorization-server", ] for discovery_url in discovery_urls: try: # Secure HTTP request with configurable SSL verification verify_ssl = getattr(self.options, "verify_ssl", True) response = httpx.get( discovery_url, timeout=10, verify=verify_ssl, headers={"User-Agent": OAUTH_USER_AGENT} ) if response.status_code == 200: config = response.json() logger.info( "Discovered OAuth endpoints via well-known URL for server '%s'", self.options.server_name or "UNKNOWN", ) return { "authorization_endpoint": config.get("authorization_endpoint"), "token_endpoint": config.get("token_endpoint"), "registration_endpoint": config.get("registration_endpoint"), "userinfo_endpoint": config.get("userinfo_endpoint"), } except (httpx.HTTPError, json.JSONDecodeError): continue # Fallback to common paths (original behavior) logger.debug("Using fallback OAuth endpoints for server '%s'", self.options.server_name or "UNKNOWN") return { "authorization_endpoint": f"{base_url}/oauth/authorize", "token_endpoint": f"{base_url}/oauth/token", "registration_endpoint": f"{base_url}/oauth/register", "userinfo_endpoint": f"{base_url}/oauth/userinfo", } def register_client(self, registration_endpoint: str) -> OAuthClientInformation: """Register OAuth client with the server.""" existing_client = self.provider.client_information() if existing_client: return existing_client metadata = self.provider.client_metadata() try: verify_ssl = getattr(self.options, "verify_ssl", True) response = httpx.post( registration_endpoint, json={ "redirect_uris": metadata.redirect_uris, "client_name": metadata.client_name, "client_uri": metadata.client_uri, "token_endpoint_auth_method": metadata.token_endpoint_auth_method, "grant_types": metadata.grant_types, "response_types": metadata.response_types, "scope": metadata.scope, "software_id": metadata.software_id, "software_version": metadata.software_version, }, headers={"Content-Type": "application/json", "User-Agent": OAUTH_USER_AGENT}, timeout=10, verify=verify_ssl, ) if response.status_code == 201: client_data = response.json() client_info = OAuthClientInformation( client_id=client_data["client_id"], client_secret=client_data.get("client_secret"), client_id_issued_at=client_data.get("client_id_issued_at"), client_secret_expires_at=client_data.get("client_secret_expires_at"), ) self.provider.save_client_information(client_info) return client_info msg = f"Client registration failed: {response.status_code} - {response.text}" raise RuntimeError(msg) except httpx.HTTPError as e: msg = f"Failed to register OAuth client: {e}" raise RuntimeError(msg) from e def exchange_code_for_tokens( self, token_endpoint: str, auth_code: str, client_info: OAuthClientInformation ) -> OAuthTokens: """Exchange authorization code for access tokens.""" code_verifier = self.provider.code_verifier() if not code_verifier: raise RuntimeError("Code verifier not found") # Debug logging for token exchange logger.debug(f"Token exchange - Authorization code length: {len(auth_code)}") logger.debug(f"Token exchange - Code verifier length: {len(code_verifier)}") logger.debug(f"Token exchange - Redirect URI: {self.provider.redirect_url}") logger.debug(f"Token exchange - Client ID: {client_info.client_id}") logger.debug(f"Token exchange - Token endpoint: {token_endpoint}") data = { "grant_type": "authorization_code", "code": auth_code, "redirect_uri": self.provider.redirect_url, "client_id": client_info.client_id, "code_verifier": code_verifier, } # CRITICAL FIX: Some OAuth servers require client_secret in POST body, not Basic auth # Try POST body authentication first for broader compatibility auth = None if client_info.client_secret: logger.debug("Using client secret in POST body for token exchange") data["client_secret"] = client_info.client_secret else: logger.debug("Using public client authentication (PKCE only)") # Enhanced debug logging for the actual request logger.debug(f"Token exchange request data: {data}") logger.debug(f"Token exchange auth method: {'post_body' if client_info.client_secret else 'public'}") try: verify_ssl = getattr(self.options, "verify_ssl", True) response = httpx.post( token_endpoint, data=data, auth=auth, headers={"Content-Type": "application/x-www-form-urlencoded", "User-Agent": OAUTH_USER_AGENT}, timeout=10, verify=verify_ssl, ) # Enhanced debug logging for response logger.debug(f"Token exchange response status: {response.status_code}") logger.debug(f"Token exchange response headers: {dict(response.headers)}") if response.status_code == 200: token_data = response.json() tokens = OAuthTokens( access_token=token_data["access_token"], refresh_token=token_data.get("refresh_token"), token_type=token_data.get("token_type", "Bearer"), expires_in=token_data.get("expires_in"), scope=token_data.get("scope"), ) self.provider.save_tokens(tokens) logger.info("Token exchange successful") return tokens # Enhanced error logging logger.error(f"Token exchange failed with status {response.status_code}") logger.error(f"Response body: {response.text}") msg = f"Token exchange failed: {response.status_code} - {response.text}" raise RuntimeError(msg) except httpx.HTTPError as e: logger.exception("HTTP error during token exchange: %s", e) msg = f"Failed to exchange code for tokens: {e}" raise RuntimeError(msg) from e async def authenticate(self, skip_browser: bool = False) -> OAuthTokens: """Perform complete OAuth authentication flow.""" # Check for existing valid tokens existing_tokens = self.provider.tokens() if existing_tokens and not skip_browser: logger.info("Using existing tokens for server '%s'", self.options.server_name or "UNKNOWN") return existing_tokens logger.info("Starting OAuth authentication flow for server '%s'", self.options.server_name or "UNKNOWN") # Discover OAuth endpoints endpoints = self.discover_endpoints() if not endpoints.get("authorization_endpoint") or not endpoints.get("token_endpoint"): raise RuntimeError("Could not discover OAuth endpoints") # Check if we should coordinate with other processes # Note: callback_port is now the bridge server port since OAuth is integrated should_start_auth, _ = coordinate_auth(self.provider.server_url_hash, self.options.callback_port, self.events) if not should_start_auth: # Another process handled authentication tokens = self.provider.tokens() if tokens: return tokens raise RuntimeError("Authentication failed in coordinating process") # Register OAuth client if needed if endpoints.get("registration_endpoint"): client_info = self.register_client(endpoints["registration_endpoint"]) else: retrieved_client_info: OAuthClientInformation | None = self.provider.client_information() if not retrieved_client_info: raise RuntimeError("No OAuth client information available and no registration endpoint") client_info = retrieved_client_info # No need to start callback server - bridge server handles OAuth callbacks try: # Build authorization URL and redirect user auth_url = self.provider.build_authorization_url(endpoints["authorization_endpoint"], client_info.client_id) if not skip_browser: self.provider.redirect_to_authorization(auth_url) else: logger.info( "Visit this URL to authorize server '%s': %s", self.options.server_name or "UNKNOWN", auth_url ) # Wait for tokens to be saved by bridge server's OAuth callback logger.info("Waiting for authorization callback for server '%s'", self.options.server_name or "UNKNOWN") timeout = 300 # 5 minutes start_time = time.time() last_log_time = start_time while time.time() - start_time < timeout: # Check if tokens were saved by the bridge server's callback handler tokens = self.provider.tokens() if tokens: elapsed = time.time() - start_time logger.success("Authentication successful after %.1f seconds!", elapsed) # type: ignore[attr-defined] return tokens current_time = time.time() elapsed = current_time - start_time remaining = timeout - elapsed if current_time - last_log_time >= 10: logger.info( "OAuth callback pending for server '%s' (%.1fs elapsed, %.1fs remaining)", self.options.server_name or "UNKNOWN", elapsed, remaining, ) last_log_time = current_time await asyncio.sleep(1.0) # Provide helpful error message for redirect issues elapsed = time.time() - start_time error_msg = f"OAuth authentication timed out after {elapsed:.1f} seconds" logger.error("[%s] OAuth timeout: %s", self.options.server_name or "UNKNOWN", error_msg) raise RuntimeError(error_msg) finally: self._cleanup() def refresh_tokens(self, refresh_token: str) -> OAuthTokens: """Refresh access tokens using refresh token.""" # For token refresh, try to use the OAuth issuer first if available token_endpoint = None # Check if we have a configured OAuth issuer for refresh operations if hasattr(self.options, "oauth_issuer") and self.options.oauth_issuer: logger.debug("[%s] Using configured OAuth issuer for token refresh", self.options.server_name or "UNKNOWN") oauth_issuer = self.options.oauth_issuer # Try OAuth issuer endpoints for refresh refresh_discovery_urls = [ f"{oauth_issuer}/.well-known/openid_configuration", f"{oauth_issuer}/.well-known/oauth-authorization-server", ] for discovery_url in refresh_discovery_urls: try: verify_ssl = getattr(self.options, "verify_ssl", True) response = httpx.get( discovery_url, timeout=10, verify=verify_ssl, headers={"User-Agent": OAUTH_USER_AGENT} ) if response.status_code == 200: config = response.json() token_endpoint = config.get("token_endpoint") if token_endpoint: server_name = self.options.server_name or "UNKNOWN" logger.debug("[%s] Found token endpoint via OAuth issuer discovery", server_name) break except (httpx.HTTPError, json.JSONDecodeError): continue # If discovery failed, try common OAuth issuer paths if not token_endpoint: token_endpoint = f"{oauth_issuer}/oauth/token" logger.debug("[%s] Using fallback OAuth issuer token endpoint", self.options.server_name or "UNKNOWN") # If no OAuth issuer or it failed, fall back to server endpoints if not token_endpoint: server_name = self.options.server_name or "UNKNOWN" logger.debug("[%s] Falling back to server endpoints for token refresh", server_name) endpoints = self.discover_endpoints() token_endpoint = endpoints.get("token_endpoint") if not token_endpoint: raise RuntimeError("Token endpoint not available") client_info = self.provider.client_information() if not client_info: raise RuntimeError("Client information not available") data = { "grant_type": "refresh_token", "refresh_token": refresh_token, "client_id": client_info.client_id, } auth = None if client_info.client_secret: auth = (client_info.client_id, client_info.client_secret) try: verify_ssl = getattr(self.options, "verify_ssl", True) response = httpx.post( token_endpoint, data=data, auth=auth, headers={"Content-Type": "application/x-www-form-urlencoded", "User-Agent": OAUTH_USER_AGENT}, timeout=10, verify=verify_ssl, ) if response.status_code == 200: token_data = response.json() tokens = OAuthTokens( access_token=token_data["access_token"], refresh_token=token_data.get("refresh_token", refresh_token), token_type=token_data.get("token_type", "Bearer"), expires_in=token_data.get("expires_in"), scope=token_data.get("scope"), ) self.provider.save_tokens(tokens) return tokens msg = f"Token refresh failed: {response.status_code} - {response.text}" raise RuntimeError(msg) except httpx.HTTPError as e: msg = f"Failed to refresh tokens: {e}" raise RuntimeError(msg) from e def invalidate_credentials(self) -> None: """Invalidate and remove all stored credentials.""" self.provider.invalidate_credentials() self._cleanup()

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/billyjbryant/mcp-foxxy-bridge'

If you have feedback or need assistance with the MCP directory API, please join our Discord server