"""
JWT token management module using python-jose.
"""
import os
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
from jose import JWTError, jwt
from passlib.context import CryptContext
import logging
logger = logging.getLogger(__name__)
# JWT Configuration
SECRET_KEY = os.getenv("JWT_SECRET_KEY", "your-super-secret-key-change-in-production")
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
REFRESH_TOKEN_EXPIRE_DAYS = 7
# Password hashing context
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
class JWTHandler:
"""JWT token creation and verification handler."""
def __init__(self):
if not SECRET_KEY or SECRET_KEY == "your-super-secret-key-change-in-production":
logger.warning("Using default JWT secret key. Set JWT_SECRET_KEY environment variable!")
def verify_password(self, plain_password: str, hashed_password: str) -> bool:
"""Verify a plain password against its hash."""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(self, password: str) -> str:
"""Generate password hash."""
return pwd_context.hash(password)
def create_access_token(
self,
data: Dict[str, Any],
expires_delta: Optional[timedelta] = None
) -> str:
"""
Create JWT access token.
Args:
data: Data to encode in the token
expires_delta: Custom expiration time
Returns:
Encoded JWT token string
"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire, "type": "access"})
try:
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
logger.info(f"Access token created for user: {data.get('sub', 'unknown')}")
return encoded_jwt
except JWTError as e:
logger.error(f"Failed to create access token: {e}")
raise
def create_refresh_token(self, data: Dict[str, Any]) -> str:
"""
Create JWT refresh token.
Args:
data: Data to encode in the token
Returns:
Encoded JWT refresh token string
"""
to_encode = data.copy()
expire = datetime.utcnow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
to_encode.update({"exp": expire, "type": "refresh"})
try:
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
logger.info(f"Refresh token created for user: {data.get('sub', 'unknown')}")
return encoded_jwt
except JWTError as e:
logger.error(f"Failed to create refresh token: {e}")
raise
def verify_access_token(
self,
token: str,
token_type: str = "access"
) -> Optional[Dict[str, Any]]:
"""
Verify and decode JWT token.
Args:
token: JWT token to verify
token_type: Expected token type (access or refresh)
Returns:
Decoded token payload or None if invalid
"""
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
# Check token type
token_type_in_payload = payload.get("type")
if token_type_in_payload != token_type:
logger.warning(f"Token type mismatch. Expected: {token_type}, Got: {token_type_in_payload}")
return None
# Check expiration
exp = payload.get("exp")
if exp is None or datetime.utcnow() > datetime.fromtimestamp(exp):
logger.warning("Token has expired")
return None
user_id: str = payload.get("sub")
if user_id is None:
logger.warning("Token missing subject")
return None
logger.info(f"Token verified successfully for user: {user_id}")
return payload
except JWTError as e:
logger.error(f"JWT verification failed: {e}")
return None
def create_token_pair(self, user_data: Dict[str, Any]) -> Dict[str, str]:
"""
Create both access and refresh tokens.
Args:
user_data: User information to encode
Returns:
Dictionary containing both tokens
"""
access_token_data = {
"sub": user_data["id"],
"email": user_data["email"],
"type": "user"
}
refresh_token_data = {
"sub": user_data["id"],
"email": user_data["email"],
"type": "refresh"
}
access_token = self.create_access_token(access_token_data)
refresh_token = self.create_refresh_token(refresh_token_data)
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer",
"expires_in": ACCESS_TOKEN_EXPIRE_MINUTES * 60
}
def get_token_payload(self, token: str) -> Optional[Dict[str, Any]]:
"""
Get token payload without verification (for logging/debugging).
Args:
token: JWT token
Returns:
Decoded payload or None if decode fails
"""
try:
return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM], options={"verify_exp": False})
except JWTError:
return None
# Global JWT handler instance
jwt_handler = JWTHandler()
def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
"""Global function to create access token."""
return jwt_handler.create_access_token(data, expires_delta)
def verify_access_token(token: str) -> Optional[Dict[str, Any]]:
"""Global function to verify access token."""
return jwt_handler.verify_access_token(token)
def create_token_pair(user_data: Dict[str, Any]) -> Dict[str, str]:
"""Global function to create token pair."""
return jwt_handler.create_token_pair(user_data)