Skip to main content
Glama
security-error-handling-architecture.md97.2 kB
# Security & Error Handling Architecture ## Reddit MCP Server - Production-Grade Design **Version**: 1.0 **Last Updated**: 2025-11-05 **Status**: Architecture Specification --- ## Table of Contents 1. [Security Architecture](#1-security-architecture) 2. [Error Handling Architecture](#2-error-handling-architecture) 3. [Rate Limit Management](#3-rate-limit-management) 4. [Monitoring & Observability](#4-monitoring--observability) 5. [Fault Tolerance & Recovery](#5-fault-tolerance--recovery) 6. [Implementation Checklist](#6-implementation-checklist) --- ## 1. Security Architecture ### 1.1 Authentication & Authorization #### OAuth2 Flow for Reddit API **Implementation Strategy:** ```python # src/auth/reddit_oauth.py from typing import Optional from datetime import datetime, timedelta import httpx from pydantic import BaseModel, SecretStr class RedditOAuthConfig(BaseModel): """Secure configuration for Reddit OAuth.""" client_id: str client_secret: SecretStr # Never logged or printed redirect_uri: str user_agent: str class OAuthToken(BaseModel): """OAuth token with metadata.""" access_token: SecretStr refresh_token: Optional[SecretStr] = None expires_at: datetime token_type: str = "bearer" scope: str class RedditAuthManager: """Handles Reddit OAuth2 authentication lifecycle.""" def __init__(self, config: RedditOAuthConfig, kv_store): self.config = config self.kv_store = kv_store # Apify KV Store for encrypted storage self._current_token: Optional[OAuthToken] = None async def get_authorization_url(self, state: str) -> str: """Generate OAuth2 authorization URL for user consent.""" params = { "client_id": self.config.client_id, "response_type": "code", "state": state, # CSRF protection "redirect_uri": self.config.redirect_uri, "duration": "permanent", # Get refresh token "scope": "read identity mysubreddits" } return f"https://www.reddit.com/api/v1/authorize?{urlencode(params)}" async def exchange_code_for_token(self, code: str) -> OAuthToken: """Exchange authorization code for access token.""" async with httpx.AsyncClient() as client: response = await client.post( "https://www.reddit.com/api/v1/access_token", auth=(self.config.client_id, self.config.client_secret.get_secret_value()), data={ "grant_type": "authorization_code", "code": code, "redirect_uri": self.config.redirect_uri }, headers={"User-Agent": self.config.user_agent} ) response.raise_for_status() data = response.json() token = OAuthToken( access_token=SecretStr(data["access_token"]), refresh_token=SecretStr(data.get("refresh_token")), expires_at=datetime.utcnow() + timedelta(seconds=data["expires_in"]), scope=data["scope"] ) # Store encrypted in Apify KV Store await self._store_token(token) return token async def get_valid_token(self) -> str: """Get valid access token, refreshing if needed.""" if not self._current_token: self._current_token = await self._load_token() # Check if token expires in next 5 minutes if datetime.utcnow() >= (self._current_token.expires_at - timedelta(minutes=5)): self._current_token = await self._refresh_token() return self._current_token.access_token.get_secret_value() async def _refresh_token(self) -> OAuthToken: """Refresh expired OAuth token.""" if not self._current_token or not self._current_token.refresh_token: raise AuthenticationError("No refresh token available") async with httpx.AsyncClient() as client: response = await client.post( "https://www.reddit.com/api/v1/access_token", auth=(self.config.client_id, self.config.client_secret.get_secret_value()), data={ "grant_type": "refresh_token", "refresh_token": self._current_token.refresh_token.get_secret_value() }, headers={"User-Agent": self.config.user_agent} ) response.raise_for_status() data = response.json() new_token = OAuthToken( access_token=SecretStr(data["access_token"]), refresh_token=self._current_token.refresh_token, # Reuse existing expires_at=datetime.utcnow() + timedelta(seconds=data["expires_in"]), scope=data["scope"] ) await self._store_token(new_token) return new_token async def _store_token(self, token: OAuthToken) -> None: """Store token encrypted in Apify KV Store.""" # Apify KV Store handles encryption at rest await self.kv_store.set_value( "reddit_oauth_token", { "access_token": token.access_token.get_secret_value(), "refresh_token": token.refresh_token.get_secret_value() if token.refresh_token else None, "expires_at": token.expires_at.isoformat(), "scope": token.scope } ) async def _load_token(self) -> OAuthToken: """Load token from KV Store.""" data = await self.kv_store.get_value("reddit_oauth_token") if not data: raise AuthenticationError("No stored token found") return OAuthToken( access_token=SecretStr(data["access_token"]), refresh_token=SecretStr(data["refresh_token"]) if data.get("refresh_token") else None, expires_at=datetime.fromisoformat(data["expires_at"]), scope=data["scope"] ) ``` #### MCP Client Authentication **Bearer Token Authentication for MCP Clients:** ```python # src/auth/mcp_auth.py from typing import Optional import hashlib import secrets from fastapi import HTTPException, Security from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials security = HTTPBearer() class MCPAuthManager: """Authenticates MCP clients connecting to the server.""" def __init__(self, kv_store): self.kv_store = kv_store self.valid_tokens: dict[str, dict] = {} # In-memory cache async def generate_api_key(self, user_id: str, name: str) -> str: """Generate a new API key for a user.""" # Generate cryptographically secure random token token = secrets.token_urlsafe(32) token_hash = hashlib.sha256(token.encode()).hexdigest() # Store hash (never store plaintext!) await self.kv_store.set_value( f"api_key:{token_hash}", { "user_id": user_id, "name": name, "created_at": datetime.utcnow().isoformat(), "last_used": None, "rate_limit_tier": "free" # or "paid" } ) return token # Return once, never stored async def validate_token( self, credentials: HTTPAuthorizationCredentials = Security(security) ) -> dict: """Validate bearer token and return user info.""" token = credentials.credentials token_hash = hashlib.sha256(token.encode()).hexdigest() # Check cache first if token_hash in self.valid_tokens: return self.valid_tokens[token_hash] # Load from KV Store token_data = await self.kv_store.get_value(f"api_key:{token_hash}") if not token_data: raise HTTPException( status_code=401, detail="Invalid API key", headers={"WWW-Authenticate": "Bearer"} ) # Update last used timestamp token_data["last_used"] = datetime.utcnow().isoformat() await self.kv_store.set_value(f"api_key:{token_hash}", token_data) # Cache for 5 minutes self.valid_tokens[token_hash] = token_data return token_data ``` #### Rate Limiting Per User **Per-User Rate Limiting:** ```python # src/auth/rate_limiter.py from datetime import datetime, timedelta from collections import defaultdict from typing import Optional import asyncio class UserRateLimiter: """Rate limiter with per-user quotas.""" def __init__(self, kv_store): self.kv_store = kv_store self.quotas = { "free": {"requests_per_minute": 10, "requests_per_hour": 300}, "paid": {"requests_per_minute": 50, "requests_per_hour": 2000} } async def check_rate_limit(self, user_id: str, tier: str) -> dict: """Check if user has exceeded rate limit.""" now = datetime.utcnow() key = f"rate_limit:{user_id}" # Get current usage usage = await self.kv_store.get_value(key) or { "minute_requests": [], "hour_requests": [] } # Clean old entries minute_ago = now - timedelta(minutes=1) hour_ago = now - timedelta(hours=1) usage["minute_requests"] = [ req for req in usage["minute_requests"] if datetime.fromisoformat(req) > minute_ago ] usage["hour_requests"] = [ req for req in usage["hour_requests"] if datetime.fromisoformat(req) > hour_ago ] # Check limits quota = self.quotas[tier] minute_count = len(usage["minute_requests"]) hour_count = len(usage["hour_requests"]) if minute_count >= quota["requests_per_minute"]: return { "allowed": False, "reason": "Rate limit exceeded (per minute)", "retry_after_seconds": 60, "quota_remaining": 0, "quota_limit": quota["requests_per_minute"] } if hour_count >= quota["requests_per_hour"]: return { "allowed": False, "reason": "Rate limit exceeded (per hour)", "retry_after_seconds": 3600, "quota_remaining": 0, "quota_limit": quota["requests_per_hour"] } # Record this request usage["minute_requests"].append(now.isoformat()) usage["hour_requests"].append(now.isoformat()) await self.kv_store.set_value(key, usage, expiration_ttl=3600) return { "allowed": True, "quota_remaining": quota["requests_per_minute"] - minute_count - 1, "quota_limit": quota["requests_per_minute"] } ``` --- ### 1.2 Secrets Management #### Apify Console Configuration **Environment Variable Strategy:** ```yaml # .actor/actor.json { "name": "reddit-mcp-server", "version": "1.0.0", "environmentVariables": { "REDDIT_CLIENT_ID": { "type": "string", "description": "Reddit OAuth2 Client ID", "isSecret": true }, "REDDIT_CLIENT_SECRET": { "type": "string", "description": "Reddit OAuth2 Client Secret", "isSecret": true }, "MCP_MASTER_KEY": { "type": "string", "description": "Master key for encrypting API keys", "isSecret": true }, "REDIS_URL": { "type": "string", "description": "Redis connection URL with credentials", "isSecret": true }, "SENTRY_DSN": { "type": "string", "description": "Sentry error tracking DSN", "isSecret": true, "optional": true } } } ``` **Configuration Loader:** ```python # src/config/settings.py from pydantic import BaseSettings, SecretStr, validator from typing import Optional class Settings(BaseSettings): """Application settings with secure defaults.""" # Reddit OAuth reddit_client_id: str reddit_client_secret: SecretStr reddit_redirect_uri: str = "http://localhost:3000/auth/callback" reddit_user_agent: str = "RedditMCPServer/1.0" # MCP Server mcp_master_key: SecretStr mcp_server_host: str = "0.0.0.0" mcp_server_port: int = 8080 # Redis Cache redis_url: SecretStr redis_key_prefix: str = "reddit_mcp:" # Monitoring sentry_dsn: Optional[SecretStr] = None log_level: str = "INFO" environment: str = "production" # Security cors_origins: list[str] = ["http://localhost:3000"] max_request_size: int = 1_000_000 # 1MB @validator("environment") def validate_environment(cls, v): allowed = ["development", "staging", "production"] if v not in allowed: raise ValueError(f"Environment must be one of {allowed}") return v class Config: env_file = ".env" env_file_encoding = "utf-8" # Prevent secrets from appearing in repr() json_encoders = { SecretStr: lambda v: "***REDACTED***" if v else None } # Global settings instance settings = Settings() # Validate on startup def validate_secrets(): """Ensure all required secrets are present.""" required = [ "reddit_client_id", "reddit_client_secret", "mcp_master_key", "redis_url" ] missing = [key for key in required if not getattr(settings, key)] if missing: raise ValueError(f"Missing required secrets: {missing}") ``` #### Secret Rotation Strategy **Monthly Secret Rotation Plan:** ```python # src/auth/rotation.py from datetime import datetime, timedelta from typing import Optional class SecretRotationManager: """Manages secret rotation lifecycle.""" def __init__(self, kv_store): self.kv_store = kv_store self.rotation_period = timedelta(days=90) # 90 days async def check_rotation_needed(self, secret_name: str) -> bool: """Check if secret needs rotation.""" metadata = await self.kv_store.get_value(f"secret_metadata:{secret_name}") if not metadata: return False last_rotated = datetime.fromisoformat(metadata["last_rotated"]) return datetime.utcnow() - last_rotated > self.rotation_period async def rotate_api_keys(self) -> dict: """Rotate all API keys gracefully.""" # Step 1: Generate new keys new_keys = await self._generate_new_keys() # Step 2: Dual-run period (both old and new valid for 7 days) await self._enable_dual_mode(new_keys) # Step 3: Notify users await self._notify_users_of_rotation() # Step 4: Schedule old key deactivation deactivation_date = datetime.utcnow() + timedelta(days=7) await self._schedule_deactivation(deactivation_date) return { "status": "rotation_initiated", "dual_mode_until": deactivation_date.isoformat(), "affected_users": len(new_keys) } # Automated rotation check (runs daily) async def rotation_check_task(): """Background task to check for needed rotations.""" manager = SecretRotationManager(kv_store) secrets_to_check = [ "reddit_oauth_token", "mcp_master_key", "redis_password" ] for secret in secrets_to_check: if await manager.check_rotation_needed(secret): logger.warning( f"Secret rotation needed: {secret}", extra={"secret_name": secret} ) # Trigger alert in monitoring system ``` --- ### 1.3 Data Privacy & Compliance #### GDPR Compliance **Data Subject Rights Implementation:** ```python # src/privacy/gdpr.py from typing import Optional from datetime import datetime, timedelta class GDPRManager: """Handles GDPR compliance operations.""" def __init__(self, kv_store, datasets): self.kv_store = kv_store self.datasets = datasets async def export_user_data(self, user_id: str) -> dict: """Export all data for a user (Right to Data Portability).""" return { "user_id": user_id, "exported_at": datetime.utcnow().isoformat(), "data": { "profile": await self._get_user_profile(user_id), "api_keys": await self._get_api_keys(user_id), "request_history": await self._get_request_history(user_id), "cached_data": await self._get_cached_data(user_id) } } async def delete_user_data(self, user_id: str) -> dict: """Delete all user data (Right to Erasure).""" deleted = { "profile": await self._delete_user_profile(user_id), "api_keys": await self._revoke_all_keys(user_id), "request_history": await self._delete_request_history(user_id), "cached_data": await self._delete_cached_data(user_id) } # Log deletion for audit trail (keep for 1 year) await self.kv_store.set_value( f"gdpr_deletion:{user_id}", { "deleted_at": datetime.utcnow().isoformat(), "items_deleted": deleted }, expiration_ttl=31536000 # 1 year ) return deleted async def anonymize_logs(self, user_id: str) -> int: """Anonymize user data in logs (Right to Erasure).""" # Replace user_id with anonymous identifier in logs anon_id = hashlib.sha256(f"{user_id}:salt".encode()).hexdigest()[:16] # This would integrate with your logging system # Example: Update Elasticsearch/CloudWatch logs count = await self._update_log_entries(user_id, anon_id) return count ``` #### Data Retention Policies **Automatic Data Cleanup:** ```python # src/privacy/retention.py from datetime import datetime, timedelta class RetentionPolicy: """Enforces data retention policies.""" RETENTION_PERIODS = { "request_logs": timedelta(days=30), "error_logs": timedelta(days=90), "cached_reddit_data": timedelta(hours=24), "user_activity": timedelta(days=365), "audit_logs": timedelta(days=2555), # 7 years (compliance) } def __init__(self, kv_store, datasets): self.kv_store = kv_store self.datasets = datasets async def cleanup_expired_data(self) -> dict: """Remove data past retention period.""" now = datetime.utcnow() results = {} for data_type, retention_period in self.RETENTION_PERIODS.items(): cutoff_date = now - retention_period deleted = await self._delete_before_date(data_type, cutoff_date) results[data_type] = deleted return results async def _delete_before_date(self, data_type: str, cutoff: datetime) -> int: """Delete records older than cutoff date.""" # Implementation depends on storage backend if data_type == "cached_reddit_data": # Redis handles this with TTL automatically return 0 # For Apify datasets query = f"created_at < '{cutoff.isoformat()}'" deleted_count = await self.datasets.delete_items(query) return deleted_count # Daily cleanup task async def retention_cleanup_task(): """Background task for data cleanup.""" policy = RetentionPolicy(kv_store, datasets) results = await policy.cleanup_expired_data() logger.info( "Data retention cleanup completed", extra={"deleted_counts": results} ) ``` #### Logging Restrictions (No PII) **PII-Safe Logging:** ```python # src/logging/safe_logger.py import logging import re from typing import Any import json class PIISafeLogger: """Logger that automatically redacts PII.""" PII_PATTERNS = { "email": re.compile(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b'), "reddit_username": re.compile(r'(?:^|/)u/([A-Za-z0-9_-]+)'), "api_key": re.compile(r'Bearer\s+([A-Za-z0-9_-]{20,})'), "token": re.compile(r'token["\']?\s*[:=]\s*["\']?([A-Za-z0-9_-]{20,})'), "password": re.compile(r'password["\']?\s*[:=]\s*["\']?([^\s"\']+)'), } def __init__(self, name: str): self.logger = logging.getLogger(name) def _redact_pii(self, message: str) -> str: """Redact PII from log messages.""" for pii_type, pattern in self.PII_PATTERNS.items(): message = pattern.sub(f"[REDACTED:{pii_type.upper()}]", message) return message def _redact_dict(self, data: dict) -> dict: """Recursively redact PII from dictionaries.""" redacted = {} sensitive_keys = {"password", "token", "api_key", "secret", "authorization"} for key, value in data.items(): if any(sensitive in key.lower() for sensitive in sensitive_keys): redacted[key] = "***REDACTED***" elif isinstance(value, dict): redacted[key] = self._redact_dict(value) elif isinstance(value, str): redacted[key] = self._redact_pii(value) else: redacted[key] = value return redacted def info(self, message: str, extra: dict = None): """Log info with PII redaction.""" safe_message = self._redact_pii(message) safe_extra = self._redact_dict(extra) if extra else None self.logger.info(safe_message, extra=safe_extra) def error(self, message: str, exc_info: bool = False, extra: dict = None): """Log error with PII redaction.""" safe_message = self._redact_pii(message) safe_extra = self._redact_dict(extra) if extra else None self.logger.error(safe_message, exc_info=exc_info, extra=safe_extra) # Usage logger = PIISafeLogger(__name__) logger.info( "User request processed", extra={ "user_id": user_id, # Hash or UUID (not username) "tool": "search_reddit", "query": "AI trends", # Safe to log "api_key": api_key # Will be redacted } ) ``` --- ### 1.4 Input Validation #### Pydantic Models for All Inputs **Comprehensive Validation Models:** ```python # src/validation/models.py from pydantic import BaseModel, Field, validator, constr from typing import Optional, Literal from datetime import datetime class SearchRedditInput(BaseModel): """Validated input for search_reddit tool.""" query: constr(min_length=1, max_length=500) = Field( ..., description="Search query or keywords", example="artificial intelligence" ) subreddit: Optional[constr(regex=r'^[A-Za-z0-9_]+$', max_length=21)] = Field( None, description="Limit search to specific subreddit", example="technology" ) time_filter: Literal["hour", "day", "week", "month", "year", "all"] = Field( "week", description="Time range for search results" ) sort: Literal["relevance", "hot", "top", "new", "comments"] = Field( "relevance", description="Sort order for results" ) limit: int = Field( 25, ge=1, le=100, description="Maximum number of results" ) @validator("query") def sanitize_query(cls, v): """Sanitize query for SQL injection and XSS.""" # Remove potentially dangerous characters dangerous = ["<", ">", "'", '"', ";", "--", "/*", "*/"] for char in dangerous: if char in v: raise ValueError(f"Query contains forbidden character: {char}") # Trim whitespace v = v.strip() if not v: raise ValueError("Query cannot be empty after sanitization") return v @validator("subreddit") def validate_subreddit(cls, v): """Validate subreddit name format.""" if v is None: return v # Reddit subreddit rules if len(v) < 3: raise ValueError("Subreddit name must be at least 3 characters") if v.startswith("_") or v.endswith("_"): raise ValueError("Subreddit name cannot start or end with underscore") return v class GetPostCommentsInput(BaseModel): """Validated input for get_post_comments tool.""" post_id: constr(regex=r'^(t3_)?[a-z0-9]{5,7}$') = Field( ..., description="Reddit post ID (with or without t3_ prefix)", example="t3_abc123" ) sort: Literal["best", "top", "new", "controversial", "old"] = Field( "best", description="Comment sort order" ) max_depth: int = Field( 0, ge=0, le=10, description="Maximum comment nesting depth (0 = unlimited)" ) @validator("post_id") def normalize_post_id(cls, v): """Ensure post_id has t3_ prefix.""" if not v.startswith("t3_"): return f"t3_{v}" return v class WatchKeywordsInput(BaseModel): """Validated input for watch_keywords tool.""" keywords: list[constr(min_length=1, max_length=100)] = Field( ..., description="Keywords to monitor", min_items=1, max_items=10 ) subreddits: Optional[list[constr(regex=r'^[A-Za-z0-9_]+$')]] = Field( None, description="Subreddits to monitor (all if empty)", max_items=20 ) alert_threshold: Literal["any", "high_engagement", "viral"] = Field( "high_engagement", description="When to trigger alerts" ) check_interval: int = Field( 15, ge=5, le=1440, description="Minutes between checks" ) @validator("keywords") def validate_keywords(cls, v): """Validate and sanitize keywords.""" sanitized = [] for keyword in v: # Remove special regex characters clean = re.sub(r'[^\w\s-]', '', keyword) if clean: sanitized.append(clean.lower()) if not sanitized: raise ValueError("No valid keywords after sanitization") # Remove duplicates return list(set(sanitized)) # Validation middleware async def validate_input(tool_name: str, raw_input: dict) -> BaseModel: """Validate input based on tool name.""" validation_map = { "search_reddit": SearchRedditInput, "get_post_comments": GetPostCommentsInput, "watch_keywords": WatchKeywordsInput, # ... add all tools } validator_class = validation_map.get(tool_name) if not validator_class: raise ValueError(f"No validator for tool: {tool_name}") try: return validator_class(**raw_input) except ValidationError as e: # Transform Pydantic error to MCP error format raise MCPValidationError( code=-32602, message="Invalid parameters", data={"errors": e.errors()} ) ``` #### Length Limits & Pattern Validation **Security Limits Configuration:** ```python # src/validation/limits.py from dataclasses import dataclass @dataclass class SecurityLimits: """Security limits for inputs.""" # String lengths MAX_QUERY_LENGTH: int = 500 MAX_SUBREDDIT_NAME: int = 21 # Reddit's limit MAX_USERNAME_LENGTH: int = 20 # Reddit's limit MAX_POST_ID_LENGTH: int = 10 MAX_KEYWORDS_PER_WATCH: int = 10 MAX_SUBREDDITS_PER_WATCH: int = 20 # Array sizes MAX_SEARCH_RESULTS: int = 100 MAX_COMMENTS_DEPTH: int = 10 MAX_BATCH_REQUESTS: int = 5 # Rate limits MAX_REQUESTS_PER_MINUTE: int = 10 # Free tier MAX_REQUESTS_PER_HOUR: int = 300 # Payload sizes MAX_REQUEST_BODY_SIZE: int = 1_000_000 # 1MB MAX_CACHE_VALUE_SIZE: int = 10_000_000 # 10MB # Timeouts REQUEST_TIMEOUT_SECONDS: int = 30 CACHE_TTL_MAX: int = 86400 # 24 hours LIMITS = SecurityLimits() # Validation decorator def enforce_limits(func): """Decorator to enforce security limits.""" @wraps(func) async def wrapper(*args, **kwargs): # Check request size request_size = len(json.dumps(kwargs).encode()) if request_size > LIMITS.MAX_REQUEST_BODY_SIZE: raise RequestTooLargeError( f"Request body exceeds {LIMITS.MAX_REQUEST_BODY_SIZE} bytes" ) return await func(*args, **kwargs) return wrapper ``` --- ## 2. Error Handling Architecture ### 2.1 Error Classification **Hierarchical Error System:** ```python # src/errors/exceptions.py from enum import Enum from typing import Optional, Any class ErrorCategory(Enum): """Error categories for classification.""" CLIENT_ERROR = "client_error" # 4xx SERVER_ERROR = "server_error" # 5xx RATE_LIMIT = "rate_limit" # 429 EXTERNAL_API = "external_api" # Reddit API issues VALIDATION = "validation" AUTHENTICATION = "authentication" TIMEOUT = "timeout" class MCPError(Exception): """Base exception for all MCP errors.""" def __init__( self, code: int, message: str, category: ErrorCategory, data: Optional[dict] = None, retry_after: Optional[int] = None ): self.code = code self.message = message self.category = category self.data = data or {} self.retry_after = retry_after super().__init__(message) def to_json_rpc(self) -> dict: """Convert to JSON-RPC 2.0 error format.""" error = { "code": self.code, "message": self.message, "data": { "category": self.category.value, **self.data } } if self.retry_after: error["data"]["retry_after_seconds"] = self.retry_after return {"error": error} # Client Errors (4xx) class ValidationError(MCPError): """Invalid input parameters.""" def __init__(self, message: str, field: str = None, **kwargs): data = {"field": field} if field else {} data.update(kwargs) super().__init__( code=-32602, # Invalid params message=message, category=ErrorCategory.VALIDATION, data=data ) class AuthenticationError(MCPError): """Authentication failed.""" def __init__(self, message: str = "Authentication failed"): super().__init__( code=-32001, # Custom: Unauthorized message=message, category=ErrorCategory.AUTHENTICATION ) class AuthorizationError(MCPError): """Insufficient permissions.""" def __init__(self, message: str = "Insufficient permissions"): super().__init__( code=-32002, # Custom: Forbidden message=message, category=ErrorCategory.AUTHENTICATION ) class NotFoundError(MCPError): """Resource not found.""" def __init__(self, resource: str): super().__init__( code=-32003, # Custom: Not Found message=f"{resource} not found", category=ErrorCategory.CLIENT_ERROR, data={"resource": resource} ) # Rate Limit Errors class RateLimitError(MCPError): """Rate limit exceeded.""" def __init__(self, retry_after: int, limit_type: str = "user"): super().__init__( code=-32004, # Custom: Rate Limited message=f"Rate limit exceeded. Please retry after {retry_after} seconds", category=ErrorCategory.RATE_LIMIT, retry_after=retry_after, data={"limit_type": limit_type} ) class RedditRateLimitError(RateLimitError): """Reddit API rate limit.""" def __init__(self, retry_after: int): super().__init__( retry_after=retry_after, limit_type="reddit_api" ) self.message = f"Reddit API rate limit reached. Request queued, estimated wait: {retry_after}s" # Server Errors (5xx) class InternalServerError(MCPError): """Unexpected server error.""" def __init__(self, message: str = "Internal server error", trace_id: str = None): super().__init__( code=-32603, # Internal error message=message, category=ErrorCategory.SERVER_ERROR, data={"trace_id": trace_id} if trace_id else {} ) class ServiceUnavailableError(MCPError): """Service temporarily unavailable.""" def __init__(self, service: str, retry_after: int = 60): super().__init__( code=-32005, # Custom: Service Unavailable message=f"{service} is temporarily unavailable", category=ErrorCategory.SERVER_ERROR, retry_after=retry_after, data={"service": service} ) # External Errors class RedditAPIError(MCPError): """Reddit API returned an error.""" def __init__(self, status_code: int, message: str): super().__init__( code=-32006, # Custom: External API Error message=f"Reddit API error: {message}", category=ErrorCategory.EXTERNAL_API, data={ "status_code": status_code, "upstream_service": "reddit" } ) class TimeoutError(MCPError): """Request timeout.""" def __init__(self, operation: str, timeout_seconds: int): super().__init__( code=-32007, # Custom: Timeout message=f"Operation '{operation}' timed out after {timeout_seconds}s", category=ErrorCategory.TIMEOUT, data={"operation": operation, "timeout": timeout_seconds} ) ``` ### 2.2 Error Response Format **Standardized MCP Error Responses:** ```python # src/errors/formatter.py from typing import Optional import traceback import uuid class ErrorResponseFormatter: """Formats errors for MCP responses.""" def __init__(self, environment: str): self.environment = environment self.include_traces = environment == "development" def format_error( self, error: Exception, request_id: str = None ) -> dict: """Format error for JSON-RPC 2.0 response.""" # Generate trace ID for error tracking trace_id = request_id or str(uuid.uuid4()) # Handle known MCP errors if isinstance(error, MCPError): response = error.to_json_rpc() response["error"]["data"]["trace_id"] = trace_id if self.include_traces: response["error"]["data"]["stack_trace"] = traceback.format_exc() return response # Handle unexpected errors return { "error": { "code": -32603, "message": self._get_safe_message(error), "data": { "category": "server_error", "trace_id": trace_id, "error_type": type(error).__name__, **({"stack_trace": traceback.format_exc()} if self.include_traces else {}) } } } def _get_safe_message(self, error: Exception) -> str: """Get user-safe error message.""" if self.environment == "production": return "An unexpected error occurred. Please contact support with trace ID." return str(error) # Global error handler async def handle_error( error: Exception, request_id: str, logger: PIISafeLogger ) -> dict: """Global error handling pipeline.""" formatter = ErrorResponseFormatter(settings.environment) response = formatter.format_error(error, request_id) # Log error with appropriate severity log_level = "error" if isinstance(error, MCPError) else "critical" logger_method = getattr(logger, log_level) logger_method( f"Error in request {request_id}", exc_info=True, extra={ "trace_id": request_id, "error_code": response["error"]["code"], "error_category": response["error"]["data"].get("category") } ) # Send to error tracking (Sentry) if settings.sentry_dsn: sentry_sdk.capture_exception(error) return response ``` ### 2.3 Retry Strategy **Exponential Backoff with Jitter:** ```python # src/retry/backoff.py import asyncio import random from typing import Callable, TypeVar, Optional from functools import wraps T = TypeVar('T') class RetryConfig: """Configuration for retry behavior.""" def __init__( self, max_attempts: int = 3, base_delay: float = 1.0, max_delay: float = 60.0, exponential_base: float = 2.0, jitter: bool = True ): self.max_attempts = max_attempts self.base_delay = base_delay self.max_delay = max_delay self.exponential_base = exponential_base self.jitter = jitter def calculate_delay(self, attempt: int) -> float: """Calculate delay for attempt with exponential backoff and jitter.""" delay = min( self.base_delay * (self.exponential_base ** attempt), self.max_delay ) if self.jitter: # Add random jitter (±25%) jitter_range = delay * 0.25 delay += random.uniform(-jitter_range, jitter_range) return max(0, delay) class RetryStrategy: """Intelligent retry strategy with circuit breaker.""" # Different retry configs for different error types RETRY_CONFIGS = { ErrorCategory.RATE_LIMIT: RetryConfig( max_attempts=5, base_delay=2.0, max_delay=300.0 ), ErrorCategory.EXTERNAL_API: RetryConfig( max_attempts=3, base_delay=1.0, max_delay=30.0 ), ErrorCategory.TIMEOUT: RetryConfig( max_attempts=2, base_delay=0.5, max_delay=10.0 ), } # Don't retry these error types NO_RETRY_CATEGORIES = { ErrorCategory.CLIENT_ERROR, ErrorCategory.VALIDATION, ErrorCategory.AUTHENTICATION } def __init__(self, logger: PIISafeLogger): self.logger = logger async def execute_with_retry( self, func: Callable[..., T], *args, operation_name: str = None, **kwargs ) -> T: """Execute function with retry logic.""" operation_name = operation_name or func.__name__ attempt = 0 while True: try: result = await func(*args, **kwargs) if attempt > 0: self.logger.info( f"Operation '{operation_name}' succeeded after {attempt} retries" ) return result except MCPError as error: # Check if error is retryable if error.category in self.NO_RETRY_CATEGORIES: self.logger.info( f"Non-retryable error in '{operation_name}': {error.message}" ) raise # Get retry config for this error type config = self.RETRY_CONFIGS.get( error.category, RetryConfig() # Default config ) attempt += 1 if attempt >= config.max_attempts: self.logger.error( f"Operation '{operation_name}' failed after {attempt} attempts", extra={"error": str(error)} ) raise # Calculate delay if error.retry_after: delay = error.retry_after else: delay = config.calculate_delay(attempt) self.logger.warning( f"Retrying '{operation_name}' after {delay:.2f}s (attempt {attempt}/{config.max_attempts})", extra={"error_category": error.category.value} ) await asyncio.sleep(delay) except Exception as error: # Unexpected error - retry once if attempt > 0: self.logger.error( f"Unexpected error in '{operation_name}' after retry", exc_info=True ) raise attempt += 1 delay = 1.0 self.logger.warning( f"Unexpected error in '{operation_name}', retrying after {delay}s", exc_info=True ) await asyncio.sleep(delay) # Decorator for easy retry def with_retry( operation_name: str = None, max_attempts: int = 3 ): """Decorator to add retry logic to async functions.""" def decorator(func: Callable): @wraps(func) async def wrapper(*args, **kwargs): strategy = RetryStrategy(logger) return await strategy.execute_with_retry( func, *args, operation_name=operation_name or func.__name__, **kwargs ) return wrapper return decorator # Usage example @with_retry(operation_name="fetch_reddit_posts", max_attempts=3) async def fetch_posts_from_reddit(subreddit: str): """Fetch posts with automatic retry.""" async with reddit_client.get(f"/r/{subreddit}/hot.json") as response: if response.status == 429: retry_after = int(response.headers.get("Retry-After", 60)) raise RedditRateLimitError(retry_after) response.raise_for_status() return await response.json() ``` ### 2.4 Circuit Breaker Pattern **Fault Isolation:** ```python # src/resilience/circuit_breaker.py from enum import Enum from datetime import datetime, timedelta from typing import Callable, TypeVar from collections import deque T = TypeVar('T') class CircuitState(Enum): """Circuit breaker states.""" CLOSED = "closed" # Normal operation OPEN = "open" # Failing, reject requests HALF_OPEN = "half_open" # Testing if service recovered class CircuitBreaker: """Circuit breaker for external service calls.""" def __init__( self, service_name: str, failure_threshold: int = 5, recovery_timeout: int = 60, half_open_requests: int = 3, success_threshold: int = 2 ): self.service_name = service_name self.failure_threshold = failure_threshold self.recovery_timeout = timedelta(seconds=recovery_timeout) self.half_open_requests = half_open_requests self.success_threshold = success_threshold self.state = CircuitState.CLOSED self.failure_count = 0 self.success_count = 0 self.last_failure_time: Optional[datetime] = None self.half_open_attempts = 0 # Track recent errors for monitoring self.recent_errors = deque(maxlen=100) async def call(self, func: Callable[..., T], *args, **kwargs) -> T: """Execute function with circuit breaker protection.""" # Check circuit state if self.state == CircuitState.OPEN: # Check if recovery timeout has elapsed if datetime.utcnow() - self.last_failure_time > self.recovery_timeout: logger.info(f"Circuit breaker for {self.service_name}: OPEN -> HALF_OPEN") self.state = CircuitState.HALF_OPEN self.half_open_attempts = 0 else: raise ServiceUnavailableError( service=self.service_name, retry_after=int(self.recovery_timeout.total_seconds()) ) if self.state == CircuitState.HALF_OPEN: if self.half_open_attempts >= self.half_open_requests: raise ServiceUnavailableError( service=self.service_name, retry_after=30 ) self.half_open_attempts += 1 # Execute the function try: result = await func(*args, **kwargs) self._on_success() return result except Exception as error: self._on_failure(error) raise def _on_success(self): """Handle successful call.""" if self.state == CircuitState.HALF_OPEN: self.success_count += 1 if self.success_count >= self.success_threshold: logger.info(f"Circuit breaker for {self.service_name}: HALF_OPEN -> CLOSED") self.state = CircuitState.CLOSED self.failure_count = 0 self.success_count = 0 elif self.state == CircuitState.CLOSED: # Reset failure count on success self.failure_count = max(0, self.failure_count - 1) def _on_failure(self, error: Exception): """Handle failed call.""" self.failure_count += 1 self.last_failure_time = datetime.utcnow() self.recent_errors.append({ "timestamp": self.last_failure_time.isoformat(), "error": str(error) }) if self.state == CircuitState.HALF_OPEN: logger.warning(f"Circuit breaker for {self.service_name}: HALF_OPEN -> OPEN") self.state = CircuitState.OPEN self.success_count = 0 elif self.state == CircuitState.CLOSED: if self.failure_count >= self.failure_threshold: logger.error( f"Circuit breaker for {self.service_name}: CLOSED -> OPEN " f"({self.failure_count} failures)" ) self.state = CircuitState.OPEN def get_status(self) -> dict: """Get current circuit breaker status.""" return { "service": self.service_name, "state": self.state.value, "failure_count": self.failure_count, "last_failure": self.last_failure_time.isoformat() if self.last_failure_time else None, "recent_errors": list(self.recent_errors) } # Global circuit breakers class CircuitBreakerManager: """Manages circuit breakers for all services.""" def __init__(self): self.breakers: dict[str, CircuitBreaker] = {} def get_breaker(self, service_name: str) -> CircuitBreaker: """Get or create circuit breaker for service.""" if service_name not in self.breakers: self.breakers[service_name] = CircuitBreaker(service_name) return self.breakers[service_name] def get_all_status(self) -> dict: """Get status of all circuit breakers.""" return { name: breaker.get_status() for name, breaker in self.breakers.items() } # Global instance circuit_breakers = CircuitBreakerManager() # Usage decorator def with_circuit_breaker(service_name: str): """Decorator to add circuit breaker protection.""" def decorator(func: Callable): @wraps(func) async def wrapper(*args, **kwargs): breaker = circuit_breakers.get_breaker(service_name) return await breaker.call(func, *args, **kwargs) return wrapper return decorator # Example usage @with_circuit_breaker("reddit_api") async def fetch_from_reddit(endpoint: str): """Fetch data from Reddit with circuit breaker.""" async with httpx.AsyncClient() as client: response = await client.get(f"https://oauth.reddit.com{endpoint}") response.raise_for_status() return response.json() ``` --- ## 3. Rate Limit Management ### 3.1 Token Bucket Algorithm **Efficient Rate Limiting:** ```python # src/ratelimit/token_bucket.py import asyncio from datetime import datetime, timedelta from typing import Optional import time class TokenBucket: """Token bucket algorithm for rate limiting.""" def __init__( self, capacity: int, refill_rate: float, initial_tokens: Optional[int] = None ): """ Initialize token bucket. Args: capacity: Maximum number of tokens refill_rate: Tokens added per second initial_tokens: Starting tokens (defaults to capacity) """ self.capacity = capacity self.refill_rate = refill_rate self.tokens = initial_tokens if initial_tokens is not None else capacity self.last_refill = time.monotonic() self.lock = asyncio.Lock() async def consume(self, tokens: int = 1) -> bool: """ Try to consume tokens from bucket. Returns: True if tokens were consumed, False if insufficient tokens """ async with self.lock: self._refill() if self.tokens >= tokens: self.tokens -= tokens return True return False async def wait_for_token(self, tokens: int = 1) -> float: """ Wait until tokens are available. Returns: Wait time in seconds (0 if immediate) """ async with self.lock: self._refill() if self.tokens >= tokens: self.tokens -= tokens return 0.0 # Calculate wait time tokens_needed = tokens - self.tokens wait_seconds = tokens_needed / self.refill_rate # Wait and refill await asyncio.sleep(wait_seconds) self._refill() self.tokens -= tokens return wait_seconds def _refill(self): """Refill tokens based on time elapsed.""" now = time.monotonic() elapsed = now - self.last_refill # Calculate new tokens new_tokens = elapsed * self.refill_rate self.tokens = min(self.capacity, self.tokens + new_tokens) self.last_refill = now def get_status(self) -> dict: """Get current bucket status.""" return { "tokens_available": int(self.tokens), "capacity": self.capacity, "refill_rate": self.refill_rate, "utilization": 1 - (self.tokens / self.capacity) } class HierarchicalRateLimiter: """Rate limiter with multiple tiers (user, app, Reddit API).""" def __init__(self): # Reddit API limit: 100 requests per minute self.reddit_bucket = TokenBucket( capacity=100, refill_rate=100/60 # 1.67 per second ) # Per-user limits stored in Redis self.user_buckets: dict[str, TokenBucket] = {} def get_user_bucket(self, user_id: str, tier: str) -> TokenBucket: """Get or create token bucket for user.""" key = f"{user_id}:{tier}" if key not in self.user_buckets: # Free tier: 10 per minute, Paid: 50 per minute limits = { "free": (10, 10/60), "paid": (50, 50/60) } capacity, rate = limits.get(tier, limits["free"]) self.user_buckets[key] = TokenBucket(capacity, rate) return self.user_buckets[key] async def check_limits(self, user_id: str, tier: str) -> dict: """Check all rate limits for request.""" user_bucket = self.get_user_bucket(user_id, tier) # Check user limit first (faster fail) user_ok = await user_bucket.consume() if not user_ok: return { "allowed": False, "limited_by": "user_quota", "retry_after": await self._calculate_retry_time(user_bucket) } # Check Reddit API limit reddit_ok = await self.reddit_bucket.consume() if not reddit_ok: # Refund user token async with user_bucket.lock: user_bucket.tokens += 1 return { "allowed": False, "limited_by": "reddit_api", "retry_after": await self._calculate_retry_time(self.reddit_bucket) } return { "allowed": True, "user_tokens_remaining": int(user_bucket.tokens), "reddit_tokens_remaining": int(self.reddit_bucket.tokens) } async def _calculate_retry_time(self, bucket: TokenBucket) -> int: """Calculate seconds until token available.""" if bucket.tokens >= 1: return 0 tokens_needed = 1 - bucket.tokens return int(tokens_needed / bucket.refill_rate) + 1 def get_global_status(self) -> dict: """Get status of all rate limiters.""" return { "reddit_api": self.reddit_bucket.get_status(), "active_users": len(self.user_buckets) } # Global rate limiter rate_limiter = HierarchicalRateLimiter() # Middleware for rate limiting async def rate_limit_middleware(user_id: str, tier: str): """Middleware to enforce rate limits.""" result = await rate_limiter.check_limits(user_id, tier) if not result["allowed"]: raise RateLimitError( retry_after=result["retry_after"], limit_type=result["limited_by"] ) return result ``` ### 3.2 Request Queuing **Priority Queue System:** ```python # src/queue/priority_queue.py import asyncio from enum import IntEnum from dataclasses import dataclass, field from typing import Optional, Callable, Any from datetime import datetime import uuid class Priority(IntEnum): """Request priority levels.""" CRITICAL = 0 # Alerts, health checks HIGH = 1 # User-facing real-time queries NORMAL = 2 # Standard user requests LOW = 3 # Background jobs, analytics @dataclass(order=True) class QueuedRequest: """Request in the queue.""" priority: Priority = field(compare=True) timestamp: datetime = field(compare=True) request_id: str = field(default_factory=lambda: str(uuid.uuid4()), compare=False) user_id: str = field(compare=False) tool_name: str = field(compare=False) args: dict = field(default_factory=dict, compare=False) callback: Optional[Callable] = field(default=None, compare=False) class RequestQueue: """Priority queue for Reddit API requests.""" def __init__(self, max_size: int = 1000): self.max_size = max_size self.queue = asyncio.PriorityQueue(maxsize=max_size) self.processing = {} # request_id -> task self.results = {} # request_id -> result self.worker_task: Optional[asyncio.Task] = None async def start_worker(self): """Start background worker to process queue.""" if self.worker_task and not self.worker_task.done(): return self.worker_task = asyncio.create_task(self._process_queue()) logger.info("Request queue worker started") async def stop_worker(self): """Stop the background worker.""" if self.worker_task: self.worker_task.cancel() try: await self.worker_task except asyncio.CancelledError: pass logger.info("Request queue worker stopped") async def enqueue( self, user_id: str, tool_name: str, args: dict, priority: Priority = Priority.NORMAL ) -> str: """Add request to queue.""" if self.queue.full(): raise ServiceUnavailableError( service="request_queue", retry_after=30 ) request = QueuedRequest( priority=priority, timestamp=datetime.utcnow(), user_id=user_id, tool_name=tool_name, args=args ) await self.queue.put(request) logger.info( f"Request queued: {request.request_id}", extra={ "user_id": user_id, "tool": tool_name, "priority": priority.name, "queue_size": self.queue.qsize() } ) return request.request_id async def get_result( self, request_id: str, timeout: int = 300 ) -> Any: """Wait for and retrieve request result.""" start_time = datetime.utcnow() while True: # Check if result is ready if request_id in self.results: result = self.results.pop(request_id) if isinstance(result, Exception): raise result return result # Check timeout elapsed = (datetime.utcnow() - start_time).total_seconds() if elapsed > timeout: raise TimeoutError( operation="queue_wait", timeout_seconds=timeout ) # Wait a bit await asyncio.sleep(0.1) async def get_status(self, request_id: str) -> dict: """Get status of queued request.""" if request_id in self.results: return { "status": "completed", "request_id": request_id } if request_id in self.processing: return { "status": "processing", "request_id": request_id } # Check if still in queue (linear search - could be optimized) position = 0 for item in list(self.queue._queue): position += 1 if item.request_id == request_id: estimated_wait = position * 2 # ~2s per request return { "status": "queued", "request_id": request_id, "position": position, "estimated_wait_seconds": estimated_wait } return { "status": "not_found", "request_id": request_id } async def _process_queue(self): """Background worker to process queued requests.""" while True: try: # Get next request request = await self.queue.get() # Mark as processing self.processing[request.request_id] = asyncio.current_task() logger.info( f"Processing request: {request.request_id}", extra={ "tool": request.tool_name, "queue_size": self.queue.qsize() } ) # Wait for rate limit await rate_limiter.reddit_bucket.wait_for_token() # Execute the request try: # Import and execute tool tool = get_tool(request.tool_name) result = await tool.execute(**request.args) self.results[request.request_id] = result except Exception as error: self.results[request.request_id] = error logger.error( f"Request failed: {request.request_id}", exc_info=True ) finally: # Clean up self.processing.pop(request.request_id, None) self.queue.task_done() except asyncio.CancelledError: logger.info("Queue processor cancelled") break except Exception as error: logger.error("Queue processor error", exc_info=True) await asyncio.sleep(1) # Prevent tight loop on error # Global queue request_queue = RequestQueue() # Queue management decorator def queue_if_rate_limited(priority: Priority = Priority.NORMAL): """Decorator to automatically queue requests when rate limited.""" def decorator(func: Callable): @wraps(func) async def wrapper(user_id: str, *args, **kwargs): try: # Try direct execution return await func(user_id, *args, **kwargs) except RateLimitError as error: # Queue the request logger.info( f"Rate limited, queueing request", extra={"user_id": user_id} ) request_id = await request_queue.enqueue( user_id=user_id, tool_name=func.__name__, args=kwargs, priority=priority ) # Return status return { "status": "queued", "request_id": request_id, "message": f"Request queued due to rate limit. Retry after {error.retry_after}s", "check_status_url": f"/api/queue/status/{request_id}" } return wrapper return decorator ``` ### 3.3 Graceful Degradation **Serving Stale Cache During Rate Limits:** ```python # src/cache/graceful_degradation.py from typing import Optional, Any from datetime import datetime, timedelta class GracefulCacheManager: """Cache manager with graceful degradation.""" def __init__(self, redis_client): self.redis = redis_client async def get_with_fallback( self, key: str, max_stale_age: timedelta = timedelta(hours=24) ) -> Optional[dict]: """Get cached value, allowing stale data during outages.""" # Try to get fresh cache cached = await self.redis.get(key) if cached: data = json.loads(cached) cached_at = datetime.fromisoformat(data["cached_at"]) age = datetime.utcnow() - cached_at return { "data": data["value"], "cached_at": data["cached_at"], "age_seconds": int(age.total_seconds()), "is_stale": False } # No fresh cache, try stale cache stale_key = f"{key}:stale" stale_cached = await self.redis.get(stale_key) if stale_cached: data = json.loads(stale_cached) cached_at = datetime.fromisoformat(data["cached_at"]) age = datetime.utcnow() - cached_at # Check if stale data is too old if age > max_stale_age: return None logger.warning( f"Serving stale cache for {key}", extra={"age_hours": age.total_seconds() / 3600} ) return { "data": data["value"], "cached_at": data["cached_at"], "age_seconds": int(age.total_seconds()), "is_stale": True } return None async def set_with_stale_backup( self, key: str, value: Any, ttl: int, stale_ttl: int = 86400 # 24 hours ): """Set cache with long-lived stale backup.""" cache_data = { "value": value, "cached_at": datetime.utcnow().isoformat() } # Set fresh cache await self.redis.setex( key, ttl, json.dumps(cache_data) ) # Set stale backup (longer TTL) await self.redis.setex( f"{key}:stale", stale_ttl, json.dumps(cache_data) ) # Graceful degradation wrapper async def fetch_with_degradation( cache_key: str, fetch_func: Callable, cache_ttl: int = 300 ): """Fetch data with graceful degradation.""" cache_manager = GracefulCacheManager(redis_client) try: # Try to fetch fresh data data = await fetch_func() # Cache it await cache_manager.set_with_stale_backup( cache_key, data, ttl=cache_ttl ) return { "data": data, "from_cache": False, "is_stale": False } except RateLimitError: # Try to serve from cache cached = await cache_manager.get_with_fallback(cache_key) if cached: return cached # No cache available, queue request raise except RedditAPIError: # Reddit is down, serve stale cache cached = await cache_manager.get_with_fallback( cache_key, max_stale_age=timedelta(hours=48) # More lenient during outages ) if cached: return cached raise ServiceUnavailableError( service="reddit_api", retry_after=300 ) ``` --- ## 4. Monitoring & Observability ### 4.1 Metrics Collection **Comprehensive Metrics Tracking:** ```python # src/monitoring/metrics.py from dataclasses import dataclass from datetime import datetime, timedelta from collections import defaultdict from typing import Optional import time @dataclass class Metrics: """Metrics snapshot.""" timestamp: datetime request_count: int error_count: int cache_hits: int cache_misses: int latency_p50: float latency_p95: float latency_p99: float rate_limit_hits: int active_users: int class MetricsCollector: """Collects and aggregates metrics.""" def __init__(self): self.requests_by_tool = defaultdict(int) self.errors_by_type = defaultdict(int) self.cache_stats = {"hits": 0, "misses": 0} self.latencies = [] # Rolling window self.rate_limit_hits = 0 self.active_users = set() # Time series data (last 24 hours) self.time_series = [] self.window_size = 300 # 5 minutes def record_request( self, user_id: str, tool_name: str, latency_ms: float, cached: bool, error: Optional[str] = None ): """Record request metrics.""" # Count requests self.requests_by_tool[tool_name] += 1 self.active_users.add(user_id) # Record latency self.latencies.append(latency_ms) # Keep only recent latencies (last 1000) if len(self.latencies) > 1000: self.latencies = self.latencies[-1000:] # Cache stats if cached: self.cache_stats["hits"] += 1 else: self.cache_stats["misses"] += 1 # Error tracking if error: self.errors_by_type[error] += 1 def record_rate_limit(self): """Record rate limit hit.""" self.rate_limit_hits += 1 def get_snapshot(self) -> Metrics: """Get current metrics snapshot.""" # Calculate percentiles if self.latencies: sorted_latencies = sorted(self.latencies) p50 = sorted_latencies[len(sorted_latencies) // 2] p95 = sorted_latencies[int(len(sorted_latencies) * 0.95)] p99 = sorted_latencies[int(len(sorted_latencies) * 0.99)] else: p50 = p95 = p99 = 0.0 total_requests = sum(self.requests_by_tool.values()) total_errors = sum(self.errors_by_type.values()) return Metrics( timestamp=datetime.utcnow(), request_count=total_requests, error_count=total_errors, cache_hits=self.cache_stats["hits"], cache_misses=self.cache_stats["misses"], latency_p50=p50, latency_p95=p95, latency_p99=p99, rate_limit_hits=self.rate_limit_hits, active_users=len(self.active_users) ) def get_detailed_stats(self) -> dict: """Get detailed statistics.""" snapshot = self.get_snapshot() # Calculate cache hit rate total_cache_requests = self.cache_stats["hits"] + self.cache_stats["misses"] cache_hit_rate = ( self.cache_stats["hits"] / total_cache_requests if total_cache_requests > 0 else 0.0 ) # Error rate error_rate = ( snapshot.error_count / snapshot.request_count if snapshot.request_count > 0 else 0.0 ) return { "overview": { "total_requests": snapshot.request_count, "error_count": snapshot.error_count, "error_rate": error_rate, "active_users": snapshot.active_users, "rate_limit_hits": snapshot.rate_limit_hits }, "latency": { "p50_ms": snapshot.latency_p50, "p95_ms": snapshot.latency_p95, "p99_ms": snapshot.latency_p99 }, "cache": { "hits": self.cache_stats["hits"], "misses": self.cache_stats["misses"], "hit_rate": cache_hit_rate }, "requests_by_tool": dict(self.requests_by_tool), "errors_by_type": dict(self.errors_by_type) } def reset_window(self): """Reset metrics for new time window.""" # Save snapshot to time series self.time_series.append(self.get_snapshot()) # Keep only last 24 hours cutoff = datetime.utcnow() - timedelta(hours=24) self.time_series = [ m for m in self.time_series if m.timestamp > cutoff ] # Reset counters self.requests_by_tool.clear() self.errors_by_type.clear() self.cache_stats = {"hits": 0, "misses": 0} self.latencies.clear() self.rate_limit_hits = 0 self.active_users.clear() # Global metrics collector metrics_collector = MetricsCollector() # Metrics collection decorator def track_metrics(tool_name: str): """Decorator to automatically track metrics.""" def decorator(func: Callable): @wraps(func) async def wrapper(user_id: str, *args, **kwargs): start_time = time.time() error = None cached = False try: result = await func(user_id, *args, **kwargs) # Check if result was cached if isinstance(result, dict): cached = result.get("metadata", {}).get("cached", False) return result except MCPError as e: error = e.category.value raise finally: # Record metrics latency_ms = (time.time() - start_time) * 1000 metrics_collector.record_request( user_id=user_id, tool_name=tool_name, latency_ms=latency_ms, cached=cached, error=error ) return wrapper return decorator ``` ### 4.2 Structured Logging **JSON Logging with Context:** ```python # src/logging/structured_logger.py import logging import json from datetime import datetime from typing import Any, Optional import traceback import sys class StructuredFormatter(logging.Formatter): """JSON formatter for structured logging.""" def format(self, record: logging.LogRecord) -> str: """Format log record as JSON.""" log_data = { "timestamp": datetime.utcnow().isoformat() + "Z", "level": record.levelname, "logger": record.name, "message": record.getMessage(), "environment": settings.environment } # Add exception info if record.exc_info: log_data["exception"] = { "type": record.exc_info[0].__name__, "message": str(record.exc_info[1]), "traceback": traceback.format_exception(*record.exc_info) } # Add extra fields if hasattr(record, "extra"): log_data.update(record.extra) return json.dumps(log_data) def setup_logging(): """Configure structured logging.""" # Root logger root_logger = logging.getLogger() root_logger.setLevel(settings.log_level) # Console handler console_handler = logging.StreamHandler(sys.stdout) console_handler.setFormatter(StructuredFormatter()) root_logger.addHandler(console_handler) # File handler for production if settings.environment == "production": file_handler = logging.FileHandler("/var/log/reddit-mcp/app.log") file_handler.setFormatter(StructuredFormatter()) root_logger.addHandler(file_handler) # Usage example logger = PIISafeLogger(__name__) logger.info( "Tool execution started", extra={ "event": "tool_execution_start", "tool": "search_reddit", "user_id": user_id, "request_id": request_id, "query": query } ) ``` ### 4.3 Alerting Rules **Automated Alert System:** ```python # src/monitoring/alerts.py from enum import Enum from dataclasses import dataclass from datetime import datetime, timedelta from typing import Callable, List class AlertSeverity(Enum): """Alert severity levels.""" INFO = "info" WARNING = "warning" ERROR = "error" CRITICAL = "critical" @dataclass class AlertRule: """Alert rule definition.""" name: str condition: Callable[[Metrics], bool] severity: AlertSeverity message: str cooldown: timedelta = timedelta(minutes=15) last_triggered: Optional[datetime] = None class AlertManager: """Manages alerts based on metrics.""" def __init__(self): self.rules = self._define_rules() self.alert_channels = [] # Email, Slack, PagerDuty, etc. def _define_rules(self) -> List[AlertRule]: """Define alert rules.""" return [ # High error rate AlertRule( name="high_error_rate", condition=lambda m: ( m.error_count / m.request_count > 0.05 if m.request_count > 100 else False ), severity=AlertSeverity.ERROR, message="Error rate exceeded 5% (last 5 minutes)" ), # Rate limit near exhaustion AlertRule( name="rate_limit_high", condition=lambda m: ( rate_limiter.reddit_bucket.get_status()["utilization"] > 0.9 ), severity=AlertSeverity.WARNING, message="Reddit API rate limit >90% utilized" ), # Low cache hit rate AlertRule( name="low_cache_hit_rate", condition=lambda m: ( m.cache_hits / (m.cache_hits + m.cache_misses) < 0.6 if (m.cache_hits + m.cache_misses) > 50 else False ), severity=AlertSeverity.WARNING, message="Cache hit rate below 60%" ), # High latency AlertRule( name="high_latency_p95", condition=lambda m: m.latency_p95 > 2000, # 2 seconds severity=AlertSeverity.WARNING, message="P95 latency exceeded 2 seconds" ), # Circuit breaker open AlertRule( name="circuit_breaker_open", condition=lambda m: any( breaker.state == CircuitState.OPEN for breaker in circuit_breakers.breakers.values() ), severity=AlertSeverity.CRITICAL, message="Circuit breaker opened - service degraded" ), # Queue size high AlertRule( name="request_queue_full", condition=lambda m: ( request_queue.queue.qsize() > request_queue.max_size * 0.8 ), severity=AlertSeverity.WARNING, message="Request queue >80% full" ) ] async def check_alerts(self, metrics: Metrics): """Check all alert rules and trigger if needed.""" now = datetime.utcnow() for rule in self.rules: # Skip if in cooldown if rule.last_triggered: if now - rule.last_triggered < rule.cooldown: continue # Check condition try: if rule.condition(metrics): await self._trigger_alert(rule, metrics) rule.last_triggered = now except Exception as error: logger.error( f"Error checking alert rule: {rule.name}", exc_info=True ) async def _trigger_alert(self, rule: AlertRule, metrics: Metrics): """Trigger an alert.""" alert_data = { "rule": rule.name, "severity": rule.severity.value, "message": rule.message, "timestamp": datetime.utcnow().isoformat(), "metrics": { "error_rate": metrics.error_count / metrics.request_count if metrics.request_count > 0 else 0, "latency_p95": metrics.latency_p95, "cache_hit_rate": metrics.cache_hits / (metrics.cache_hits + metrics.cache_misses) if (metrics.cache_hits + metrics.cache_misses) > 0 else 0 } } # Log alert logger.error( f"ALERT: {rule.message}", extra=alert_data ) # Send to alert channels for channel in self.alert_channels: try: await channel.send_alert(alert_data) except Exception as error: logger.error( f"Failed to send alert to {channel}", exc_info=True ) # Global alert manager alert_manager = AlertManager() # Background task to check alerts async def alert_monitoring_task(): """Background task to monitor and trigger alerts.""" while True: try: metrics = metrics_collector.get_snapshot() await alert_manager.check_alerts(metrics) await asyncio.sleep(60) # Check every minute except Exception as error: logger.error("Alert monitoring task error", exc_info=True) await asyncio.sleep(60) ``` --- ## 5. Fault Tolerance & Recovery ### 5.1 Health Checks **Comprehensive Health Monitoring:** ```python # src/health/checks.py from enum import Enum from dataclasses import dataclass from datetime import datetime from typing import Optional, List import asyncio class HealthStatus(Enum): """Health check status.""" HEALTHY = "healthy" DEGRADED = "degraded" UNHEALTHY = "unhealthy" @dataclass class ComponentHealth: """Health status of a component.""" name: str status: HealthStatus message: str last_check: datetime response_time_ms: Optional[float] = None details: Optional[dict] = None class HealthChecker: """Performs health checks on dependencies.""" async def check_redis(self) -> ComponentHealth: """Check Redis connectivity.""" start = time.time() try: await redis_client.ping() response_time = (time.time() - start) * 1000 return ComponentHealth( name="redis", status=HealthStatus.HEALTHY, message="Redis connection OK", last_check=datetime.utcnow(), response_time_ms=response_time ) except Exception as error: return ComponentHealth( name="redis", status=HealthStatus.UNHEALTHY, message=f"Redis connection failed: {str(error)}", last_check=datetime.utcnow() ) async def check_reddit_api(self) -> ComponentHealth: """Check Reddit API availability.""" start = time.time() try: # Try to fetch Reddit status async with httpx.AsyncClient() as client: response = await client.get( "https://www.reddit.com/api/v1/me", headers={"Authorization": f"Bearer {await auth_manager.get_valid_token()}"}, timeout=5.0 ) response_time = (time.time() - start) * 1000 if response.status_code == 200: status = HealthStatus.HEALTHY message = "Reddit API responding" elif response.status_code == 401: status = HealthStatus.DEGRADED message = "Reddit API auth issue" else: status = HealthStatus.DEGRADED message = f"Reddit API returned {response.status_code}" return ComponentHealth( name="reddit_api", status=status, message=message, last_check=datetime.utcnow(), response_time_ms=response_time ) except asyncio.TimeoutError: return ComponentHealth( name="reddit_api", status=HealthStatus.UNHEALTHY, message="Reddit API timeout", last_check=datetime.utcnow() ) except Exception as error: return ComponentHealth( name="reddit_api", status=HealthStatus.UNHEALTHY, message=f"Reddit API error: {str(error)}", last_check=datetime.utcnow() ) async def check_apify_kv_store(self) -> ComponentHealth: """Check Apify KV Store.""" start = time.time() try: # Try to read/write test value test_key = "health_check" test_value = datetime.utcnow().isoformat() await kv_store.set_value(test_key, test_value) result = await kv_store.get_value(test_key) response_time = (time.time() - start) * 1000 if result == test_value: return ComponentHealth( name="apify_kv_store", status=HealthStatus.HEALTHY, message="KV Store read/write OK", last_check=datetime.utcnow(), response_time_ms=response_time ) else: return ComponentHealth( name="apify_kv_store", status=HealthStatus.DEGRADED, message="KV Store data mismatch", last_check=datetime.utcnow() ) except Exception as error: return ComponentHealth( name="apify_kv_store", status=HealthStatus.UNHEALTHY, message=f"KV Store error: {str(error)}", last_check=datetime.utcnow() ) async def check_all(self) -> dict: """Run all health checks.""" checks = await asyncio.gather( self.check_redis(), self.check_reddit_api(), self.check_apify_kv_store(), return_exceptions=True ) components = {} overall_status = HealthStatus.HEALTHY for check in checks: if isinstance(check, Exception): overall_status = HealthStatus.UNHEALTHY continue components[check.name] = { "status": check.status.value, "message": check.message, "response_time_ms": check.response_time_ms, "last_check": check.last_check.isoformat() } # Determine overall status if check.status == HealthStatus.UNHEALTHY: overall_status = HealthStatus.UNHEALTHY elif check.status == HealthStatus.DEGRADED and overall_status != HealthStatus.UNHEALTHY: overall_status = HealthStatus.DEGRADED return { "status": overall_status.value, "timestamp": datetime.utcnow().isoformat(), "components": components } # Health check endpoint health_checker = HealthChecker() @app.get("/health") async def health_endpoint(): """Health check endpoint.""" health = await health_checker.check_all() # Return appropriate HTTP status status_code = { "healthy": 200, "degraded": 200, # Still accepting requests "unhealthy": 503 }[health["status"]] return JSONResponse(health, status_code=status_code) @app.get("/health/ready") async def readiness_endpoint(): """Readiness check (can accept traffic).""" health = await health_checker.check_all() if health["status"] == "unhealthy": return JSONResponse( {"ready": False, "reason": "Dependencies unhealthy"}, status_code=503 ) return {"ready": True} @app.get("/health/live") async def liveness_endpoint(): """Liveness check (process is running).""" return {"alive": True} ``` ### 5.2 Recovery Strategies **Automated Recovery Procedures:** ```python # src/resilience/recovery.py from typing import Callable, Any from datetime import datetime import asyncio class RecoveryStrategy: """Automated recovery strategies.""" def __init__(self): self.recovery_attempts = {} async def recover_redis_connection(self) -> bool: """Attempt to reconnect to Redis.""" logger.info("Attempting Redis recovery") max_attempts = 5 for attempt in range(max_attempts): try: # Recreate Redis client global redis_client redis_client = await create_redis_client() # Test connection await redis_client.ping() logger.info(f"Redis recovery successful (attempt {attempt + 1})") return True except Exception as error: logger.warning( f"Redis recovery attempt {attempt + 1} failed: {error}" ) await asyncio.sleep(2 ** attempt) # Exponential backoff logger.error("Redis recovery failed after all attempts") return False async def recover_reddit_oauth(self) -> bool: """Refresh Reddit OAuth token.""" logger.info("Attempting Reddit OAuth recovery") try: await auth_manager._refresh_token() logger.info("Reddit OAuth recovery successful") return True except Exception as error: logger.error(f"Reddit OAuth recovery failed: {error}") return False async def switch_to_degraded_mode(self): """Switch to degraded mode (cache-only).""" logger.warning("Switching to degraded mode") # Set flag for degraded mode global DEGRADED_MODE DEGRADED_MODE = True # Notify users await self._notify_degraded_mode() async def exit_degraded_mode(self): """Exit degraded mode.""" logger.info("Exiting degraded mode") global DEGRADED_MODE DEGRADED_MODE = False # Notify recovery await self._notify_recovery() async def _notify_degraded_mode(self): """Notify users of degraded mode.""" # Send to monitoring channels notification = { "type": "service_degradation", "message": "Reddit MCP Server in degraded mode - serving cached data only", "timestamp": datetime.utcnow().isoformat() } logger.warning("Service degradation notification sent", extra=notification) async def _notify_recovery(self): """Notify users of service recovery.""" notification = { "type": "service_recovery", "message": "Reddit MCP Server recovered - full functionality restored", "timestamp": datetime.utcnow().isoformat() } logger.info("Service recovery notification sent", extra=notification) # Global recovery strategy recovery_strategy = RecoveryStrategy() # Background recovery task async def automatic_recovery_task(): """Background task for automatic recovery.""" while True: try: # Check health health = await health_checker.check_all() # Attempt recovery for unhealthy components if health["status"] == "unhealthy": logger.warning("Unhealthy components detected, attempting recovery") for component_name, component in health["components"].items(): if component["status"] == "unhealthy": if component_name == "redis": await recovery_strategy.recover_redis_connection() elif component_name == "reddit_api": await recovery_strategy.recover_reddit_oauth() # If still unhealthy, switch to degraded mode health_after = await health_checker.check_all() if health_after["status"] == "unhealthy": await recovery_strategy.switch_to_degraded_mode() # If healthy and in degraded mode, exit it elif health["status"] == "healthy" and DEGRADED_MODE: await recovery_strategy.exit_degraded_mode() # Sleep before next check await asyncio.sleep(30) except Exception as error: logger.error("Recovery task error", exc_info=True) await asyncio.sleep(30) ``` --- ## 6. Implementation Checklist ### Phase 1: Foundation (Week 1) - [ ] **Security Basics** - [ ] Implement OAuth2 flow for Reddit - [ ] Set up Apify KV Store for encrypted token storage - [ ] Configure environment variables in Apify Console - [ ] Create MCP client authentication (Bearer tokens) - [ ] Add PII-safe logging system - [ ] **Error Handling Core** - [ ] Define error hierarchy (MCPError base class) - [ ] Implement error formatter for JSON-RPC responses - [ ] Add global error handler - [ ] Create retry strategy with exponential backoff - [ ] Test error responses for all tools ### Phase 2: Rate Limiting (Week 2) - [ ] **Rate Limit Implementation** - [ ] Implement token bucket algorithm - [ ] Add per-user rate limiting (free vs paid tiers) - [ ] Create Reddit API rate limiter - [ ] Build priority request queue - [ ] Add queue status endpoint - [ ] **Graceful Degradation** - [ ] Implement stale cache fallback - [ ] Add degraded mode flag - [ ] Create user notifications for queued requests - [ ] Test rate limit scenarios ### Phase 3: Monitoring (Week 3) - [ ] **Metrics & Logging** - [ ] Set up structured JSON logging - [ ] Implement metrics collector - [ ] Add latency tracking (P50, P95, P99) - [ ] Track cache hit/miss rates - [ ] Create metrics dashboard endpoint - [ ] **Alerting** - [ ] Define alert rules - [ ] Implement alert manager - [ ] Set up Sentry error tracking - [ ] Configure alert channels (email, Slack) - [ ] Test alert triggers ### Phase 4: Resilience (Week 4) - [ ] **Fault Tolerance** - [ ] Implement circuit breaker pattern - [ ] Add health check endpoints (/health, /ready, /live) - [ ] Create automated recovery strategies - [ ] Test failure scenarios - [ ] Document manual intervention procedures - [ ] **Input Validation** - [ ] Create Pydantic models for all tool inputs - [ ] Add sanitization for SQL injection/XSS - [ ] Enforce length limits - [ ] Test validation edge cases ### Phase 5: Compliance (Ongoing) - [ ] **Data Privacy** - [ ] Implement GDPR data export - [ ] Add GDPR data deletion - [ ] Create retention policy enforcement - [ ] Document data handling procedures - [ ] Add privacy policy - [ ] **Security Hardening** - [ ] Security audit of all endpoints - [ ] Penetration testing - [ ] Dependency vulnerability scan - [ ] Secret rotation procedures - [ ] Incident response plan --- ## Configuration Examples ### Apify Actor Configuration ```json { "name": "reddit-mcp-server", "version": "1.0.0", "buildTag": "latest", "environmentVariables": { "REDDIT_CLIENT_ID": "your_client_id", "REDDIT_CLIENT_SECRET": "your_client_secret", "REDIS_URL": "redis://redis:6379", "MCP_MASTER_KEY": "generate_with_secrets_token_hex_32", "SENTRY_DSN": "https://your-sentry-dsn", "LOG_LEVEL": "INFO", "ENVIRONMENT": "production" }, "resourceRequirements": { "memoryMbytes": 2048, "diskMbytes": 1024 } } ``` ### Redis Configuration ```yaml # docker-compose.yml services: redis: image: redis:7-alpine command: redis-server --maxmemory 512mb --maxmemory-policy allkeys-lru volumes: - redis_data:/data ports: - "6379:6379" healthcheck: test: ["CMD", "redis-cli", "ping"] interval: 10s timeout: 3s retries: 3 ``` ### Sentry Configuration ```python # src/monitoring/sentry.py import sentry_sdk from sentry_sdk.integrations.asyncio import AsyncioIntegration from sentry_sdk.integrations.logging import LoggingIntegration def init_sentry(): """Initialize Sentry error tracking.""" if not settings.sentry_dsn: return sentry_sdk.init( dsn=settings.sentry_dsn.get_secret_value(), environment=settings.environment, integrations=[ AsyncioIntegration(), LoggingIntegration(level=logging.INFO, event_level=logging.ERROR) ], traces_sample_rate=0.1, # 10% of transactions profiles_sample_rate=0.1, before_send=filter_sensitive_data, ignore_errors=[ValidationError, RateLimitError] # Don't report client errors ) def filter_sensitive_data(event, hint): """Filter sensitive data from Sentry events.""" # Remove authorization headers if "request" in event and "headers" in event["request"]: headers = event["request"]["headers"] if "Authorization" in headers: headers["Authorization"] = "[REDACTED]" # Remove user data if "user" in event: event["user"] = {"id": event["user"].get("id")} # Keep only ID return event ``` --- ## Testing Strategy ### Security Testing ```python # tests/test_security.py import pytest from src.auth.mcp_auth import MCPAuthManager @pytest.mark.asyncio async def test_invalid_api_key(): """Test that invalid API keys are rejected.""" with pytest.raises(HTTPException) as exc_info: await auth_manager.validate_token("invalid_key") assert exc_info.value.status_code == 401 @pytest.mark.asyncio async def test_sql_injection_protection(): """Test SQL injection prevention in inputs.""" malicious_input = "'; DROP TABLE users; --" with pytest.raises(ValidationError): SearchRedditInput(query=malicious_input) @pytest.mark.asyncio async def test_xss_protection(): """Test XSS prevention in inputs.""" malicious_input = "<script>alert('xss')</script>" with pytest.raises(ValidationError): SearchRedditInput(query=malicious_input) @pytest.mark.asyncio async def test_pii_redaction_in_logs(caplog): """Test that PII is redacted from logs.""" logger.info("User email@example.com made request with token abc123") assert "email@example.com" not in caplog.text assert "abc123" not in caplog.text assert "[REDACTED" in caplog.text ``` ### Error Handling Testing ```python # tests/test_error_handling.py import pytest from src.errors.exceptions import * @pytest.mark.asyncio async def test_rate_limit_error_format(): """Test rate limit error format.""" error = RateLimitError(retry_after=60) response = error.to_json_rpc() assert response["error"]["code"] == -32004 assert response["error"]["data"]["retry_after_seconds"] == 60 @pytest.mark.asyncio async def test_retry_on_transient_error(): """Test retry logic for transient errors.""" attempts = 0 async def flaky_function(): nonlocal attempts attempts += 1 if attempts < 3: raise RedditAPIError(503, "Service unavailable") return "success" result = await retry_strategy.execute_with_retry(flaky_function) assert result == "success" assert attempts == 3 @pytest.mark.asyncio async def test_circuit_breaker_opens(): """Test circuit breaker opens after failures.""" breaker = CircuitBreaker("test_service", failure_threshold=3) # Trigger failures for _ in range(3): try: await breaker.call(lambda: raise RedditAPIError(500, "Error")) except: pass assert breaker.state == CircuitState.OPEN # Next call should fail fast with pytest.raises(ServiceUnavailableError): await breaker.call(lambda: "success") ``` --- **End of Security & Error Handling Architecture** This architecture provides production-grade security, comprehensive error handling, intelligent rate limiting, and robust monitoring for the Reddit MCP Server. All components are designed to work together to create a resilient, secure, and observable system.

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/padak/apify-actor-reddit-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server