oauth2.pyโข21.1 kB
#!/usr/bin/env python3
"""
OAuth2 Authentication
Implementation of OAuth2-based authentication for third-party providers.
"""
import secrets
import base64
import hashlib
from typing import Any, Dict, List, Optional, Union
from datetime import datetime, timedelta
from urllib.parse import urlencode, parse_qs
import aiohttp
import structlog
from .base import (
BaseAuthenticator,
BaseAuthorizer,
User,
AuthToken,
Permission,
UserRole,
AuthenticationError,
InvalidTokenError,
TokenExpiredError,
)
logger = structlog.get_logger(__name__)
class OAuth2Provider:
"""OAuth2 provider configuration."""
def __init__(
self,
name: str,
client_id: str,
client_secret: str,
authorization_url: str,
token_url: str,
user_info_url: str,
scopes: List[str] = None,
redirect_uri: str = None,
):
self.name = name
self.client_id = client_id
self.client_secret = client_secret
self.authorization_url = authorization_url
self.token_url = token_url
self.user_info_url = user_info_url
self.scopes = scopes or []
self.redirect_uri = redirect_uri
class OAuth2State:
"""OAuth2 state information."""
def __init__(
self,
state: str,
provider: str,
redirect_uri: str,
scopes: List[str] = None,
code_verifier: Optional[str] = None,
created_at: Optional[datetime] = None,
expires_at: Optional[datetime] = None,
):
self.state = state
self.provider = provider
self.redirect_uri = redirect_uri
self.scopes = scopes or []
self.code_verifier = code_verifier
self.created_at = created_at or datetime.utcnow()
self.expires_at = expires_at or (datetime.utcnow() + timedelta(minutes=10))
def is_expired(self) -> bool:
"""Check if the state is expired."""
return datetime.utcnow() > self.expires_at
def is_valid(self) -> bool:
"""Check if the state is valid."""
return not self.is_expired()
class OAuth2Token:
"""OAuth2 token information."""
def __init__(
self,
access_token: str,
token_type: str = "Bearer",
expires_in: Optional[int] = None,
refresh_token: Optional[str] = None,
scope: Optional[str] = None,
user_id: Optional[str] = None,
provider: Optional[str] = None,
created_at: Optional[datetime] = None,
):
self.access_token = access_token
self.token_type = token_type
self.expires_in = expires_in
self.refresh_token = refresh_token
self.scope = scope
self.user_id = user_id
self.provider = provider
self.created_at = created_at or datetime.utcnow()
# Calculate expiry
if expires_in:
self.expires_at = self.created_at + timedelta(seconds=expires_in)
else:
self.expires_at = None
def is_expired(self) -> bool:
"""Check if the token is expired."""
if self.expires_at is None:
return False
return datetime.utcnow() > self.expires_at
def is_valid(self) -> bool:
"""Check if the token is valid."""
return bool(self.access_token) and not self.is_expired()
class OAuth2Authenticator(BaseAuthenticator):
"""OAuth2-based authenticator."""
def __init__(self, default_redirect_uri: str = "http://localhost:8000/auth/callback"):
"""Initialize OAuth2 authenticator.
Args:
default_redirect_uri: Default redirect URI for OAuth2 flows
"""
self.default_redirect_uri = default_redirect_uri
self._providers: Dict[str, OAuth2Provider] = {}
self._states: Dict[str, OAuth2State] = {} # state -> OAuth2State
self._tokens: Dict[str, OAuth2Token] = {} # user_id -> OAuth2Token
self._users: Dict[str, User] = {} # user_id -> User
logger.info("OAuth2 authenticator initialized",
default_redirect_uri=default_redirect_uri)
def add_provider(self, provider: OAuth2Provider) -> None:
"""Add an OAuth2 provider.
Args:
provider: OAuth2 provider configuration
"""
self._providers[provider.name] = provider
logger.info("OAuth2 provider added", provider=provider.name)
def _generate_state(self) -> str:
"""Generate a secure random state parameter.
Returns:
Random state string
"""
return secrets.token_urlsafe(32)
def _generate_code_verifier(self) -> str:
"""Generate PKCE code verifier.
Returns:
Code verifier string
"""
return base64.urlsafe_b64encode(secrets.token_bytes(32)).decode('utf-8').rstrip('=')
def _generate_code_challenge(self, code_verifier: str) -> str:
"""Generate PKCE code challenge.
Args:
code_verifier: Code verifier
Returns:
Code challenge string
"""
digest = hashlib.sha256(code_verifier.encode('utf-8')).digest()
return base64.urlsafe_b64encode(digest).decode('utf-8').rstrip('=')
def get_authorization_url(
self,
provider_name: str,
redirect_uri: Optional[str] = None,
scopes: Optional[List[str]] = None,
use_pkce: bool = True,
) -> tuple[str, str]:
"""Get OAuth2 authorization URL.
Args:
provider_name: Name of the OAuth2 provider
redirect_uri: Custom redirect URI
scopes: Custom scopes
use_pkce: Whether to use PKCE
Returns:
Tuple of (authorization_url, state)
Raises:
AuthenticationError: If provider not found
"""
if provider_name not in self._providers:
raise AuthenticationError(f"OAuth2 provider '{provider_name}' not found")
provider = self._providers[provider_name]
state = self._generate_state()
redirect_uri = redirect_uri or provider.redirect_uri or self.default_redirect_uri
scopes = scopes or provider.scopes
# Generate PKCE parameters if enabled
code_verifier = None
code_challenge = None
if use_pkce:
code_verifier = self._generate_code_verifier()
code_challenge = self._generate_code_challenge(code_verifier)
# Store state
oauth_state = OAuth2State(
state=state,
provider=provider_name,
redirect_uri=redirect_uri,
scopes=scopes,
code_verifier=code_verifier,
)
self._states[state] = oauth_state
# Build authorization URL
params = {
'client_id': provider.client_id,
'response_type': 'code',
'redirect_uri': redirect_uri,
'state': state,
}
if scopes:
params['scope'] = ' '.join(scopes)
if code_challenge:
params['code_challenge'] = code_challenge
params['code_challenge_method'] = 'S256'
auth_url = f"{provider.authorization_url}?{urlencode(params)}"
logger.info("Generated OAuth2 authorization URL",
provider=provider_name, state=state, use_pkce=use_pkce)
return auth_url, state
async def exchange_code(
self,
code: str,
state: str,
redirect_uri: Optional[str] = None,
) -> OAuth2Token:
"""Exchange authorization code for access token.
Args:
code: Authorization code
state: State parameter
redirect_uri: Redirect URI used in authorization
Returns:
OAuth2 token
Raises:
AuthenticationError: If exchange fails
"""
# Validate state
if state not in self._states:
raise AuthenticationError("Invalid or expired OAuth2 state")
oauth_state = self._states[state]
if not oauth_state.is_valid():
del self._states[state]
raise AuthenticationError("OAuth2 state has expired")
provider = self._providers[oauth_state.provider]
redirect_uri = redirect_uri or oauth_state.redirect_uri
# Prepare token request
data = {
'grant_type': 'authorization_code',
'client_id': provider.client_id,
'client_secret': provider.client_secret,
'code': code,
'redirect_uri': redirect_uri,
}
# Add PKCE verifier if used
if oauth_state.code_verifier:
data['code_verifier'] = oauth_state.code_verifier
try:
async with aiohttp.ClientSession() as session:
async with session.post(
provider.token_url,
data=data,
headers={'Accept': 'application/json'},
) as response:
if response.status != 200:
error_text = await response.text()
raise AuthenticationError(
f"Token exchange failed: {response.status} - {error_text}"
)
token_data = await response.json()
except Exception as e:
logger.error("OAuth2 token exchange failed",
provider=oauth_state.provider, error=str(e))
raise AuthenticationError(f"Token exchange failed: {e}")
# Clean up state
del self._states[state]
# Create OAuth2 token
oauth_token = OAuth2Token(
access_token=token_data['access_token'],
token_type=token_data.get('token_type', 'Bearer'),
expires_in=token_data.get('expires_in'),
refresh_token=token_data.get('refresh_token'),
scope=token_data.get('scope'),
provider=oauth_state.provider,
)
logger.info("OAuth2 token exchange successful",
provider=oauth_state.provider)
return oauth_token
async def get_user_info(self, oauth_token: OAuth2Token) -> Dict[str, Any]:
"""Get user information using OAuth2 token.
Args:
oauth_token: OAuth2 access token
Returns:
User information dictionary
Raises:
AuthenticationError: If user info retrieval fails
"""
if not oauth_token.is_valid():
raise AuthenticationError("OAuth2 token is invalid or expired")
provider = self._providers[oauth_token.provider]
try:
async with aiohttp.ClientSession() as session:
headers = {
'Authorization': f'{oauth_token.token_type} {oauth_token.access_token}',
'Accept': 'application/json',
}
async with session.get(
provider.user_info_url,
headers=headers,
) as response:
if response.status != 200:
error_text = await response.text()
raise AuthenticationError(
f"User info retrieval failed: {response.status} - {error_text}"
)
user_info = await response.json()
except Exception as e:
logger.error("OAuth2 user info retrieval failed",
provider=oauth_token.provider, error=str(e))
raise AuthenticationError(f"User info retrieval failed: {e}")
logger.info("OAuth2 user info retrieved",
provider=oauth_token.provider)
return user_info
async def authenticate(self, credentials: Dict[str, Any]) -> User:
"""Authenticate user with OAuth2 token.
Args:
credentials: Dictionary containing 'oauth_token' key
Returns:
Authenticated user
Raises:
AuthenticationError: If authentication fails
"""
oauth_token = credentials.get('oauth_token')
if not isinstance(oauth_token, OAuth2Token):
raise AuthenticationError("Invalid OAuth2 token provided")
if not oauth_token.is_valid():
raise AuthenticationError("OAuth2 token is invalid or expired")
# Get user info from provider
user_info = await self.get_user_info(oauth_token)
# Create or update user
user_id = user_info.get('id') or user_info.get('sub')
if not user_id:
raise AuthenticationError("Unable to extract user ID from OAuth2 response")
user_id = f"{oauth_token.provider}:{user_id}"
user = User(
id=user_id,
username=user_info.get('login') or user_info.get('preferred_username') or user_id,
email=user_info.get('email'),
full_name=user_info.get('name'),
role=UserRole.USER, # Default role
is_active=True,
metadata={
'provider': oauth_token.provider,
'oauth_user_info': user_info,
},
)
# Store user and token
self._users[user_id] = user
oauth_token.user_id = user_id
self._tokens[user_id] = oauth_token
logger.info("OAuth2 authentication successful",
user_id=user_id, provider=oauth_token.provider)
return user
async def validate_token(self, token: str) -> User:
"""Validate OAuth2 access token.
Args:
token: OAuth2 access token
Returns:
User associated with the token
Raises:
InvalidTokenError: If token is invalid
TokenExpiredError: If token is expired
"""
# Find token in storage
oauth_token = None
for stored_token in self._tokens.values():
if stored_token.access_token == token:
oauth_token = stored_token
break
if not oauth_token:
raise InvalidTokenError("OAuth2 token not found")
if oauth_token.is_expired():
raise TokenExpiredError("OAuth2 token has expired")
user = self._users.get(oauth_token.user_id)
if not user:
raise InvalidTokenError("User not found for OAuth2 token")
return user
async def create_token(self, user: User, **kwargs) -> AuthToken:
"""Create authentication token for user.
Args:
user: User to create token for
**kwargs: Additional token parameters
Returns:
Authentication token
"""
oauth_token = self._tokens.get(user.id)
if not oauth_token:
raise AuthenticationError("No OAuth2 token found for user")
return AuthToken(
token=oauth_token.access_token,
user_id=user.id,
expires_at=oauth_token.expires_at,
token_type="oauth2",
metadata={
'provider': oauth_token.provider,
'scope': oauth_token.scope,
},
)
async def refresh_token(self, refresh_token: str) -> OAuth2Token:
"""Refresh OAuth2 access token.
Args:
refresh_token: OAuth2 refresh token
Returns:
New OAuth2 token
Raises:
AuthenticationError: If refresh fails
"""
# Find token by refresh token
oauth_token = None
for stored_token in self._tokens.values():
if stored_token.refresh_token == refresh_token:
oauth_token = stored_token
break
if not oauth_token:
raise AuthenticationError("Refresh token not found")
provider = self._providers[oauth_token.provider]
# Prepare refresh request
data = {
'grant_type': 'refresh_token',
'client_id': provider.client_id,
'client_secret': provider.client_secret,
'refresh_token': refresh_token,
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(
provider.token_url,
data=data,
headers={'Accept': 'application/json'},
) as response:
if response.status != 200:
error_text = await response.text()
raise AuthenticationError(
f"Token refresh failed: {response.status} - {error_text}"
)
token_data = await response.json()
except Exception as e:
logger.error("OAuth2 token refresh failed",
provider=oauth_token.provider, error=str(e))
raise AuthenticationError(f"Token refresh failed: {e}")
# Update token
oauth_token.access_token = token_data['access_token']
oauth_token.token_type = token_data.get('token_type', 'Bearer')
oauth_token.expires_in = token_data.get('expires_in')
oauth_token.refresh_token = token_data.get('refresh_token', refresh_token)
oauth_token.scope = token_data.get('scope', oauth_token.scope)
oauth_token.created_at = datetime.utcnow()
# Recalculate expiry
if oauth_token.expires_in:
oauth_token.expires_at = oauth_token.created_at + timedelta(seconds=oauth_token.expires_in)
logger.info("OAuth2 token refresh successful",
provider=oauth_token.provider)
return oauth_token
async def revoke_token(self, token: str) -> bool:
"""Revoke OAuth2 access token.
Args:
token: Token to revoke
Returns:
True if revocation successful
"""
# Find and remove token
user_id_to_remove = None
for user_id, oauth_token in self._tokens.items():
if oauth_token.access_token == token:
user_id_to_remove = user_id
break
if user_id_to_remove:
del self._tokens[user_id_to_remove]
logger.info("OAuth2 token revoked", user_id=user_id_to_remove)
return True
return False
class OAuth2Authorizer(BaseAuthorizer):
"""OAuth2-based authorizer."""
def __init__(self):
"""Initialize OAuth2 authorizer."""
self._user_permissions: Dict[str, List[Permission]] = {}
logger.info("OAuth2 authorizer initialized")
def set_user_permissions(self, user_id: str, permissions: List[Permission]) -> None:
"""Set permissions for a user.
Args:
user_id: User ID
permissions: List of permissions
"""
self._user_permissions[user_id] = permissions
logger.info("User permissions updated", user_id=user_id, permissions=len(permissions))
async def check_permission(self, user: User, permission: Permission, resource: str = None) -> bool:
"""Check if user has permission.
Args:
user: User to check
permission: Permission to check
resource: Optional resource identifier
Returns:
True if user has permission
"""
# Admin users have all permissions
if user.role == UserRole.ADMIN:
return True
# Check explicit permissions
user_permissions = self._user_permissions.get(user.id, [])
return permission in user_permissions
async def require_permission(
self,
user: User,
permission: Permission,
resource: str = None,
) -> None:
"""Require user to have permission.
Args:
user: User to check
permission: Required permission
resource: Optional resource identifier
Raises:
AuthorizationError: If user lacks permission
"""
if not await self.check_permission(user, permission, resource):
raise AuthorizationError(
f"User '{user.id}' lacks permission '{permission.value}' for resource '{resource}'"
)