"""OAuth U2M authentication for Databricks."""
import base64
import hashlib
import secrets
import string
import sys
import webbrowser
from http.server import BaseHTTPRequestHandler, HTTPServer
from urllib.parse import parse_qs, urlencode, urlparse
import requests
CLIENT_ID = "databricks-cli"
DEFAULT_SCOPES = "all-apis offline_access"
DEFAULT_REDIRECT_URI = "http://localhost:8020"
class OAuthCallbackHandler(BaseHTTPRequestHandler):
"""HTTP handler to capture OAuth callback."""
authorization_code = None
state_value = None
def do_GET(self):
query_components = parse_qs(urlparse(self.path).query)
OAuthCallbackHandler.authorization_code = query_components.get("code", [None])[0]
OAuthCallbackHandler.state_value = query_components.get("state", [None])[0]
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
self.wfile.write(b"""
<html>
<body style="font-family: system-ui; display: flex; justify-content: center; align-items: center; height: 100vh; margin: 0; background: #f5f5f5;">
<div style="text-align: center; padding: 40px; background: white; border-radius: 8px; box-shadow: 0 2px 10px rgba(0,0,0,0.1);">
<h2 style="color: #22c55e;">Authorization Successful!</h2>
<p style="color: #666;">You can close this window.</p>
</div>
</body>
</html>
""")
def log_message(self, format, *args):
pass
def generate_pkce_pair():
"""Generate PKCE code verifier and challenge."""
allowed_chars = string.ascii_letters + string.digits + "-._~"
code_verifier = "".join(secrets.choice(allowed_chars) for _ in range(64))
sha256_hash = hashlib.sha256(code_verifier.encode()).digest()
code_challenge = base64.urlsafe_b64encode(sha256_hash).decode().rstrip("=")
return code_verifier, code_challenge
def start_oauth_flow(host: str, scopes: str = DEFAULT_SCOPES, redirect_uri: str = DEFAULT_REDIRECT_URI) -> str:
"""
Start OAuth U2M flow and return access token.
Opens browser for user authorization.
"""
host = host.rstrip("/")
state = secrets.token_urlsafe(32)
code_verifier, code_challenge = generate_pkce_pair()
auth_params = {
"client_id": CLIENT_ID,
"redirect_uri": redirect_uri,
"response_type": "code",
"state": state,
"code_challenge": code_challenge,
"code_challenge_method": "S256",
"scope": scopes,
}
auth_url = f"{host}/oidc/v1/authorize?{urlencode(auth_params)}"
# Reset state
OAuthCallbackHandler.authorization_code = None
OAuthCallbackHandler.state_value = None
# Start callback server
redirect_port = int(urlparse(redirect_uri).port or 8020)
server = HTTPServer(("localhost", redirect_port), OAuthCallbackHandler)
server.timeout = 300
# Open browser
print(f"Opening browser for authorization...", file=sys.stderr)
webbrowser.open(auth_url)
# Wait for callback
print(f"Waiting for authorization callback on {redirect_uri}...", file=sys.stderr)
server.handle_request()
if OAuthCallbackHandler.state_value != state:
raise ValueError("State mismatch! Possible CSRF attack.")
if not OAuthCallbackHandler.authorization_code:
raise ValueError("No authorization code received.")
# Exchange code for token
print("Exchanging code for token...", file=sys.stderr)
token_response = requests.post(
f"{host}/oidc/v1/token",
data={
"client_id": CLIENT_ID,
"grant_type": "authorization_code",
"scope": scopes,
"redirect_uri": redirect_uri,
"code_verifier": code_verifier,
"code": OAuthCallbackHandler.authorization_code,
}
)
if token_response.status_code != 200:
raise ValueError(f"Token exchange failed: {token_response.text}")
print("Token obtained successfully!", file=sys.stderr)
return token_response.json()["access_token"]