jwt_auth.pyโข16.3 kB
#!/usr/bin/env python3
"""
JWT Authentication
Implementation of JWT-based authentication and authorization.
"""
from jose import jwt
import secrets
from typing import Any, Dict, List, Optional
from datetime import datetime, timedelta
import structlog
from .base import (
BaseAuthenticator,
BaseAuthorizer,
User,
AuthToken,
Permission,
UserRole,
AuthenticationError,
InvalidTokenError,
TokenExpiredError,
)
logger = structlog.get_logger(__name__)
class JWTAuthenticator(BaseAuthenticator):
"""JWT-based authenticator."""
def __init__(
self,
secret_key: str,
algorithm: str = "HS256",
access_token_expire_minutes: int = 30,
refresh_token_expire_days: int = 7,
issuer: str = "anydocs-mcp",
):
"""Initialize JWT authenticator.
Args:
secret_key: Secret key for JWT signing
algorithm: JWT signing algorithm
access_token_expire_minutes: Access token expiry in minutes
refresh_token_expire_days: Refresh token expiry in days
issuer: JWT issuer claim
"""
self.secret_key = secret_key
self.algorithm = algorithm
self.access_token_expire_minutes = access_token_expire_minutes
self.refresh_token_expire_days = refresh_token_expire_days
self.issuer = issuer
# In-memory storage (in production, use database)
self._users: Dict[str, User] = {} # user_id -> User
self._refresh_tokens: Dict[str, Dict[str, Any]] = {} # token_id -> token_info
self._revoked_tokens: set = set() # Set of revoked token JTIs
logger.info("JWT authenticator initialized",
algorithm=algorithm,
access_token_expire_minutes=access_token_expire_minutes,
refresh_token_expire_days=refresh_token_expire_days)
def _generate_jti(self) -> str:
"""Generate a unique JWT ID.
Returns:
Unique JWT ID
"""
return secrets.token_urlsafe(16)
def _create_jwt_payload(
self,
user: User,
token_type: str = "access",
expires_delta: Optional[timedelta] = None,
scopes: List[str] = None,
) -> Dict[str, Any]:
"""Create JWT payload.
Args:
user: User to create token for
token_type: Type of token (access, refresh)
expires_delta: Custom expiry delta
scopes: Token scopes
Returns:
JWT payload dictionary
"""
now = datetime.utcnow()
# Calculate expiry
if expires_delta:
expire = now + expires_delta
elif token_type == "refresh":
expire = now + timedelta(days=self.refresh_token_expire_days)
else:
expire = now + timedelta(minutes=self.access_token_expire_minutes)
# Create payload
payload = {
"sub": user.id, # Subject (user ID)
"username": user.username,
"email": user.email,
"role": user.role.value,
"permissions": [p.value for p in user.permissions],
"iss": self.issuer, # Issuer
"iat": now, # Issued at
"exp": expire, # Expiry
"jti": self._generate_jti(), # JWT ID
"type": token_type, # Token type
}
# Add scopes if provided
if scopes:
payload["scopes"] = scopes
# Add user metadata
if user.metadata:
payload["metadata"] = user.metadata
return payload
def _encode_jwt(self, payload: Dict[str, Any]) -> str:
"""Encode JWT payload.
Args:
payload: JWT payload
Returns:
Encoded JWT token
"""
return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
def _decode_jwt(self, token: str) -> Dict[str, Any]:
"""Decode JWT token.
Args:
token: JWT token to decode
Returns:
Decoded JWT payload
Raises:
InvalidTokenError: If token is invalid
TokenExpiredError: If token has expired
"""
try:
payload = jwt.decode(
token,
self.secret_key,
algorithms=[self.algorithm],
issuer=self.issuer,
)
# Check if token is revoked
jti = payload.get("jti")
if jti and jti in self._revoked_tokens:
raise InvalidTokenError("Token has been revoked")
return payload
except jwt.ExpiredSignatureError:
raise TokenExpiredError("Token has expired")
except jwt.InvalidTokenError as e:
raise InvalidTokenError(f"Invalid token: {str(e)}")
async def authenticate(self, credentials: Dict[str, Any]) -> Optional[User]:
"""Authenticate using username/password credentials.
Args:
credentials: Dict with 'username' and 'password'
Returns:
User object if authentication successful
"""
username = credentials.get("username")
password = credentials.get("password")
if not username or not password:
raise AuthenticationError("Username and password are required")
# Find user by username
user = None
for u in self._users.values():
if u.username == username:
user = u
break
if not user:
logger.warning("User not found", username=username)
raise AuthenticationError("Invalid credentials")
# In a real implementation, verify password hash
# For demo purposes, we'll use a simple check
stored_password = user.metadata.get("password")
if not stored_password or stored_password != password:
logger.warning("Password verification failed", username=username)
raise AuthenticationError("Invalid credentials")
# Check if user is active
if not user.is_active:
logger.warning("Inactive user attempted login", username=username)
raise AuthenticationError("User account is inactive")
# Update last login
user.update_last_login()
logger.info("User authentication successful",
user_id=user.id,
username=user.username)
return user
async def validate_token(self, token: str) -> Optional[AuthToken]:
"""Validate JWT token.
Args:
token: JWT token to validate
Returns:
AuthToken object if valid
"""
payload = self._decode_jwt(token)
# Extract token information
user_id = payload.get("sub")
token_type = payload.get("type", "access")
expires_at = datetime.fromtimestamp(payload.get("exp", 0))
scopes = payload.get("scopes", [])
jti = payload.get("jti")
return AuthToken(
token=token,
token_type=f"jwt_{token_type}",
expires_at=expires_at,
user_id=user_id,
scopes=scopes,
metadata={
"jti": jti,
"username": payload.get("username"),
"role": payload.get("role"),
"permissions": payload.get("permissions", []),
},
)
async def create_token(self, user: User, **kwargs) -> AuthToken:
"""Create JWT access token for user.
Args:
user: User to create token for
**kwargs: Additional parameters (scopes, expires_delta)
Returns:
AuthToken object
"""
scopes = kwargs.get("scopes", [])
expires_delta = kwargs.get("expires_delta")
payload = self._create_jwt_payload(
user=user,
token_type="access",
expires_delta=expires_delta,
scopes=scopes,
)
token = self._encode_jwt(payload)
return AuthToken(
token=token,
token_type="jwt_access",
expires_at=datetime.fromtimestamp(payload["exp"]),
user_id=user.id,
scopes=scopes,
metadata={
"jti": payload["jti"],
"username": user.username,
"role": user.role.value,
},
)
async def create_refresh_token(self, user: User, **kwargs) -> AuthToken:
"""Create JWT refresh token for user.
Args:
user: User to create token for
**kwargs: Additional parameters
Returns:
AuthToken object
"""
payload = self._create_jwt_payload(
user=user,
token_type="refresh",
)
token = self._encode_jwt(payload)
jti = payload["jti"]
# Store refresh token info
self._refresh_tokens[jti] = {
"user_id": user.id,
"created_at": datetime.utcnow(),
"expires_at": datetime.fromtimestamp(payload["exp"]),
}
return AuthToken(
token=token,
token_type="jwt_refresh",
expires_at=datetime.fromtimestamp(payload["exp"]),
user_id=user.id,
metadata={
"jti": jti,
"username": user.username,
},
)
async def create_token_pair(self, user: User, **kwargs) -> Dict[str, AuthToken]:
"""Create access and refresh token pair.
Args:
user: User to create tokens for
**kwargs: Additional parameters
Returns:
Dictionary with 'access' and 'refresh' tokens
"""
access_token = await self.create_token(user, **kwargs)
refresh_token = await self.create_refresh_token(user, **kwargs)
return {
"access": access_token,
"refresh": refresh_token,
}
async def refresh_token(self, refresh_token: str) -> Optional[AuthToken]:
"""Refresh access token using refresh token.
Args:
refresh_token: Refresh token
Returns:
New access token if refresh successful
"""
# Validate refresh token
try:
payload = self._decode_jwt(refresh_token)
except (InvalidTokenError, TokenExpiredError) as e:
logger.warning("Refresh token validation failed", error=str(e))
return None
# Check token type
if payload.get("type") != "refresh":
logger.warning("Invalid token type for refresh", type=payload.get("type"))
return None
# Check if refresh token exists
jti = payload.get("jti")
if not jti or jti not in self._refresh_tokens:
logger.warning("Refresh token not found", jti=jti)
return None
# Get user
user_id = payload.get("sub")
user = await self.get_user(user_id)
if not user:
logger.warning("User not found for refresh token", user_id=user_id)
return None
# Create new access token
access_token = await self.create_token(user)
logger.info("Token refreshed successfully",
user_id=user_id,
username=user.username)
return access_token
async def revoke_token(self, token: str) -> bool:
"""Revoke JWT token.
Args:
token: JWT token to revoke
Returns:
True if token was revoked
"""
try:
payload = self._decode_jwt(token)
jti = payload.get("jti")
if jti:
self._revoked_tokens.add(jti)
# If it's a refresh token, remove from storage
if payload.get("type") == "refresh" and jti in self._refresh_tokens:
del self._refresh_tokens[jti]
logger.info("Token revoked", jti=jti, type=payload.get("type"))
return True
except (InvalidTokenError, TokenExpiredError):
# Token is already invalid, consider it revoked
return True
return False
async def revoke_all_user_tokens(self, user_id: str) -> int:
"""Revoke all tokens for a user.
Args:
user_id: User identifier
Returns:
Number of tokens revoked
"""
revoked_count = 0
# Revoke refresh tokens
refresh_tokens_to_remove = []
for jti, token_info in self._refresh_tokens.items():
if token_info["user_id"] == user_id:
self._revoked_tokens.add(jti)
refresh_tokens_to_remove.append(jti)
revoked_count += 1
# Remove refresh tokens from storage
for jti in refresh_tokens_to_remove:
del self._refresh_tokens[jti]
logger.info("All user tokens revoked", user_id=user_id, count=revoked_count)
return revoked_count
async def get_user(self, user_id: str) -> Optional[User]:
"""Get user by ID.
Args:
user_id: User identifier
Returns:
User object if found
"""
return self._users.get(user_id)
async def get_user_by_token(self, token: str) -> Optional[User]:
"""Get user from JWT token.
Args:
token: JWT token
Returns:
User object if token is valid and user exists
"""
try:
payload = self._decode_jwt(token)
user_id = payload.get("sub")
if user_id:
return await self.get_user(user_id)
except (InvalidTokenError, TokenExpiredError):
pass
return None
async def add_user(self, user: User, password: str) -> None:
"""Add a user to the authenticator.
Args:
user: User to add
password: User password (stored in metadata for demo)
"""
# In production, hash the password properly
user.metadata["password"] = password
self._users[user.id] = user
logger.info("User added to JWT authenticator",
user_id=user.id,
username=user.username)
async def cleanup_expired_tokens(self) -> int:
"""Clean up expired refresh tokens.
Returns:
Number of tokens cleaned up
"""
now = datetime.utcnow()
expired_tokens = []
for jti, token_info in self._refresh_tokens.items():
if token_info["expires_at"] < now:
expired_tokens.append(jti)
# Remove expired tokens
for jti in expired_tokens:
del self._refresh_tokens[jti]
self._revoked_tokens.discard(jti)
logger.info("Expired tokens cleaned up", count=len(expired_tokens))
return len(expired_tokens)
class JWTAuthorizer(BaseAuthorizer):
"""JWT-based authorizer that uses token claims."""
async def check_permission(self, user: User, permission: Permission, resource: Optional[str] = None) -> bool:
"""Check if user has permission.
Args:
user: User to check
permission: Required permission
resource: Optional resource identifier
Returns:
True if user has permission
"""
return user.has_permission(permission)