"""
Redis-based token blacklist module for managing revoked tokens.
"""
import redis.asyncio as redis
import json
import logging
from typing import Optional, Set
import os
logger = logging.getLogger(__name__)
class TokenBlacklist:
"""Redis-based token blacklist manager with in-memory fallback."""
def __init__(self):
self.redis_client: Optional[redis.Redis] = None
self.redis_url = os.getenv("REDIS_URL", "redis://localhost:6379")
# Fallback in-memory storage when Redis is not available
self.memory_blacklist: Set[str] = set()
self.memory_user_tokens: dict = {} # user_id -> set of tokens
async def connect(self) -> redis.Redis:
"""Connect to Redis server."""
try:
self.redis_client = redis.from_url(
self.redis_url,
encoding="utf-8",
decode_responses=True,
socket_timeout=5,
socket_connect_timeout=5,
health_check_interval=30
)
# Test connection
await self.redis_client.ping()
logger.info("Connected to Redis successfully")
return self.redis_client
except Exception as e:
logger.warning(f"Redis connection failed, using in-memory storage: {e}")
logger.warning("Token blacklist will be non-persistent without Redis")
self.redis_client = None
return None
async def disconnect(self):
"""Disconnect from Redis server."""
if self.redis_client:
await self.redis_client.close()
logger.info("Disconnected from Redis")
async def add_to_blacklist(self, token: str, expires_in: int = None, user_id: str = None) -> bool:
"""
Add token to blacklist.
Args:
token: JWT token to blacklist
expires_in: Token expiration time in seconds (uses token expiry if None)
user_id: Optional user ID for user-specific token tracking
Returns:
True if successful, False otherwise
"""
if not self.redis_client:
await self.connect()
try:
# Calculate expiration time for the blacklist entry
if expires_in is None:
# Default to 24 hours if no expiration provided
expires_in = 24 * 60 * 60
if self.redis_client:
# Use Redis if available
await self.redis_client.setex(
f"blacklist:{token}",
expires_in,
"revoked"
)
logger.info(f"Token added to Redis blacklist: {token[:20]}...")
else:
# Fallback to in-memory storage
self.memory_blacklist.add(token)
logger.warning(f"Token added to memory blacklist (non-persistent): {token[:20]}...")
return True
except Exception as e:
logger.error(f"Failed to add token to blacklist: {e}")
return False
async def is_in_blacklist(self, token: str) -> bool:
"""
Check if token is in blacklist.
Args:
token: JWT token to check
Returns:
True if token is blacklisted, False otherwise
"""
if not self.redis_client:
await self.connect()
try:
if self.redis_client:
# Check Redis if available
result = await self.redis_client.get(f"blacklist:{token}")
is_blacklisted = result is not None
else:
# Check memory storage if Redis not available
is_blacklisted = token in self.memory_blacklist
if is_blacklisted:
logger.warning(f"Blacklisted token attempted: {token[:20]}...")
else:
logger.debug(f"Token not in blacklist: {token[:20]}...")
return is_blacklisted
except Exception as e:
logger.error(f"Failed to check blacklist status: {e}")
return False
async def remove_from_blacklist(self, token: str) -> bool:
"""
Remove token from blacklist.
Args:
token: JWT token to remove from blacklist
Returns:
True if successful, False otherwise
"""
if not self.redis_client:
await self.connect()
try:
if self.redis_client:
# Use Redis if available
result = await self.redis_client.delete(f"blacklist:{token}")
if result > 0:
logger.info(f"Token removed from Redis blacklist: {token[:20]}...")
return True
else:
logger.warning(f"Token not found in Redis blacklist: {token[:20]}...")
return False
else:
# Use memory storage if Redis not available
if token in self.memory_blacklist:
self.memory_blacklist.remove(token)
logger.warning(f"Token removed from memory blacklist: {token[:20]}...")
return True
else:
logger.warning(f"Token not found in memory blacklist: {token[:20]}...")
return False
except Exception as e:
logger.error(f"Failed to remove token from blacklist: {e}")
return False
async def blacklist_all_user_tokens(self, user_id: str) -> bool:
"""
Blacklist all tokens for a specific user.
Args:
user_id: User identifier
Returns:
True if successful, False otherwise
"""
if not self.redis_client:
await self.connect()
try:
if self.redis_client:
# Find all blacklisted tokens for this user
pattern = f"user_tokens:{user_id}:*"
keys = await self.redis_client.keys(pattern)
if keys:
# Add all found tokens to blacklist
pipeline = self.redis_client.pipeline()
for key in keys:
token = key.decode().replace(f"user_tokens:{user_id}:", "")
pipeline.setex(f"blacklist:{token}", 24 * 60 * 60, "revoked")
await pipeline.execute()
logger.info(f"Blacklisted {len(keys)} tokens for user: {user_id}")
return True
else:
# Not supported with memory storage
logger.warning(f"User token blacklisting not supported with memory storage for user: {user_id}")
return False
except Exception as e:
logger.error(f"Failed to blacklist user tokens: {e}")
return False
async def add_user_token(self, user_id: str, token: str, expires_in: int = None) -> bool:
"""
Add token to user's token list (for batch blacklisting).
Args:
user_id: User identifier
token: JWT token to store
expires_in: Token expiration time in seconds
Returns:
True if successful, False otherwise
"""
if not self.redis_client:
await self.connect()
try:
if expires_in is None:
expires_in = 24 * 60 * 60
if self.redis_client:
# Store token in user's token list using Redis
await self.redis_client.setex(
f"user_tokens:{user_id}:{token}",
expires_in,
"active"
)
logger.debug(f"Token added to user Redis list: {token[:20]}...")
else:
# Store token in memory storage if Redis not available
if user_id not in self.memory_user_tokens:
self.memory_user_tokens[user_id] = set()
self.memory_user_tokens[user_id].add(token)
logger.warning(f"Token added to user memory list (non-persistent): {token[:20]}...")
return True
except Exception as e:
logger.error(f"Failed to add user token: {e}")
return False
async def get_blacklist_stats(self) -> Optional[dict]:
"""
Get blacklist statistics.
Returns:
Dictionary containing statistics or None if failed
"""
if not self.redis_client:
await self.connect()
try:
if self.redis_client:
# Use Redis if available
blacklist_keys = await self.redis_client.keys("blacklist:*")
user_token_keys = await self.redis_client.keys("user_tokens:*")
return {
"blacklisted_tokens": len(blacklist_keys),
"active_user_tokens": len(user_token_keys),
"redis_info": await self.redis_client.info(),
"storage_type": "redis"
}
else:
# Return memory storage stats if Redis not available
return {
"blacklisted_tokens": len(self.memory_blacklist),
"active_user_tokens": len(self.memory_user_tokens),
"redis_info": None,
"storage_type": "memory",
"note": "Using in-memory storage - data will be lost on restart"
}
except Exception as e:
logger.error(f"Failed to get blacklist stats: {e}")
return None
async def cleanup_expired_entries(self) -> int:
"""
Clean up expired blacklist entries.
Returns:
Number of cleaned entries
"""
if not self.redis_client:
await self.connect()
try:
if self.redis_client:
# Redis automatically handles expiration with SETEX
# This method can be used for manual cleanup if needed
keys = await self.redis_client.keys("blacklist:*")
expired_count = 0
for key in keys:
ttl = await self.redis_client.ttl(key)
if ttl <= 0:
await self.redis_client.delete(key)
expired_count += 1
if expired_count > 0:
logger.info(f"Cleaned up {expired_count} expired Redis blacklist entries")
return expired_count
else:
# Manual cleanup not needed for memory storage - tokens are kept until removed explicitly
logger.debug("Manual cleanup not needed for memory storage")
return 0
except Exception as e:
logger.error(f"Failed to cleanup expired entries: {e}")
return 0
# Compatibility methods for router.py
async def add_token(self, token: str, user_id: str = None) -> bool:
"""
Compatibility method for router.py - adds token to blacklist.
Args:
token: JWT token to blacklist
user_id: Optional user ID for tracking
Returns:
True if successful, False otherwise
"""
return await self.add_to_blacklist(token, user_id=user_id)
async def check_token(self, token: str) -> bool:
"""
Compatibility method for router.py - checks if token is blacklisted.
Args:
token: JWT token to check
Returns:
True if token is blacklisted, False otherwise
"""
return await self.is_in_blacklist(token)
async def remove_token(self, token: str) -> bool:
"""
Compatibility method for router.py - removes token from blacklist.
Args:
token: JWT token to remove from blacklist
Returns:
True if successful, False otherwise
"""
return await self.remove_from_blacklist(token)
# Global blacklist instance
blacklist = TokenBlacklist()
async def add_to_blacklist(token: str, expires_in: int = None, user_id: str = None) -> bool:
"""Global function to add token to blacklist."""
return await blacklist.add_to_blacklist(token, expires_in, user_id)
async def is_in_blacklist(token: str) -> bool:
"""Global function to check if token is blacklisted."""
return await blacklist.is_in_blacklist(token)
async def remove_from_blacklist(token: str) -> bool:
"""Global function to remove token from blacklist."""
return await blacklist.remove_from_blacklist(token)