Skip to main content
Glama

Doris MCP Server

Official
by apache
oauth_client.pyโ€ข21 kB
#!/usr/bin/env python3 # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """ OAuth 2.0/OIDC Client Manager Provides OAuth authentication client implementation with PKCE and OIDC support """ import base64 import hashlib import secrets import uuid from datetime import datetime, timedelta from typing import Dict, Optional, Any, Tuple from urllib.parse import urlencode, parse_qs, urlparse import asyncio import json try: import aiohttp except ImportError: raise ImportError("aiohttp is required for OAuth functionality. Install with: pip install aiohttp") from .oauth_types import ( OAuthProvider, OAuthState, OAuthTokens, OAuthUserInfo, OIDCDiscovery, OAuthError, OAuthProviderConfig, OAUTH_PROVIDERS ) from ..utils.logger import get_logger logger = get_logger(__name__) class OAuthStateManager: """Manages OAuth state parameters for CSRF protection""" def __init__(self, state_expiry: int = 600): """Initialize state manager Args: state_expiry: State expiry time in seconds """ self.state_expiry = state_expiry self._states: Dict[str, OAuthState] = {} self._cleanup_task = None logger.info("OAuthStateManager initialized") async def start(self): """Start periodic cleanup task""" self._cleanup_task = asyncio.create_task(self._periodic_cleanup()) logger.info("OAuth state manager started") async def stop(self): """Stop periodic cleanup task""" if self._cleanup_task: self._cleanup_task.cancel() try: await self._cleanup_task except asyncio.CancelledError: pass logger.info("OAuth state manager stopped") def create_state(self, redirect_uri: str, pkce_enabled: bool = True, nonce_enabled: bool = True) -> OAuthState: """Create new OAuth state Args: redirect_uri: OAuth redirect URI pkce_enabled: Whether to enable PKCE nonce_enabled: Whether to enable nonce (for OIDC) Returns: OAuth state object """ state = secrets.token_urlsafe(32) nonce = secrets.token_urlsafe(32) if nonce_enabled else None pkce_verifier = None pkce_challenge = None if pkce_enabled: pkce_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode('utf-8').rstrip('=') challenge_bytes = hashlib.sha256(pkce_verifier.encode()).digest() pkce_challenge = base64.urlsafe_b64encode(challenge_bytes).decode('utf-8').rstrip('=') oauth_state = OAuthState( state=state, nonce=nonce, pkce_verifier=pkce_verifier, pkce_challenge=pkce_challenge, redirect_uri=redirect_uri, created_at=datetime.utcnow(), expires_at=datetime.utcnow() + timedelta(seconds=self.state_expiry) ) self._states[state] = oauth_state logger.debug(f"Created OAuth state: {state}") return oauth_state def get_state(self, state: str) -> Optional[OAuthState]: """Get OAuth state by state parameter Args: state: State parameter Returns: OAuth state object or None if not found/expired """ oauth_state = self._states.get(state) if oauth_state and oauth_state.expires_at > datetime.utcnow(): return oauth_state elif oauth_state: # Remove expired state del self._states[state] logger.debug(f"Removed expired OAuth state: {state}") return None def consume_state(self, state: str) -> Optional[OAuthState]: """Get and remove OAuth state Args: state: State parameter Returns: OAuth state object or None if not found/expired """ oauth_state = self.get_state(state) if oauth_state: del self._states[state] logger.debug(f"Consumed OAuth state: {state}") return oauth_state async def _periodic_cleanup(self): """Periodic cleanup of expired states""" while True: try: await asyncio.sleep(300) # Clean up every 5 minutes current_time = datetime.utcnow() expired_states = [ state for state, oauth_state in self._states.items() if oauth_state.expires_at <= current_time ] for state in expired_states: del self._states[state] if expired_states: logger.info(f"Cleaned up {len(expired_states)} expired OAuth states") except asyncio.CancelledError: break except Exception as e: logger.error(f"Error during OAuth state cleanup: {e}") class OAuthClient: """OAuth 2.0/OIDC Client implementation""" def __init__(self, config): """Initialize OAuth client Args: config: DorisConfig with OAuth configuration """ self.config = config # Access OAuth settings through security configuration if hasattr(config, 'security'): security_config = config.security else: security_config = config self.enabled = security_config.oauth_enabled if not self.enabled: logger.info("OAuth client disabled by configuration") return # Build provider configuration self.provider_config = self._build_provider_config(security_config) self.state_manager = OAuthStateManager(security_config.oauth_state_expiry) # HTTP client session self._session: Optional[aiohttp.ClientSession] = None # Discovery cache self._discovery_cache: Optional[OIDCDiscovery] = None self._discovery_cache_time: Optional[datetime] = None logger.info(f"OAuthClient initialized for provider: {self.provider_config.provider.value}") def _build_provider_config(self, security_config) -> OAuthProviderConfig: """Build OAuth provider configuration Args: security_config: Security configuration object Returns: OAuth provider configuration """ try: provider = OAuthProvider(security_config.oauth_provider) except ValueError: provider = OAuthProvider.CUSTOM # Get default configuration for known providers defaults = OAUTH_PROVIDERS.get(provider, {}) return OAuthProviderConfig( provider=provider, client_id=security_config.oauth_client_id, client_secret=security_config.oauth_client_secret, redirect_uri=security_config.oauth_redirect_uri, scopes=security_config.oauth_scopes or defaults.get("scopes", ["openid", "email", "profile"]), # Endpoints (use configured or defaults) authorization_endpoint=security_config.oauth_authorization_endpoint or defaults.get("authorization_endpoint", ""), token_endpoint=security_config.oauth_token_endpoint or defaults.get("token_endpoint", ""), userinfo_endpoint=security_config.oauth_userinfo_endpoint or defaults.get("userinfo_endpoint"), jwks_uri=security_config.oauth_jwks_uri or defaults.get("jwks_uri"), # Discovery discovery_url=security_config.oidc_discovery_url or defaults.get("discovery_url"), # Settings pkce_enabled=security_config.oauth_pkce_enabled, nonce_enabled=security_config.oauth_nonce_enabled, # User mapping user_id_claim=security_config.oauth_user_id_claim or defaults.get("user_id_claim", "sub"), email_claim=security_config.oauth_email_claim or defaults.get("email_claim", "email"), name_claim=security_config.oauth_name_claim or defaults.get("name_claim", "name"), roles_claim=security_config.oauth_roles_claim, default_roles=security_config.oauth_default_roles ) async def initialize(self) -> bool: """Initialize OAuth client Returns: True if initialization successful """ if not self.enabled: return True try: # Create HTTP session self._session = aiohttp.ClientSession() # Start state manager await self.state_manager.start() # Perform OIDC discovery if configured if self.provider_config.discovery_url: await self._discover_oidc_endpoints() logger.info("OAuth client initialization completed") return True except Exception as e: logger.error(f"Failed to initialize OAuth client: {e}") return False async def shutdown(self): """Shutdown OAuth client""" if not self.enabled: return try: # Stop state manager await self.state_manager.stop() # Close HTTP session if self._session: await self._session.close() logger.info("OAuth client shutdown completed") except Exception as e: logger.error(f"Error during OAuth client shutdown: {e}") async def _discover_oidc_endpoints(self): """Discover OIDC endpoints using discovery URL""" try: # Check cache first if (self._discovery_cache and self._discovery_cache_time and datetime.utcnow() - self._discovery_cache_time < timedelta(hours=1)): return self._discovery_cache logger.info(f"Discovering OIDC endpoints: {self.provider_config.discovery_url}") async with self._session.get(self.provider_config.discovery_url) as response: response.raise_for_status() data = await response.json() discovery = OIDCDiscovery( issuer=data["issuer"], authorization_endpoint=data["authorization_endpoint"], token_endpoint=data["token_endpoint"], userinfo_endpoint=data.get("userinfo_endpoint"), jwks_uri=data.get("jwks_uri"), scopes_supported=data.get("scopes_supported"), response_types_supported=data.get("response_types_supported"), subject_types_supported=data.get("subject_types_supported"), id_token_signing_alg_values_supported=data.get("id_token_signing_alg_values_supported") ) # Update provider configuration with discovered endpoints if not self.provider_config.authorization_endpoint: self.provider_config.authorization_endpoint = discovery.authorization_endpoint if not self.provider_config.token_endpoint: self.provider_config.token_endpoint = discovery.token_endpoint if not self.provider_config.userinfo_endpoint: self.provider_config.userinfo_endpoint = discovery.userinfo_endpoint if not self.provider_config.jwks_uri: self.provider_config.jwks_uri = discovery.jwks_uri # Cache discovery result self._discovery_cache = discovery self._discovery_cache_time = datetime.utcnow() logger.info("OIDC endpoint discovery completed successfully") return discovery except Exception as e: logger.error(f"OIDC endpoint discovery failed: {e}") raise def build_authorization_url(self) -> Tuple[str, OAuthState]: """Build OAuth authorization URL Returns: Tuple of (authorization_url, oauth_state) """ if not self.enabled: raise ValueError("OAuth client is not enabled") # Create state for CSRF protection oauth_state = self.state_manager.create_state( redirect_uri=self.provider_config.redirect_uri, pkce_enabled=self.provider_config.pkce_enabled, nonce_enabled=self.provider_config.nonce_enabled ) # Build authorization parameters params = { 'response_type': 'code', 'client_id': self.provider_config.client_id, 'redirect_uri': self.provider_config.redirect_uri, 'scope': ' '.join(self.provider_config.scopes), 'state': oauth_state.state } # Add PKCE challenge if oauth_state.pkce_challenge: params['code_challenge'] = oauth_state.pkce_challenge params['code_challenge_method'] = 'S256' # Add nonce for OIDC if oauth_state.nonce: params['nonce'] = oauth_state.nonce # Build URL authorization_url = f"{self.provider_config.authorization_endpoint}?{urlencode(params)}" logger.info(f"Built OAuth authorization URL for state: {oauth_state.state}") return authorization_url, oauth_state async def exchange_code_for_tokens(self, code: str, state: str) -> Tuple[OAuthTokens, OAuthState]: """Exchange authorization code for tokens Args: code: Authorization code state: State parameter Returns: Tuple of (OAuth tokens, OAuth state) Raises: ValueError: If state is invalid or exchange fails """ if not self.enabled: raise ValueError("OAuth client is not enabled") # Validate and consume state oauth_state = self.state_manager.consume_state(state) if not oauth_state: raise ValueError("Invalid or expired state parameter") try: # Prepare token request data = { 'grant_type': 'authorization_code', 'client_id': self.provider_config.client_id, 'client_secret': self.provider_config.client_secret, 'code': code, 'redirect_uri': oauth_state.redirect_uri } # Add PKCE verifier if oauth_state.pkce_verifier: data['code_verifier'] = oauth_state.pkce_verifier # Make token request async with self._session.post( self.provider_config.token_endpoint, data=data, headers={'Content-Type': 'application/x-www-form-urlencoded'} ) as response: response_data = await response.json() if response.status != 200: error_msg = response_data.get('error_description', response_data.get('error', 'Token exchange failed')) raise ValueError(f"Token exchange failed: {error_msg}") tokens = OAuthTokens( access_token=response_data['access_token'], token_type=response_data.get('token_type', 'Bearer'), expires_in=response_data.get('expires_in'), refresh_token=response_data.get('refresh_token'), scope=response_data.get('scope'), id_token=response_data.get('id_token') ) logger.info("Successfully exchanged authorization code for tokens") return tokens, oauth_state except Exception as e: logger.error(f"Token exchange failed: {e}") raise ValueError(f"Token exchange failed: {str(e)}") async def get_user_info(self, tokens: OAuthTokens) -> OAuthUserInfo: """Get user information from OAuth provider Args: tokens: OAuth tokens Returns: OAuth user information """ if not self.enabled: raise ValueError("OAuth client is not enabled") if not self.provider_config.userinfo_endpoint: raise ValueError("Userinfo endpoint not configured") try: # Make userinfo request headers = {'Authorization': f'{tokens.token_type} {tokens.access_token}'} async with self._session.get( self.provider_config.userinfo_endpoint, headers=headers ) as response: response.raise_for_status() user_data = await response.json() # Extract user information using configured claims user_info = OAuthUserInfo( sub=str(user_data.get(self.provider_config.user_id_claim, '')), email=user_data.get(self.provider_config.email_claim), name=user_data.get(self.provider_config.name_claim), given_name=user_data.get('given_name'), family_name=user_data.get('family_name'), picture=user_data.get('picture'), locale=user_data.get('locale'), email_verified=user_data.get('email_verified'), roles=user_data.get(self.provider_config.roles_claim, self.provider_config.default_roles.copy()), raw_claims=user_data ) logger.info(f"Retrieved user info for user: {user_info.sub}") return user_info except Exception as e: logger.error(f"Failed to get user info: {e}") raise ValueError(f"Failed to get user info: {str(e)}") async def refresh_tokens(self, refresh_token: str) -> OAuthTokens: """Refresh OAuth tokens Args: refresh_token: Refresh token Returns: New OAuth tokens """ if not self.enabled: raise ValueError("OAuth client is not enabled") try: data = { 'grant_type': 'refresh_token', 'client_id': self.provider_config.client_id, 'client_secret': self.provider_config.client_secret, 'refresh_token': refresh_token } async with self._session.post( self.provider_config.token_endpoint, data=data, headers={'Content-Type': 'application/x-www-form-urlencoded'} ) as response: response_data = await response.json() if response.status != 200: error_msg = response_data.get('error_description', response_data.get('error', 'Token refresh failed')) raise ValueError(f"Token refresh failed: {error_msg}") tokens = OAuthTokens( access_token=response_data['access_token'], token_type=response_data.get('token_type', 'Bearer'), expires_in=response_data.get('expires_in'), refresh_token=response_data.get('refresh_token', refresh_token), # Keep old if not provided scope=response_data.get('scope'), id_token=response_data.get('id_token') ) logger.info("Successfully refreshed OAuth tokens") return tokens except Exception as e: logger.error(f"Token refresh failed: {e}") raise ValueError(f"Token refresh failed: {str(e)}")

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/apache/doris-mcp-server'

If you have feedback or need assistance with the MCP directory API, please join our Discord server