performance.pyโข14.8 kB
#!/usr/bin/env python3
"""
Performance Optimization Module for AnyDocs MCP Server
Comprehensive caching strategies, connection pooling, and performance enhancements.
"""
import asyncio
import json
import time
from typing import Any, Dict, Optional, Union, Callable
from functools import wraps, lru_cache
from datetime import datetime, timedelta
import hashlib
try:
import redis.asyncio as redis
REDIS_AVAILABLE = True
except ImportError:
REDIS_AVAILABLE = False
redis = None
try:
import aiofiles
AIOFILES_AVAILABLE = True
except ImportError:
AIOFILES_AVAILABLE = False
aiofiles = None
from ..utils.logging import get_logger
logger = get_logger(__name__)
class CacheManager:
"""Comprehensive cache management with multiple backends."""
def __init__(self, redis_url: Optional[str] = None, default_ttl: int = 3600):
"""Initialize cache manager."""
self.default_ttl = default_ttl
self.redis_client = None
self.memory_cache = {}
self.cache_stats = {
'hits': 0,
'misses': 0,
'sets': 0,
'deletes': 0
}
# Initialize Redis if available and configured
if redis_url and REDIS_AVAILABLE:
try:
self.redis_client = redis.from_url(redis_url, decode_responses=True)
logger.info(f"Redis cache initialized: {redis_url}")
except Exception as e:
logger.warning(f"Failed to initialize Redis cache: {e}")
async def get(self, key: str, default: Any = None) -> Any:
"""Get value from cache."""
try:
# Try Redis first
if self.redis_client:
try:
value = await self.redis_client.get(key)
if value is not None:
self.cache_stats['hits'] += 1
return json.loads(value)
except Exception as e:
logger.warning(f"Redis get failed for key {key}: {e}")
# Fallback to memory cache
if key in self.memory_cache:
cache_entry = self.memory_cache[key]
if cache_entry['expires_at'] > time.time():
self.cache_stats['hits'] += 1
return cache_entry['value']
else:
# Expired, remove from cache
del self.memory_cache[key]
self.cache_stats['misses'] += 1
return default
except Exception as e:
logger.error(f"Cache get error for key {key}: {e}")
self.cache_stats['misses'] += 1
return default
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""Set value in cache."""
try:
ttl = ttl or self.default_ttl
# Try Redis first
if self.redis_client:
try:
serialized_value = json.dumps(value, default=str)
await self.redis_client.setex(key, ttl, serialized_value)
self.cache_stats['sets'] += 1
return True
except Exception as e:
logger.warning(f"Redis set failed for key {key}: {e}")
# Fallback to memory cache
self.memory_cache[key] = {
'value': value,
'expires_at': time.time() + ttl
}
self.cache_stats['sets'] += 1
# Clean up expired entries periodically
if len(self.memory_cache) % 100 == 0:
await self._cleanup_memory_cache()
return True
except Exception as e:
logger.error(f"Cache set error for key {key}: {e}")
return False
async def delete(self, key: str) -> bool:
"""Delete value from cache."""
try:
deleted = False
# Delete from Redis
if self.redis_client:
try:
result = await self.redis_client.delete(key)
deleted = deleted or bool(result)
except Exception as e:
logger.warning(f"Redis delete failed for key {key}: {e}")
# Delete from memory cache
if key in self.memory_cache:
del self.memory_cache[key]
deleted = True
if deleted:
self.cache_stats['deletes'] += 1
return deleted
except Exception as e:
logger.error(f"Cache delete error for key {key}: {e}")
return False
async def clear(self) -> bool:
"""Clear all cache entries."""
try:
# Clear Redis
if self.redis_client:
try:
await self.redis_client.flushdb()
except Exception as e:
logger.warning(f"Redis clear failed: {e}")
# Clear memory cache
self.memory_cache.clear()
logger.info("Cache cleared successfully")
return True
except Exception as e:
logger.error(f"Cache clear error: {e}")
return False
async def _cleanup_memory_cache(self):
"""Clean up expired entries from memory cache."""
current_time = time.time()
expired_keys = [
key for key, entry in self.memory_cache.items()
if entry['expires_at'] <= current_time
]
for key in expired_keys:
del self.memory_cache[key]
if expired_keys:
logger.debug(f"Cleaned up {len(expired_keys)} expired cache entries")
def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics."""
total_operations = sum(self.cache_stats.values())
hit_rate = (self.cache_stats['hits'] / total_operations * 100) if total_operations > 0 else 0
return {
**self.cache_stats,
'hit_rate': round(hit_rate, 2),
'memory_cache_size': len(self.memory_cache),
'redis_connected': self.redis_client is not None
}
class PerformanceOptimizer:
"""Performance optimization utilities."""
def __init__(self, cache_manager: Optional[CacheManager] = None):
"""Initialize performance optimizer."""
self.cache_manager = cache_manager or CacheManager()
self.query_cache = {}
self.connection_pools = {}
def cached(self, ttl: int = 3600, key_prefix: str = ""):
"""Decorator for caching function results."""
def decorator(func: Callable):
@wraps(func)
async def async_wrapper(*args, **kwargs):
# Generate cache key
cache_key = self._generate_cache_key(func, args, kwargs, key_prefix)
# Try to get from cache
cached_result = await self.cache_manager.get(cache_key)
if cached_result is not None:
return cached_result
# Execute function and cache result
result = await func(*args, **kwargs)
await self.cache_manager.set(cache_key, result, ttl)
return result
@wraps(func)
def sync_wrapper(*args, **kwargs):
# For sync functions, use memory cache only
cache_key = self._generate_cache_key(func, args, kwargs, key_prefix)
if cache_key in self.query_cache:
cache_entry = self.query_cache[cache_key]
if cache_entry['expires_at'] > time.time():
return cache_entry['value']
else:
del self.query_cache[cache_key]
result = func(*args, **kwargs)
self.query_cache[cache_key] = {
'value': result,
'expires_at': time.time() + ttl
}
return result
# Return appropriate wrapper
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
return decorator
def _generate_cache_key(self, func: Callable, args: tuple, kwargs: dict, prefix: str) -> str:
"""Generate cache key for function call."""
# Create a hash from function name and arguments
func_name = f"{func.__module__}.{func.__name__}"
args_str = str(args) + str(sorted(kwargs.items()))
hash_obj = hashlib.md5((func_name + args_str).encode())
hash_key = hash_obj.hexdigest()
return f"{prefix}:{func_name}:{hash_key}" if prefix else f"{func_name}:{hash_key}"
@lru_cache(maxsize=1000)
def compute_document_hash(self, content: str) -> str:
"""Compute document content hash for caching."""
return hashlib.sha256(content.encode()).hexdigest()
async def batch_process(self, items: list, func: Callable, batch_size: int = 100, delay: float = 0.1):
"""Process items in batches to avoid overwhelming the system."""
results = []
for i in range(0, len(items), batch_size):
batch = items[i:i + batch_size]
# Process batch
if asyncio.iscoroutinefunction(func):
batch_results = await asyncio.gather(*[func(item) for item in batch])
else:
batch_results = [func(item) for item in batch]
results.extend(batch_results)
# Add delay between batches
if delay > 0 and i + batch_size < len(items):
await asyncio.sleep(delay)
return results
class ConnectionPoolManager:
"""Manage database and HTTP connection pools."""
def __init__(self):
"""Initialize connection pool manager."""
self.pools = {}
self.pool_configs = {}
def configure_pool(self, name: str, **config):
"""Configure a connection pool."""
self.pool_configs[name] = config
async def get_pool(self, name: str):
"""Get or create a connection pool."""
if name not in self.pools:
config = self.pool_configs.get(name, {})
if name == 'database':
# Database connection pool would be handled by SQLAlchemy
pass
elif name == 'http':
# HTTP client pool using aiohttp
try:
import aiohttp
connector = aiohttp.TCPConnector(
limit=config.get('max_connections', 100),
limit_per_host=config.get('max_connections_per_host', 30),
ttl_dns_cache=config.get('dns_cache_ttl', 300),
use_dns_cache=True,
)
timeout = aiohttp.ClientTimeout(
total=config.get('timeout', 30),
connect=config.get('connect_timeout', 10)
)
session = aiohttp.ClientSession(
connector=connector,
timeout=timeout
)
self.pools[name] = session
except ImportError:
logger.warning("aiohttp not available for HTTP connection pooling")
return self.pools.get(name)
async def close_all_pools(self):
"""Close all connection pools."""
for name, pool in self.pools.items():
try:
if hasattr(pool, 'close'):
await pool.close()
logger.info(f"Closed connection pool: {name}")
except Exception as e:
logger.error(f"Error closing pool {name}: {e}")
self.pools.clear()
class ResponseCompressor:
"""Compress responses to reduce bandwidth."""
@staticmethod
def compress_text(text: str, method: str = 'gzip') -> bytes:
"""Compress text using specified method."""
import gzip
import zlib
if method == 'gzip':
return gzip.compress(text.encode('utf-8'))
elif method == 'deflate':
return zlib.compress(text.encode('utf-8'))
else:
return text.encode('utf-8')
@staticmethod
def decompress_text(data: bytes, method: str = 'gzip') -> str:
"""Decompress data using specified method."""
import gzip
import zlib
if method == 'gzip':
return gzip.decompress(data).decode('utf-8')
elif method == 'deflate':
return zlib.decompress(data).decode('utf-8')
else:
return data.decode('utf-8')
# Global instances
_cache_manager: Optional[CacheManager] = None
_performance_optimizer: Optional[PerformanceOptimizer] = None
_connection_pool_manager: Optional[ConnectionPoolManager] = None
def get_cache_manager(redis_url: Optional[str] = None) -> CacheManager:
"""Get global cache manager instance."""
global _cache_manager
if _cache_manager is None:
_cache_manager = CacheManager(redis_url)
return _cache_manager
def get_performance_optimizer() -> PerformanceOptimizer:
"""Get global performance optimizer instance."""
global _performance_optimizer
if _performance_optimizer is None:
_performance_optimizer = PerformanceOptimizer(get_cache_manager())
return _performance_optimizer
def get_connection_pool_manager() -> ConnectionPoolManager:
"""Get global connection pool manager instance."""
global _connection_pool_manager
if _connection_pool_manager is None:
_connection_pool_manager = ConnectionPoolManager()
return _connection_pool_manager
# Convenience decorators
def cached(ttl: int = 3600, key_prefix: str = ""):
"""Convenience decorator for caching."""
return get_performance_optimizer().cached(ttl, key_prefix)
async def batch_process(items: list, func: Callable, batch_size: int = 100, delay: float = 0.1):
"""Convenience function for batch processing."""
return await get_performance_optimizer().batch_process(items, func, batch_size, delay)