jwt_manager.py•2.76 kB
"""JWT token management."""
from datetime import datetime, timedelta
from typing import Any, Dict, Optional
import jwt
from src.utils.logger import get_logger
logger = get_logger(__name__)
class AuthenticationError(Exception):
"""Authentication error."""
pass
class JWTManager:
"""
Manages JWT token generation, validation, and refresh.
"""
def __init__(self, secret_key: str, algorithm: str = "HS256"):
"""
Initialize JWT manager.
Args:
secret_key: Secret key for signing tokens
algorithm: JWT algorithm (default: HS256)
"""
self.secret_key = secret_key
self.algorithm = algorithm
def generate_token(self, payload: Dict[str, Any], expires_in: int = 3600) -> str:
"""
Generate a new JWT token.
Args:
payload: Token payload data
expires_in: Expiration time in seconds
Returns:
Encoded JWT token
"""
exp = datetime.utcnow() + timedelta(seconds=expires_in)
payload_copy = payload.copy()
payload_copy["exp"] = exp
payload_copy["iat"] = datetime.utcnow()
token = jwt.encode(payload_copy, self.secret_key, algorithm=self.algorithm)
logger.debug("jwt_token_generated", expires_in=expires_in)
return token
def validate_token(self, token: str) -> Dict[str, Any]:
"""
Validate and decode JWT token.
Args:
token: JWT token to validate
Returns:
Decoded token payload
Raises:
AuthenticationError: If token is invalid or expired
"""
try:
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
logger.debug("jwt_token_validated")
return payload
except jwt.ExpiredSignatureError:
logger.warning("jwt_token_expired")
raise AuthenticationError("Token has expired")
except jwt.InvalidTokenError as e:
logger.warning("jwt_token_invalid", error=str(e))
raise AuthenticationError("Invalid token")
def refresh_token(self, token: str, expires_in: int = 3600) -> str:
"""
Refresh an existing token.
Args:
token: Current JWT token
expires_in: New expiration time in seconds
Returns:
New JWT token
Raises:
AuthenticationError: If token is invalid
"""
payload = self.validate_token(token)
# Remove old expiration and issued-at times
payload.pop("exp", None)
payload.pop("iat", None)
logger.debug("jwt_token_refreshed")
return self.generate_token(payload, expires_in)