"""Local OAuth callback server for handling OAuth2 authorization flow."""
import asyncio
import secrets
import urllib.parse
import webbrowser
from dataclasses import dataclass
from http.server import BaseHTTPRequestHandler, HTTPServer
from threading import Thread
from typing import Callable, Optional
import httpx
from .token_store import TokenData
@dataclass
class OAuthConfig:
"""OAuth2 configuration for a provider."""
client_id: str
client_secret: str
redirect_uri: str
authorization_url: str
token_url: str
scopes: list[str]
class OAuthCallbackHandler(BaseHTTPRequestHandler):
"""HTTP request handler for OAuth callback."""
# Class variables set by OAuthCallbackServer
callback_result: Optional[dict] = None
callback_error: Optional[str] = None
expected_state: Optional[str] = None
shutdown_event: Optional[asyncio.Event] = None
def log_message(self, format: str, *args) -> None:
"""Suppress default logging."""
pass
def do_GET(self):
"""Handle GET request (OAuth callback)."""
parsed = urllib.parse.urlparse(self.path)
if parsed.path == "/callback":
params = urllib.parse.parse_qs(parsed.query)
# Check for error
if "error" in params:
error = params["error"][0]
error_description = params.get("error_description", ["Unknown error"])[0]
OAuthCallbackHandler.callback_error = f"{error}: {error_description}"
self._send_error_response(error_description)
return
# Verify state
state = params.get("state", [None])[0]
if state != OAuthCallbackHandler.expected_state:
OAuthCallbackHandler.callback_error = "State mismatch - possible CSRF attack"
self._send_error_response("Security error: state mismatch")
return
# Get authorization code
code = params.get("code", [None])[0]
if not code:
OAuthCallbackHandler.callback_error = "No authorization code received"
self._send_error_response("No authorization code received")
return
OAuthCallbackHandler.callback_result = {"code": code}
self._send_success_response()
elif parsed.path == "/health":
self.send_response(200)
self.send_header("Content-type", "text/plain")
self.end_headers()
self.wfile.write(b"OK")
else:
self.send_response(404)
self.end_headers()
def _send_success_response(self):
"""Send success HTML response."""
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
html = """
<!DOCTYPE html>
<html>
<head>
<title>Authorization Successful</title>
<style>
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
display: flex;
justify-content: center;
align-items: center;
height: 100vh;
margin: 0;
background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%);
color: #fff;
}
.container {
text-align: center;
padding: 40px;
background: rgba(255,255,255,0.1);
border-radius: 16px;
backdrop-filter: blur(10px);
}
.success-icon {
font-size: 64px;
margin-bottom: 20px;
}
h1 { margin: 0 0 10px 0; }
p { opacity: 0.8; }
</style>
</head>
<body>
<div class="container">
<div class="success-icon">✓</div>
<h1>Authorization Successful</h1>
<p>You can close this window and return to your application.</p>
</div>
</body>
</html>
"""
self.wfile.write(html.encode())
def _send_error_response(self, error: str):
"""Send error HTML response."""
self.send_response(400)
self.send_header("Content-type", "text/html")
self.end_headers()
html = f"""
<!DOCTYPE html>
<html>
<head>
<title>Authorization Failed</title>
<style>
body {{
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
display: flex;
justify-content: center;
align-items: center;
height: 100vh;
margin: 0;
background: linear-gradient(135deg, #2e1a1a 0%, #3e1616 100%);
color: #fff;
}}
.container {{
text-align: center;
padding: 40px;
background: rgba(255,255,255,0.1);
border-radius: 16px;
backdrop-filter: blur(10px);
}}
.error-icon {{
font-size: 64px;
margin-bottom: 20px;
}}
h1 {{ margin: 0 0 10px 0; }}
p {{ opacity: 0.8; }}
</style>
</head>
<body>
<div class="container">
<div class="error-icon">✗</div>
<h1>Authorization Failed</h1>
<p>{error}</p>
</div>
</body>
</html>
"""
self.wfile.write(html.encode())
class OAuthCallbackServer:
"""Local server for handling OAuth2 callback."""
def __init__(self, host: str = "localhost", port: int = 8787):
"""Initialize OAuth callback server.
Args:
host: Server host.
port: Server port.
"""
self.host = host
self.port = port
self._server: Optional[HTTPServer] = None
self._thread: Optional[Thread] = None
def _start_server(self) -> None:
"""Start the HTTP server in a background thread."""
self._server = HTTPServer((self.host, self.port), OAuthCallbackHandler)
self._thread = Thread(target=self._server.serve_forever, daemon=True)
self._thread.start()
def _stop_server(self) -> None:
"""Stop the HTTP server."""
if self._server:
self._server.shutdown()
self._server = None
if self._thread:
self._thread.join(timeout=2)
self._thread = None
def generate_authorization_url(self, config: OAuthConfig) -> tuple[str, str]:
"""Generate OAuth2 authorization URL with state.
Args:
config: OAuth configuration.
Returns:
Tuple of (authorization_url, state).
"""
state = secrets.token_urlsafe(32)
params = {
"client_id": config.client_id,
"redirect_uri": config.redirect_uri,
"response_type": "code",
"scope": " ".join(config.scopes),
"state": state,
}
url = f"{config.authorization_url}?{urllib.parse.urlencode(params)}"
return url, state
async def exchange_code_for_token(
self, config: OAuthConfig, code: str
) -> TokenData:
"""Exchange authorization code for tokens.
Args:
config: OAuth configuration.
code: Authorization code from callback.
Returns:
TokenData with access and refresh tokens.
Raises:
Exception: If token exchange fails.
"""
async with httpx.AsyncClient() as client:
response = await client.post(
config.token_url,
data={
"grant_type": "authorization_code",
"client_id": config.client_id,
"client_secret": config.client_secret,
"redirect_uri": config.redirect_uri,
"code": code,
},
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
if response.status_code != 200:
error_detail = response.text
raise Exception(f"Token exchange failed: {response.status_code} - {error_detail}")
return TokenData.from_oauth_response(response.json())
async def refresh_token(self, config: OAuthConfig, refresh_token: str) -> TokenData:
"""Refresh an expired access token.
Args:
config: OAuth configuration.
refresh_token: Refresh token.
Returns:
New TokenData with fresh tokens.
Raises:
Exception: If refresh fails.
"""
async with httpx.AsyncClient() as client:
response = await client.post(
config.token_url,
data={
"grant_type": "refresh_token",
"client_id": config.client_id,
"client_secret": config.client_secret,
"refresh_token": refresh_token,
},
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
if response.status_code != 200:
error_detail = response.text
raise Exception(f"Token refresh failed: {response.status_code} - {error_detail}")
return TokenData.from_oauth_response(response.json())
async def run_authorization_flow(
self,
config: OAuthConfig,
timeout: int = 300,
open_browser: bool = True,
) -> TokenData:
"""Run the full OAuth2 authorization flow.
Args:
config: OAuth configuration.
timeout: Timeout in seconds waiting for callback.
open_browser: Whether to automatically open the browser.
Returns:
TokenData with access and refresh tokens.
Raises:
TimeoutError: If user doesn't complete auth within timeout.
Exception: If authorization fails.
"""
# Reset handler state
OAuthCallbackHandler.callback_result = None
OAuthCallbackHandler.callback_error = None
# Generate authorization URL
auth_url, state = self.generate_authorization_url(config)
OAuthCallbackHandler.expected_state = state
# Start callback server
self._start_server()
try:
# Open browser for authorization
if open_browser:
webbrowser.open(auth_url)
# Wait for callback
start_time = asyncio.get_event_loop().time()
while True:
if OAuthCallbackHandler.callback_result:
code = OAuthCallbackHandler.callback_result["code"]
break
if OAuthCallbackHandler.callback_error:
raise Exception(OAuthCallbackHandler.callback_error)
if asyncio.get_event_loop().time() - start_time > timeout:
raise TimeoutError("Authorization timed out")
await asyncio.sleep(0.5)
# Exchange code for tokens
return await self.exchange_code_for_token(config, code)
finally:
self._stop_server()
def get_authorization_url_sync(self, config: OAuthConfig) -> tuple[str, str]:
"""Synchronous wrapper for generating authorization URL.
Args:
config: OAuth configuration.
Returns:
Tuple of (authorization_url, state).
"""
return self.generate_authorization_url(config)