import os
import logging
from contextvars import ContextVar
from typing import Dict, Optional
from datetime import datetime, timedelta
import httpx
from awslabs.openapi_mcp_server.api.config import Config
from awslabs.openapi_mcp_server.auth.auth_provider import AuthProvider
from awslabs.openapi_mcp_server.auth.auth_factory import register_auth_provider
logger = logging.getLogger(__name__)
_current_client_id: ContextVar[Optional[str]] = ContextVar('zoho_client_id', default=None)
_current_client_secret: ContextVar[Optional[str]] = ContextVar('zoho_client_secret', default=None)
_current_refresh_token: ContextVar[Optional[str]] = ContextVar('zoho_refresh_token', default=None)
# Cache for access tokens (cache_key -> (access_token, expiry_time))
# cache_key is based on client_id + refresh_token to support multiple clients
_access_token_cache: Dict[str, tuple[str, datetime]] = {}
class ZohoAuth(httpx.Auth):
"""Custom HTTPX Auth class that dynamically adds Zoho auth headers per-request.
This is evaluated per-request, allowing context variables to be set by middleware.
"""
def __init__(self, provider: 'ZohoAuthProvider'):
self.provider = provider
def auth_flow(self, request: httpx.Request):
"""Add auth headers dynamically for each request."""
logger.info("π ZohoAuth.auth_flow - Starting per-request auth")
logger.info(f"π Original request : {request.method} {request.url}")
logger.info(f" headers : {request.headers}")
logger.info(f" params : {request.url.params}")
# Read credentials from context variables (set by middleware)
client_id = _current_client_id.get()
client_secret = _current_client_secret.get()
refresh_token = _current_refresh_token.get()
logger.info(f"π Context values - client_id: {bool(client_id)}, client_secret: {bool(client_secret)}, refresh_token: {bool(refresh_token)}")
if not client_id or not client_secret or not refresh_token:
missing = []
if not client_id:
missing.append('client_id')
if not client_secret:
missing.append('client_secret')
if not refresh_token:
missing.append('refresh_token')
logger.warning(f"π Missing OAuth credentials in request context: {', '.join(missing)}")
yield request
return
# Get/refresh access token
access_token = self.provider._get_access_token(client_id, client_secret, refresh_token)
if access_token:
logger.info(f"π Adding Zoho auth header to request")
request.headers["Authorization"] = f"Zoho-oauthtoken {access_token}"
else:
logger.error("π Failed to obtain access token")
yield request
class ZohoAuthProvider(AuthProvider):
def __init__(self, config: Config):
self.config = config
self.accounts_server = os.getenv("ZOHO_ACCOUNTS_SERVER", "https://accounts.zohocloud.ca")
@property
def provider_name(self) -> str:
return "zoho"
def is_configured(self) -> bool:
return bool(self.accounts_server)
def _get_access_token(self, client_id: str, client_secret: str, refresh_token: str) -> Optional[str]:
cache_key = f"{client_id}:{refresh_token}"
if cache_key in _access_token_cache:
access_token, expiry = _access_token_cache[cache_key]
if datetime.now() < expiry:
expires_in_minutes = (expiry - datetime.now()).total_seconds() / 60
logger.info(f"β
Using cached access token (expires in {expires_in_minutes:.0f} minutes)")
return access_token
logger.info(f"π Requesting new access token via Zoho OAuth2...")
logger.info(f"π§ Using refresh token: {refresh_token[:20]}...")
url = f"{self.accounts_server}/oauth/v2/token"
params = {
"refresh_token": refresh_token,
"client_id": client_id,
"client_secret": client_secret,
"grant_type": "refresh_token",
}
try:
logger.info(f"π§ Making POST request to: {url}")
response = httpx.post(
url,
data=params,
headers={"Content-Type": "application/x-www-form-urlencoded"},
timeout=10.0,
)
response_text = response.text
logger.info(f"π§ Raw response: {response_text}")
if not response.is_success:
logger.error(f"β HTTP {response.status_code}: {response_text}")
return None
data = response.json()
logger.info(f"π§ Parsed response: {data}")
if "error" in data:
logger.error(f"β Zoho OAuth error: {data['error']}")
return None
access_token = data.get("access_token")
if not access_token:
logger.error(f"β No access token received: {data}")
return None
# Cache the token
expires_in = data.get("expires_in", 3600) # Default 1 hour
expiry = datetime.now() + timedelta(seconds=expires_in)
_access_token_cache[cache_key] = (access_token, expiry)
logger.info(f"β
Token refresh successful via Zoho API")
return access_token
except Exception as e:
logger.error(f"β Error refreshing token: {e}")
return None
def get_auth_headers(self) -> Dict[str, str]:
"""Get static authentication headers (not used - we use dynamic auth via get_httpx_auth).
This is called during server initialization. Returns empty dict because
Zoho auth is handled dynamically per-request via ZohoAuth.auth_flow().
"""
logger.info("π get_auth_headers() called (returning empty - using dynamic auth)")
return {}
def get_auth_params(self) -> Dict[str, str]:
"""Zoho uses header-based auth, not query params."""
return {}
def get_auth_cookies(self) -> Dict[str, str]:
"""Zoho uses header-based auth, not cookies."""
return {}
def get_httpx_auth(self) -> Optional[httpx.Auth]:
"""Get authentication object for HTTPX.
Returns a custom Auth class that dynamically adds headers per-request.
"""
logger.info("π Returning ZohoAuth instance for dynamic per-request auth")
return ZohoAuth(self)
@staticmethod
def set_credentials_for_request(client_id: str, client_secret: str, refresh_token: str) -> None:
_current_client_id.set(client_id)
_current_client_secret.set(client_secret)
_current_refresh_token.set(refresh_token)
# Auto-register the Zoho auth provider when this module is imported
register_auth_provider("zoho", ZohoAuthProvider)