"""
Redis-based token blacklist module for managing revoked tokens.
"""
import redis.asyncio as redis
import json
import logging
from typing import Optional, Set
import os
logger = logging.getLogger(__name__)
class TokenBlacklist:
"""Redis-based token blacklist manager."""
def __init__(self):
self.redis_client: Optional[redis.Redis] = None
self.redis_url = os.getenv("REDIS_URL", "redis://localhost:6379")
async def connect(self) -> redis.Redis:
"""Connect to Redis server."""
try:
self.redis_client = redis.from_url(
self.redis_url,
encoding="utf-8",
decode_responses=True,
socket_timeout=5,
socket_connect_timeout=5,
health_check_interval=30
)
# Test connection
await self.redis_client.ping()
logger.info("Connected to Redis successfully")
return self.redis_client
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
raise
async def disconnect(self):
"""Disconnect from Redis server."""
if self.redis_client:
await self.redis_client.close()
logger.info("Disconnected from Redis")
async def add_to_blacklist(self, token: str, expires_in: int = None) -> bool:
"""
Add token to blacklist.
Args:
token: JWT token to blacklist
expires_in: Token expiration time in seconds (uses token expiry if None)
Returns:
True if successful, False otherwise
"""
if not self.redis_client:
await self.connect()
try:
# Calculate expiration time for the blacklist entry
if expires_in is None:
# Default to 24 hours if no expiration provided
expires_in = 24 * 60 * 60
# Store token with expiration
await self.redis_client.setex(
f"blacklist:{token}",
expires_in,
"revoked"
)
logger.info(f"Token added to blacklist: {token[:20]}...")
return True
except Exception as e:
logger.error(f"Failed to add token to blacklist: {e}")
return False
async def is_in_blacklist(self, token: str) -> bool:
"""
Check if token is in blacklist.
Args:
token: JWT token to check
Returns:
True if token is blacklisted, False otherwise
"""
if not self.redis_client:
await self.connect()
try:
result = await self.redis_client.get(f"blacklist:{token}")
is_blacklisted = result is not None
if is_blacklisted:
logger.warning(f"Blacklisted token attempted: {token[:20]}...")
else:
logger.debug(f"Token not in blacklist: {token[:20]}...")
return is_blacklisted
except Exception as e:
logger.error(f"Failed to check blacklist status: {e}")
return False
async def remove_from_blacklist(self, token: str) -> bool:
"""
Remove token from blacklist.
Args:
token: JWT token to remove from blacklist
Returns:
True if successful, False otherwise
"""
if not self.redis_client:
await self.connect()
try:
result = await self.redis_client.delete(f"blacklist:{token}")
if result > 0:
logger.info(f"Token removed from blacklist: {token[:20]}...")
return True
else:
logger.warning(f"Token not found in blacklist: {token[:20]}...")
return False
except Exception as e:
logger.error(f"Failed to remove token from blacklist: {e}")
return False
async def blacklist_all_user_tokens(self, user_id: str) -> bool:
"""
Blacklist all tokens for a specific user.
Args:
user_id: User identifier
Returns:
True if successful, False otherwise
"""
if not self.redis_client:
await self.connect()
try:
# Find all blacklisted tokens for this user
pattern = f"user_tokens:{user_id}:*"
keys = await self.redis_client.keys(pattern)
if keys:
# Add all found tokens to blacklist
pipeline = self.redis_client.pipeline()
for key in keys:
token = key.decode().replace(f"user_tokens:{user_id}:", "")
pipeline.setex(f"blacklist:{token}", 24 * 60 * 60, "revoked")
await pipeline.execute()
logger.info(f"Blacklisted {len(keys)} tokens for user: {user_id}")
return True
except Exception as e:
logger.error(f"Failed to blacklist user tokens: {e}")
return False
async def add_user_token(self, user_id: str, token: str, expires_in: int = None) -> bool:
"""
Add token to user's token list (for batch blacklisting).
Args:
user_id: User identifier
token: JWT token to store
expires_in: Token expiration time in seconds
Returns:
True if successful, False otherwise
"""
if not self.redis_client:
await self.connect()
try:
if expires_in is None:
expires_in = 24 * 60 * 60
# Store token in user's token list
await self.redis_client.setex(
f"user_tokens:{user_id}:{token}",
expires_in,
"active"
)
logger.debug(f"Token added to user list: {token[:20]}...")
return True
except Exception as e:
logger.error(f"Failed to add user token: {e}")
return False
async def get_blacklist_stats(self) -> Optional[dict]:
"""
Get blacklist statistics.
Returns:
Dictionary containing statistics or None if failed
"""
if not self.redis_client:
await self.connect()
try:
blacklist_keys = await self.redis_client.keys("blacklist:*")
user_token_keys = await self.redis_client.keys("user_tokens:*")
return {
"blacklisted_tokens": len(blacklist_keys),
"active_user_tokens": len(user_token_keys),
"redis_info": await self.redis_client.info()
}
except Exception as e:
logger.error(f"Failed to get blacklist stats: {e}")
return None
async def cleanup_expired_entries(self) -> int:
"""
Clean up expired blacklist entries.
Returns:
Number of cleaned entries
"""
if not self.redis_client:
await self.connect()
try:
# Redis automatically handles expiration with SETEX
# This method can be used for manual cleanup if needed
keys = await self.redis_client.keys("blacklist:*")
expired_count = 0
for key in keys:
ttl = await self.redis_client.ttl(key)
if ttl <= 0:
await self.redis_client.delete(key)
expired_count += 1
if expired_count > 0:
logger.info(f"Cleaned up {expired_count} expired blacklist entries")
return expired_count
except Exception as e:
logger.error(f"Failed to cleanup expired entries: {e}")
return 0
# Global blacklist instance
blacklist = TokenBlacklist()
async def add_to_blacklist(token: str, expires_in: int = None) -> bool:
"""Global function to add token to blacklist."""
return await blacklist.add_to_blacklist(token, expires_in)
async def is_in_blacklist(token: str) -> bool:
"""Global function to check if token is blacklisted."""
return await blacklist.is_in_blacklist(token)
async def remove_from_blacklist(token: str) -> bool:
"""Global function to remove token from blacklist."""
return await blacklist.remove_from_blacklist(token)