Skip to main content
Glama
chrismannina

PubMed MCP Server

by chrismannina
utils.py9.22 kB
""" Utility functions and classes for the PubMed MCP Server. This module contains common utilities including caching, rate limiting, and data formatting functions. """ import asyncio import hashlib import logging import re import time from functools import wraps from typing import Any, Callable, Dict, List, Optional # Type: ignore for cachetools since types-cachetools isn't available from cachetools import TTLCache # type: ignore from dateutil import parser # type: ignore logger = logging.getLogger(__name__) class CacheManager: """Enhanced cache manager with TTL and size limits.""" def __init__(self, max_size: int = 1000, ttl: int = 300) -> None: """ Initialize cache manager. Args: max_size: Maximum number of items to store ttl: Time to live in seconds """ self.cache = TTLCache(maxsize=max_size, ttl=ttl) self.stats = {"hits": 0, "misses": 0, "sets": 0} def get(self, key: str) -> Optional[Any]: """Get item from cache.""" try: value = self.cache.get(key) if value is not None: self.stats["hits"] += 1 logger.debug(f"Cache hit for key: {key}") return value else: self.stats["misses"] += 1 logger.debug(f"Cache miss for key: {key}") return None except Exception as e: logger.error(f"Error getting from cache: {e}") return None def set(self, key: str, value: Any) -> None: """Set item in cache.""" try: self.cache[key] = value self.stats["sets"] += 1 logger.debug(f"Cached item with key: {key}") except Exception as e: logger.error(f"Error setting cache: {e}") def generate_key(self, *args: Any, **kwargs: Any) -> str: """Generate a cache key from prefix and parameters.""" key_parts = [str(arg) for arg in args] for k, v in sorted(kwargs.items()): if v is not None: if isinstance(v, (list, dict)): # Convert complex types to string for consistent hashing v = str(sorted(v.items()) if isinstance(v, dict) else sorted(v)) key_parts.append(f"{k}={v}") key_string = ":".join(key_parts) # Use hash for very long keys to avoid key length issues if len(key_string) > 200: return f"{key_string}:{hashlib.md5(key_string.encode()).hexdigest()}" return key_string def clear(self) -> None: """Clear all cached items.""" self.cache.clear() logger.debug("Cache cleared") def get_stats(self) -> Dict[str, Any]: """Get cache statistics.""" total_requests = self.stats["hits"] + self.stats["misses"] hit_rate = self.stats["hits"] / total_requests if total_requests > 0 else 0 return { "size": len(self.cache), "max_size": self.cache.maxsize, "hits": self.stats["hits"], "misses": self.stats["misses"], "hit_rate": round(hit_rate, 3), "sets": self.stats["sets"], } class RateLimiter: """Simple rate limiter using token bucket algorithm.""" def __init__(self, rate: float = 3.0) -> None: """ Initialize rate limiter. Args: rate: Maximum requests per second """ self.rate = rate self.tokens = rate self.last_update = time.time() async def acquire(self) -> None: """Acquire a token (wait if necessary).""" current_time = time.time() elapsed = current_time - self.last_update # Add tokens based on elapsed time self.tokens = min(self.rate, self.tokens + elapsed * self.rate) self.last_update = current_time if self.tokens >= 1: self.tokens -= 1 else: # Calculate wait time for next token wait_time = (1 - self.tokens) / self.rate logger.debug(f"Rate limit reached, waiting {wait_time:.2f}s") await asyncio.sleep(wait_time) self.tokens = 0 def rate_limited(limiter: "RateLimiter") -> Callable[[Callable], Callable]: """Decorator to apply rate limiting to a function.""" def decorator(func: Callable) -> Callable: @wraps(func) async def wrapper(*args, **kwargs): await limiter.acquire() return await func(*args, **kwargs) return wrapper return decorator def format_authors(authors: List[str]) -> str: """Format author list for display.""" if not authors: return "Unknown authors" if len(authors) == 1: return authors[0] elif len(authors) <= 3: return ", ".join(authors[:-1]) + f" and {authors[-1]}" else: return f"{authors[0]} et al." def format_date(date_str: Optional[str]) -> str: """Format publication date for display.""" if not date_str: return "Unknown date" # Handle various date formats from PubMed try: parsed_date = parser.parse(date_str) return parsed_date.strftime("%Y %b %d") except Exception: return date_str def truncate_text(text: str, max_length: int = 300, suffix: str = "...") -> str: """Truncate text to specified length.""" if len(text) <= max_length: return text return text[: max_length - len(suffix)] + suffix def format_mesh_terms(mesh_terms: List[Dict[str, Any]]) -> str: """Format MeSH terms for display.""" if not mesh_terms: return "No MeSH terms" major_terms = [] other_terms = [] for term in mesh_terms: # Handle both dict and MeSHTerm object formats if hasattr(term, "major_topic"): # MeSHTerm object is_major = getattr(term, "major_topic", False) descriptor = getattr(term, "descriptor_name", "") else: # Dictionary format is_major = term.get("major_topic", False) descriptor = term.get("descriptor_name", "") if is_major: major_terms.append(descriptor) else: other_terms.append(descriptor) formatted = [] if major_terms: formatted.append("Major: " + ", ".join(major_terms[:3])) if other_terms: formatted.append("Other: " + ", ".join(other_terms[:5])) return "; ".join(formatted) def build_search_query( base_query: str, authors: Optional[List[str]] = None, journals: Optional[List[str]] = None, mesh_terms: Optional[List[str]] = None, article_types: Optional[List[str]] = None, date_from: Optional[str] = None, date_to: Optional[str] = None, language: Optional[str] = None, has_abstract: Optional[bool] = None, has_full_text: Optional[bool] = None, humans_only: Optional[bool] = None, ) -> str: """Build a complex PubMed search query with filters.""" query_parts = [f"({base_query})"] if authors: author_queries = [f'"{author}"[Author]' for author in authors] query_parts.append(f"({' OR '.join(author_queries)})") if journals: journal_queries = [f'"{journal}"[Journal]' for journal in journals] query_parts.append(f"({' OR '.join(journal_queries)})") if mesh_terms: mesh_queries = [f'"{term}"[MeSH Terms]' for term in mesh_terms] query_parts.append(f"({' OR '.join(mesh_queries)})") if article_types: type_queries = [f'"{article_type}"[Publication Type]' for article_type in article_types] query_parts.append(f"({' OR '.join(type_queries)})") if date_from or date_to: date_query = "" if date_from and date_to: date_query = f'("{date_from}"[Date - Publication] : "{date_to}"[Date - Publication])' elif date_from: date_query = f'"{date_from}"[Date - Publication] : "3000"[Date - Publication]' elif date_to: date_query = f'"1800"[Date - Publication] : "{date_to}"[Date - Publication]' if date_query: query_parts.append(date_query) if language: query_parts.append(f'"{language}"[Language]') if has_abstract: query_parts.append("hasabstract[text word]") if has_full_text: query_parts.append("free full text[sb]") if humans_only: query_parts.append("humans[MeSH Terms]") return " AND ".join(query_parts) def extract_pmids_from_text(text: str) -> List[str]: """Extract PMIDs from text using regex.""" pmid_pattern = r"\b\d{8,9}\b" # PMIDs are typically 8-9 digits pmids = re.findall(pmid_pattern, text) return [pmid for pmid in pmids if validate_pmid(pmid)] def validate_pmid(pmid: str) -> bool: """Validate PMID format.""" if not pmid or not pmid.isdigit(): return False return 7 <= len(pmid) <= 9 # PMIDs are typically 7-9 digits def format_file_size(size_bytes: int) -> str: """Format file size in human readable format.""" for unit in ["B", "KB", "MB", "GB"]: if size_bytes < 1024.0: return f"{size_bytes:.1f} {unit}" size_bytes /= 1024.0 return f"{size_bytes:.1f} TB"

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/chrismannina/pubmed-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server