"""
Database connection management module using asyncpg for PostgreSQL.
"""
import asyncpg
from asyncpg import Pool
import logging
from typing import Optional, Dict, Any
import os
logger = logging.getLogger(__name__)
class DatabaseManager:
"""PostgreSQL database connection pool manager."""
def __init__(self):
self.pool: Optional[Pool] = None
self.database_url = os.getenv("DATABASE_URL")
async def create_pool(self) -> Pool:
"""Create PostgreSQL connection pool."""
if not self.database_url:
raise ValueError("DATABASE_URL environment variable is not set")
try:
self.pool = await asyncpg.create_pool(
self.database_url,
min_size=5,
max_size=20,
command_timeout=60,
server_settings={
'application_name': 'auth_service'
}
)
logger.info("Database connection pool created successfully")
return self.pool
except Exception as e:
logger.error(f"Failed to create database pool: {e}")
raise
async def close_pool(self):
"""Close the database connection pool."""
if self.pool:
await self.pool.close()
logger.info("Database connection pool closed")
async def execute_query(self, query: str, *args) -> Optional[Dict[str, Any]]:
"""Execute a single query and return result."""
if not self.pool:
await self.create_pool()
async with self.pool.acquire() as connection:
try:
result = await connection.fetchrow(query, *args)
return dict(result) if result else None
except Exception as e:
logger.error(f"Query execution failed: {e}")
raise
async def execute_many(self, query: str, args_list) -> str:
"""Execute a query multiple times with different args."""
if not self.pool:
await self.create_pool()
async with self.pool.acquire() as connection:
try:
await connection.executemany(query, args_list)
return "Success"
except Exception as e:
logger.error(f"Batch execution failed: {e}")
raise
# Global database manager instance
db_manager = DatabaseManager()
async def get_db_pool() -> Pool:
"""Get database connection pool."""
if not db_manager.pool:
await db_manager.create_pool()
return db_manager.pool
async def fetch_credentials(user_id: str) -> Optional[Dict[str, Any]]:
"""
Fetch user credentials from database.
Args:
user_id: User identifier
Returns:
Dictionary containing user credentials or None if not found
"""
query = """
SELECT
id,
email,
encrypted_password,
salt,
is_active,
created_at,
updated_at
FROM users
WHERE id = $1 AND is_active = true
"""
try:
result = await db_manager.execute_query(query, user_id)
return result
except Exception as e:
logger.error(f"Failed to fetch credentials for user {user_id}: {e}")
return None
async def store_user_session(user_id: str, session_data: Dict[str, Any]) -> bool:
"""
Store user session data in database.
Args:
user_id: User identifier
session_data: Session information to store
Returns:
True if successful, False otherwise
"""
query = """
INSERT INTO user_sessions (user_id, session_data, created_at, expires_at)
VALUES ($1, $2, NOW(), NOW() + INTERVAL '24 hours')
ON CONFLICT (user_id)
DO UPDATE SET
session_data = $2,
updated_at = NOW(),
expires_at = NOW() + INTERVAL '24 hours'
"""
try:
await db_manager.execute_query(
query,
user_id,
session_data
)
return True
except Exception as e:
logger.error(f"Failed to store session for user {user_id}: {e}")
return False
async def cleanup_expired_sessions():
"""Clean up expired user sessions."""
query = """
DELETE FROM user_sessions
WHERE expires_at < NOW()
"""
try:
await db_manager.execute_query(query)
logger.info("Expired sessions cleaned up")
except Exception as e:
logger.error(f"Failed to cleanup expired sessions: {e}")