"""OAuth authentication using MSAL with keyring token storage.
Implements:
- Browser-based OAuth flow with PKCE
- Token storage in OS keychain (macOS Keychain, Windows Credential Manager, etc.)
- Automatic token refresh
"""
import asyncio
import base64
import hashlib
import json
import logging
import secrets
import webbrowser
from datetime import UTC, datetime, timedelta
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Any
from urllib.parse import parse_qs, urlparse
import keyring
import msal # type: ignore[import-untyped]
from .config import Config
logger = logging.getLogger(__name__)
KEYRING_SERVICE = "sharepoint-mcp"
SCOPES = [
"https://graph.microsoft.com/Sites.Read.All",
"https://graph.microsoft.com/Files.ReadWrite.All",
"https://graph.microsoft.com/User.Read",
"offline_access",
]
class AuthManager:
"""Manages OAuth authentication and token lifecycle."""
def __init__(self, config: Config):
"""Initialize the auth manager.
Args:
config: Application configuration
"""
self.config = config
self.app = msal.PublicClientApplication(
client_id=config.client_id,
authority=f"https://login.microsoftonline.com/{config.tenant_id}",
)
self._token_cache: dict[str, Any] | None = None
async def ensure_authenticated(self) -> str:
"""Ensure we have a valid access token. Returns the token.
Returns:
Valid access token
Raises:
AuthenticationError: If authentication fails
"""
# Try to get cached token
token_data = self._load_token_from_keyring()
if token_data:
# Check if token is still valid (with 5 min buffer)
expires_at = datetime.fromisoformat(token_data["expires_at"])
now = datetime.now(UTC).replace(tzinfo=None)
if expires_at > now + timedelta(minutes=5):
logger.debug("Using cached access token")
return str(token_data["access_token"])
# Try to refresh
if "refresh_token" in token_data:
try:
logger.info("Refreshing access token")
new_token = await self._refresh_token(token_data["refresh_token"])
return new_token
except Exception as e:
logger.warning(f"Token refresh failed: {e}")
# Need fresh authentication
logger.info("Starting OAuth flow")
return await self._run_oauth_flow()
async def get_access_token(self) -> str:
"""Get current access token, refreshing if needed.
Returns:
Valid access token
"""
return await self.ensure_authenticated()
async def _run_oauth_flow(self) -> str:
"""Run interactive OAuth flow with browser.
Returns:
Access token
Raises:
AuthenticationError: If OAuth flow fails
"""
# Generate PKCE values
code_verifier = secrets.token_urlsafe(32)
code_challenge = (
base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest())
.decode()
.rstrip("=")
)
# Build auth URL
auth_url = self.app.get_authorization_request_url(
scopes=SCOPES,
redirect_uri=self.config.redirect_uri,
code_challenge=code_challenge,
code_challenge_method="S256",
)
# Start local server to receive callback
auth_code = await self._get_auth_code_from_browser(auth_url)
# Exchange code for tokens
result = self.app.acquire_token_by_authorization_code(
code=auth_code,
scopes=SCOPES,
redirect_uri=self.config.redirect_uri,
code_verifier=code_verifier,
)
if "error" in result:
raise AuthenticationError(
f"Token acquisition failed: {result.get('error_description', result['error'])}"
)
# Store tokens
self._save_token_to_keyring(result)
return str(result["access_token"])
async def _get_auth_code_from_browser(self, auth_url: str) -> str:
"""Open browser and wait for OAuth callback.
Args:
auth_url: OAuth authorization URL
Returns:
Authorization code
Raises:
AuthenticationError: If callback times out or fails
"""
auth_code: str | None = None
error: str | None = None
class CallbackHandler(BaseHTTPRequestHandler):
def do_GET(self) -> None:
nonlocal auth_code, error
query = parse_qs(urlparse(self.path).query)
if "code" in query:
auth_code = query["code"][0]
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
self.wfile.write(
b"""
<html><body>
<h1>Authentication successful!</h1>
<p>You can close this window and return to Claude.</p>
</body></html>
"""
)
elif "error" in query:
error = query["error"][0]
error_desc = query.get("error_description", ["Unknown error"])[0]
self.send_response(400)
self.send_header("Content-type", "text/html")
self.end_headers()
self.wfile.write(
f"""
<html><body>
<h1>Authentication failed</h1>
<p>{error}: {error_desc}</p>
</body></html>
""".encode()
)
else:
self.send_response(400)
self.end_headers()
def log_message(self, format: str, *args: Any) -> None:
pass # Suppress logging
server = HTTPServer(("localhost", 8765), CallbackHandler)
server.timeout = 120 # 2 minute timeout
logger.info("Opening browser for authentication...")
webbrowser.open(auth_url)
# Wait for callback in background thread
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, server.handle_request)
server.server_close()
if error:
raise AuthenticationError(f"OAuth error: {error}")
if not auth_code:
raise AuthenticationError("Authentication timed out or was cancelled")
return auth_code
async def _refresh_token(self, refresh_token: str) -> str:
"""Refresh access token using refresh token.
Args:
refresh_token: Refresh token
Returns:
New access token
Raises:
AuthenticationError: If refresh fails
"""
result = self.app.acquire_token_by_refresh_token(
refresh_token=refresh_token,
scopes=SCOPES,
)
if "error" in result:
raise AuthenticationError(
f"Token refresh failed: {result.get('error_description', result['error'])}"
)
self._save_token_to_keyring(result)
return str(result["access_token"])
def _save_token_to_keyring(self, result: dict[str, Any]) -> None:
"""Save tokens to OS keychain.
Args:
result: Token response from MSAL
"""
expires_at = datetime.now(UTC).replace(tzinfo=None) + timedelta(
seconds=result["expires_in"]
)
token_data = {
"access_token": result["access_token"],
"refresh_token": result.get("refresh_token"),
"expires_at": expires_at.isoformat(),
}
keyring.set_password(KEYRING_SERVICE, self.config.client_id, json.dumps(token_data))
logger.info("Tokens saved to keychain")
def _load_token_from_keyring(self) -> dict[str, Any] | None:
"""Load tokens from OS keychain.
Returns:
Token data if available, None otherwise
"""
try:
data = keyring.get_password(KEYRING_SERVICE, self.config.client_id)
if data:
return json.loads(data) # type: ignore[no-any-return]
except Exception as e:
logger.warning(f"Failed to load token from keyring: {e}")
return None
def clear_tokens(self) -> None:
"""Clear stored tokens (for re-authentication)."""
try:
keyring.delete_password(KEYRING_SERVICE, self.config.client_id)
logger.info("Tokens cleared from keychain")
except keyring.errors.PasswordDeleteError:
logger.debug("No tokens to clear")
class AuthenticationError(Exception):
"""Authentication-related errors."""
pass