Skip to main content
Glama

Mutation Clinical Trial Matching MCP

by pickleton89
cache_strategies.py16.1 kB
""" Cache warming strategies and smart invalidation for distributed cache. """ import asyncio import logging import time from collections.abc import Callable from dataclasses import dataclass from typing import Any from clinicaltrials.async_query import query_clinical_trials_async from utils.distributed_cache import get_cache logger = logging.getLogger(__name__) @dataclass class CacheWarmingStrategy: """Configuration for cache warming strategy.""" name: str mutations: list[str] priority: int = 1 schedule: str = "startup" # startup, periodic, on_demand max_concurrent: int = 5 ttl: int | None = None class CacheWarmer: """ Cache warming system for preloading frequently accessed data. """ def __init__(self): self.cache = get_cache() self.strategies: dict[str, CacheWarmingStrategy] = {} self.warming_stats = { "total_warmed": 0, "successful": 0, "failed": 0, "last_warming_time": None, "warming_duration": 0, } def add_strategy(self, strategy: CacheWarmingStrategy): """Add a cache warming strategy.""" self.strategies[strategy.name] = strategy logger.info(f"Added cache warming strategy: {strategy.name}") def remove_strategy(self, name: str): """Remove a cache warming strategy.""" if name in self.strategies: del self.strategies[name] logger.info(f"Removed cache warming strategy: {name}") async def warm_common_mutations(self) -> int: """ Warm cache with common mutations. Returns: Number of items successfully warmed """ common_mutations = [ "EGFR L858R", "EGFR exon 19 deletion", "KRAS G12C", "BRAF V600E", "ALK EML4", "ROS1 CD74", "MET exon 14 skipping", "NTRK fusion", "RET fusion", "ERBB2 amplification", ] strategy = CacheWarmingStrategy( name="common_mutations", mutations=common_mutations, priority=1, max_concurrent=5, ttl=7200, # 2 hours for common mutations ) return await self.execute_strategy(strategy) async def warm_trending_mutations(self) -> int: """ Warm cache with trending mutations based on recent queries. Returns: Number of items successfully warmed """ # In a real implementation, this would analyze recent query patterns # For now, using a predefined list of trending mutations trending_mutations = [ "PIK3CA H1047R", "TP53 R273H", "PTEN loss", "CDKN2A deletion", "FGFR2 fusion", ] strategy = CacheWarmingStrategy( name="trending_mutations", mutations=trending_mutations, priority=2, max_concurrent=3, ttl=3600, # 1 hour for trending mutations ) return await self.execute_strategy(strategy) async def execute_strategy(self, strategy: CacheWarmingStrategy) -> int: """ Execute a cache warming strategy. Args: strategy: The strategy to execute Returns: Number of items successfully warmed """ start_time = time.time() logger.info(f"Starting cache warming strategy: {strategy.name}") # Use semaphore to limit concurrent requests semaphore = asyncio.Semaphore(strategy.max_concurrent) async def warm_mutation(mutation: str) -> bool: async with semaphore: try: # Generate cache key cache_key = f"query:{mutation}:1:10" # Check if already cached cached_result = await self.cache.get_async(cache_key) if cached_result is not None: logger.debug(f"Mutation {mutation} already cached") return True # Query and cache result result = await query_clinical_trials_async(mutation) # Cache the result ttl = strategy.ttl if strategy.ttl else self.cache.default_ttl success = await self.cache.set_async(cache_key, result, ttl) if success: logger.debug(f"Successfully warmed cache for mutation: {mutation}") return True else: logger.warning(f"Failed to cache mutation: {mutation}") return False except Exception as e: logger.error(f"Error warming cache for mutation {mutation}: {e}") return False # Execute warming tasks concurrently tasks = [warm_mutation(mutation) for mutation in strategy.mutations] results = await asyncio.gather(*tasks, return_exceptions=True) # Count successful warmings successful = sum(1 for result in results if result is True) failed = len(results) - successful # Update statistics duration = time.time() - start_time self.warming_stats["total_warmed"] += len(strategy.mutations) self.warming_stats["successful"] += successful self.warming_stats["failed"] += failed self.warming_stats["last_warming_time"] = time.time() self.warming_stats["warming_duration"] = duration logger.info( f"Cache warming strategy '{strategy.name}' completed: " f"{successful}/{len(strategy.mutations)} successful in {duration:.2f}s" ) return successful async def warm_all_strategies(self) -> dict[str, int]: """ Execute all registered cache warming strategies. Returns: Dictionary mapping strategy names to success counts """ results = {} # Sort strategies by priority sorted_strategies = sorted(self.strategies.values(), key=lambda s: s.priority) for strategy in sorted_strategies: results[strategy.name] = await self.execute_strategy(strategy) return results def get_warming_stats(self) -> dict[str, Any]: """Get cache warming statistics.""" return self.warming_stats.copy() class SmartInvalidator: """ Smart cache invalidation system that tracks dependencies and invalidates related data. """ def __init__(self): self.cache = get_cache() self.invalidation_rules: dict[str, list[Callable]] = {} self.invalidation_stats = { "total_invalidations": 0, "pattern_invalidations": 0, "dependency_invalidations": 0, "last_invalidation_time": None, } def add_invalidation_rule(self, trigger: str, rule: Callable[[str], list[str]]): """ Add a smart invalidation rule. Args: trigger: The trigger pattern (e.g., "mutation_update") rule: Function that returns list of cache keys to invalidate """ if trigger not in self.invalidation_rules: self.invalidation_rules[trigger] = [] self.invalidation_rules[trigger].append(rule) logger.info(f"Added invalidation rule for trigger: {trigger}") async def invalidate_mutation_data(self, mutation: str) -> int: """ Invalidate all cache entries related to a specific mutation. Args: mutation: The mutation to invalidate Returns: Number of cache entries invalidated """ patterns = [f"query:{mutation}:*", f"summary:{mutation}:*", f"batch:*{mutation}*"] total_invalidated = 0 for pattern in patterns: invalidated = await self.cache.invalidate_pattern_async(pattern) total_invalidated += invalidated logger.info(f"Invalidated {invalidated} entries for pattern: {pattern}") self.invalidation_stats["pattern_invalidations"] += total_invalidated self.invalidation_stats["total_invalidations"] += total_invalidated self.invalidation_stats["last_invalidation_time"] = time.time() return total_invalidated async def invalidate_by_age(self, max_age: int) -> int: """ Invalidate cache entries older than max_age seconds. Args: max_age: Maximum age in seconds Returns: Number of cache entries invalidated """ # This would require custom Redis logic or metadata tracking # For now, we'll use a simple pattern-based approach logger.info(f"Age-based invalidation not fully implemented (max_age: {max_age}s)") return 0 async def invalidate_low_hit_entries(self, min_hit_count: int = 2) -> int: """ Invalidate cache entries with low hit counts. Args: min_hit_count: Minimum hit count to keep Returns: Number of cache entries invalidated """ # This would require analyzing hit count metadata # For now, we'll use a simple pattern-based approach logger.info(f"Hit-based invalidation not fully implemented (min_hits: {min_hit_count})") return 0 async def trigger_invalidation(self, trigger: str, context: str) -> int: """ Trigger smart invalidation based on a trigger and context. Args: trigger: The trigger type context: Context information for the trigger Returns: Number of cache entries invalidated """ total_invalidated = 0 if trigger in self.invalidation_rules: for rule in self.invalidation_rules[trigger]: try: keys_to_invalidate = rule(context) for key in keys_to_invalidate: success = await self.cache.delete_async(key) if success: total_invalidated += 1 logger.info( f"Invalidated {len(keys_to_invalidate)} entries for trigger: {trigger}" ) except Exception as e: logger.error(f"Error in invalidation rule for trigger {trigger}: {e}") self.invalidation_stats["dependency_invalidations"] += total_invalidated self.invalidation_stats["total_invalidations"] += total_invalidated self.invalidation_stats["last_invalidation_time"] = time.time() return total_invalidated async def invalidate_pattern_async(self, pattern: str) -> int: """ Invalidate cache entries matching a pattern. Args: pattern: Pattern to match cache keys ("*" for all) Returns: Number of entries invalidated """ # Use the cache's built-in pattern invalidation total_invalidated = await self.cache.invalidate_pattern_async(pattern) self.invalidation_stats["pattern_invalidations"] += total_invalidated self.invalidation_stats["total_invalidations"] += total_invalidated self.invalidation_stats["last_invalidation_time"] = time.time() return total_invalidated def get_invalidation_stats(self) -> dict[str, Any]: """Get invalidation statistics.""" return self.invalidation_stats.copy() class CacheAnalytics: """ Cache analytics system for monitoring and optimizing cache performance. """ def __init__(self): self.cache = get_cache() self.warmer = CacheWarmer() self.invalidator = SmartInvalidator() async def get_comprehensive_stats(self) -> dict[str, Any]: """ Get comprehensive cache statistics. Returns: Dictionary with all cache analytics """ cache_stats = self.cache.get_stats() warming_stats = self.warmer.get_warming_stats() invalidation_stats = self.invalidator.get_invalidation_stats() return { "cache": cache_stats, "warming": warming_stats, "invalidation": invalidation_stats, "timestamp": time.time(), } async def analyze_cache_efficiency(self) -> dict[str, Any]: """ Analyze cache efficiency and provide recommendations. Returns: Dictionary with efficiency analysis and recommendations """ stats = await self.get_comprehensive_stats() cache_stats = stats["cache"] # Calculate efficiency metrics hit_rate = cache_stats.get("hit_rate", 0) error_rate = cache_stats.get("errors", 0) / max(cache_stats.get("total_requests", 1), 1) # Generate recommendations recommendations = [] if hit_rate < 0.6: recommendations.append("Consider increasing cache TTL or implementing cache warming") if error_rate > 0.05: recommendations.append("High error rate detected, check Redis connectivity") if cache_stats.get("total_requests", 0) < 100: recommendations.append("Low cache usage, consider promoting cache usage") return { "hit_rate": hit_rate, "error_rate": error_rate, "recommendations": recommendations, "efficiency_score": (hit_rate * 100) - (error_rate * 100), } async def generate_cache_report(self) -> str: """ Generate a formatted cache performance report. Returns: Formatted report string """ stats = await self.get_comprehensive_stats() efficiency = await self.analyze_cache_efficiency() report = f""" # Cache Performance Report ## Cache Statistics - Hit Rate: {stats["cache"]["hit_rate"]:.2%} - Total Requests: {stats["cache"]["total_requests"]:,} - Cache Hits: {stats["cache"]["hits"]:,} - Cache Misses: {stats["cache"]["misses"]:,} - Cache Sets: {stats["cache"]["sets"]:,} - Errors: {stats["cache"]["errors"]:,} ## Cache Warming - Total Warmed: {stats["warming"]["total_warmed"]:,} - Successful: {stats["warming"]["successful"]:,} - Failed: {stats["warming"]["failed"]:,} - Last Warming: {stats["warming"]["last_warming_time"]} ## Cache Invalidation - Total Invalidations: {stats["invalidation"]["total_invalidations"]:,} - Pattern Invalidations: {stats["invalidation"]["pattern_invalidations"]:,} - Dependency Invalidations: {stats["invalidation"]["dependency_invalidations"]:,} ## Efficiency Analysis - Efficiency Score: {efficiency["efficiency_score"]:.1f} - Error Rate: {efficiency["error_rate"]:.2%} ## Recommendations """ for rec in efficiency["recommendations"]: report += f"- {rec}\n" return report # Global instances _cache_warmer: CacheWarmer | None = None _smart_invalidator: SmartInvalidator | None = None _cache_analytics: CacheAnalytics | None = None def get_cache_warmer() -> CacheWarmer: """Get global cache warmer instance.""" global _cache_warmer if _cache_warmer is None: _cache_warmer = CacheWarmer() # Type narrowing by creating local variable warmer = _cache_warmer assert warmer is not None return warmer def get_smart_invalidator() -> SmartInvalidator: """Get global smart invalidator instance.""" global _smart_invalidator if _smart_invalidator is None: _smart_invalidator = SmartInvalidator() # Type narrowing by creating local variable invalidator = _smart_invalidator assert invalidator is not None return invalidator def get_cache_analytics() -> CacheAnalytics: """Get global cache analytics instance.""" global _cache_analytics if _cache_analytics is None: _cache_analytics = CacheAnalytics() # Type narrowing by creating local variable analytics = _cache_analytics assert analytics is not None return analytics

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/pickleton89/mutation-clinical-trial-matching-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server