Skip to main content
Glama
discover.py19.3 kB
"""Subreddit discovery using semantic vector search.""" import os import json import statistics from typing import Dict, List, Optional, Union, Any from dataclasses import dataclass from fastmcp import Context from ..chroma_client import get_chroma_client, get_collection @dataclass class SearchConfig: """Centralized configuration for semantic search behavior. All parameters are tunable to adjust search sensitivity and result filtering. """ # Distance-to-tier thresholds (Euclidean distance) EXACT_DISTANCE_THRESHOLD: float = 0.2 # Very relevant matches SEMANTIC_DISTANCE_THRESHOLD: float = 0.35 # Relevant matches ADJACENT_DISTANCE_THRESHOLD: float = 0.65 # Somewhat relevant # Confidence score mapping (maps distance ranges to confidence 0.0-1.0) # Lower distance = higher confidence CONFIDENCE_DISTANCE_BREAKPOINTS: Dict[float, float] = None # Generic subreddit filtering GENERIC_SUBREDDITS: List[str] = None GENERIC_PENALTY_MULTIPLIER: float = 0.3 # Penalty when generic sub matches broadly # Subscriber-based scoring LARGE_SUB_THRESHOLD: int = 1_000_000 # Boost applied above this LARGE_SUB_BOOST_MULTIPLIER: float = 1.1 # Boost factor for large subs SMALL_SUB_THRESHOLD: int = 10_000 # Penalty applied below this SMALL_SUB_PENALTY_MULTIPLIER: float = 0.9 # Penalty factor for small subs # Search behavior SEARCH_MULTIPLIER: int = 3 # Fetch 3x limit for filtering MAX_SEARCH_RESULTS: int = 100 # Hard cap on returned results def __post_init__(self): """Initialize default values for mutable fields.""" if self.GENERIC_SUBREDDITS is None: self.GENERIC_SUBREDDITS = [ 'funny', 'pics', 'videos', 'gifs', 'memes', 'aww', 'news', 'science' ] if self.CONFIDENCE_DISTANCE_BREAKPOINTS is None: # Maps distance thresholds to confidence ranges # Distance range -> confidence range self.CONFIDENCE_DISTANCE_BREAKPOINTS = { 0.8: 0.95, # <0.8: 0.9-1.0 1.0: 0.80, # 0.8-1.0: 0.7-0.9 1.2: 0.60, # 1.0-1.2: 0.5-0.7 1.4: 0.40, # 1.2-1.4: 0.3-0.5 2.0: 0.15 # 1.4-2.0: 0.1-0.3 } # Global default configuration DEFAULT_SEARCH_CONFIG = SearchConfig() def classify_match_tier(distance: float, config: SearchConfig = None) -> str: """ Classify semantic match quality based on distance. Uses configurable thresholds to categorize match relevance. Distance scale (Euclidean, lower is better): - 0.0-0.2: exact (highly relevant) - 0.2-0.35: semantic (very relevant) - 0.35-0.65: adjacent (somewhat relevant) - 0.65+: peripheral (weakly relevant) Args: distance: Euclidean distance from query embedding config: SearchConfig instance with tunable thresholds Returns: Match tier label: "exact", "semantic", "adjacent", or "peripheral" """ if config is None: config = DEFAULT_SEARCH_CONFIG if distance < config.EXACT_DISTANCE_THRESHOLD: return "exact" elif distance < config.SEMANTIC_DISTANCE_THRESHOLD: return "semantic" elif distance < config.ADJACENT_DISTANCE_THRESHOLD: return "adjacent" else: return "peripheral" def calculate_confidence_stats(confidence_scores: List[float]) -> Dict[str, float]: """ Calculate statistics about confidence score distribution. Provides insight into the quality and variance of matched results. Args: confidence_scores: List of confidence scores (0.0-1.0) Returns: Dictionary with mean, median, min, max, and standard deviation """ if not confidence_scores: return { "mean": 0.0, "median": 0.0, "min": 0.0, "max": 0.0, "std_dev": 0.0 } sorted_scores = sorted(confidence_scores) return { "mean": round(statistics.mean(confidence_scores), 3), "median": round(sorted_scores[len(sorted_scores) // 2], 3), "min": round(min(confidence_scores), 3), "max": round(max(confidence_scores), 3), "std_dev": round(statistics.stdev(confidence_scores), 3) if len(confidence_scores) > 1 else 0.0 } def _get_vector_collection(collection_name: str = "dialog-app-prod-db"): """ Initialize ChromaDB client and retrieve a collection. Helper function to reduce code duplication across discovery operations. Args: collection_name: Name of the ChromaDB collection to retrieve Returns: ChromaDB collection object Raises: Exception: If connection or collection retrieval fails """ client = get_chroma_client() return get_collection(collection_name, client) def calculate_confidence_from_distance(distance: float, config: SearchConfig = None) -> float: """ Convert Euclidean distance to confidence score (0.0-1.0). Uses a piecewise linear interpolation model mapping distance ranges to confidence bands. Lower distance = higher confidence. Distance ranges map as follows: - <0.8: confidence 0.90-1.0 (excellent match) - 0.8-1.0: confidence 0.70-0.9 (strong match) - 1.0-1.2: confidence 0.50-0.7 (moderate match) - 1.2-1.4: confidence 0.30-0.5 (weak match) - 1.4-2.0: confidence 0.10-0.3 (very weak match) Args: distance: Euclidean distance from query embedding config: SearchConfig with confidence mapping parameters Returns: Confidence score between 0.0 and 1.0 Example: >>> calculate_confidence_from_distance(0.5) 0.937 # Very high confidence for close match >>> calculate_confidence_from_distance(1.2) 0.6 # Moderate confidence """ if config is None: config = DEFAULT_SEARCH_CONFIG breakpoints = sorted(config.CONFIDENCE_DISTANCE_BREAKPOINTS.items()) # Find which bracket the distance falls into for i, (threshold, confidence_at_threshold) in enumerate(breakpoints): if distance <= threshold: if i == 0: # Before first breakpoint: linearly interpolate from 1.0 to confidence_at_threshold prior_threshold = 0.0 prior_confidence = 1.0 else: prior_threshold, prior_confidence = breakpoints[i - 1] # Linear interpolation between prior and current threshold if threshold == prior_threshold: return confidence_at_threshold interpolation = (distance - prior_threshold) / (threshold - prior_threshold) confidence = prior_confidence - (prior_confidence - confidence_at_threshold) * interpolation return round(max(0.0, min(1.0, confidence)), 3) # Beyond last breakpoint, return minimum confidence return 0.1 def calculate_tier_distribution(results: List[Dict[str, Any]]) -> Dict[str, int]: """ Count results by match tier. Args: results: List of result dictionaries with 'match_tier' field Returns: Dictionary with counts for each tier """ tier_counts = {"exact": 0, "semantic": 0, "adjacent": 0, "peripheral": 0} for result in results: tier = result.get('match_tier', 'peripheral') if tier in tier_counts: tier_counts[tier] += 1 return tier_counts async def discover_subreddits( query: Optional[str] = None, queries: Optional[Union[List[str], str]] = None, limit: int = 10, include_nsfw: bool = False, min_confidence: float = 0.0, config: Optional[SearchConfig] = None, ctx: Context = None ) -> Dict[str, Any]: """ Search for subreddits using semantic similarity search. Finds relevant subreddits based on semantic embeddings of subreddit names, descriptions, and community metadata. Args: query: Single search term to find subreddits queries: List of search terms for batch discovery (more efficient) Can also be a JSON string like '["term1", "term2"]' limit: Maximum number of results per query (default 10) include_nsfw: Whether to include NSFW subreddits (default False) min_confidence: Minimum confidence score to include (0.0-1.0, default 0.0) config: SearchConfig instance for tuning behavior (uses defaults if None) ctx: FastMCP context (auto-injected by decorator) Returns: Dictionary with discovered subreddits and their metadata """ if config is None: config = DEFAULT_SEARCH_CONFIG # Initialize ChromaDB client try: collection = _get_vector_collection("dialog-app-prod-db") except Exception as e: return { "error": f"Failed to connect to vector database: {str(e)}", "results": [], "summary": { "total_found": 0, "returned": 0, "coverage": "error" } } # Handle batch queries - convert string to list if needed if queries: # Handle case where LLM passes JSON string instead of array if isinstance(queries, str): try: # Try to parse as JSON if it looks like a JSON array if queries.strip().startswith('[') and queries.strip().endswith(']'): queries = json.loads(queries) else: # Single string query, convert to single-item list queries = [queries] except (json.JSONDecodeError, ValueError): # If JSON parsing fails, treat as single string queries = [queries] batch_results = {} total_api_calls = 0 for search_query in queries: result = await _search_vector_db( search_query, collection, limit, include_nsfw, min_confidence, config, ctx ) batch_results[search_query] = result total_api_calls += 1 return { "batch_mode": True, "total_queries": len(queries), "api_calls_made": total_api_calls, "results": batch_results, "tip": "Batch mode reduces API calls. Use the exact 'name' field when calling other tools." } # Handle single query elif query: return await _search_vector_db(query, collection, limit, include_nsfw, min_confidence, config, ctx) else: return { "error": "Either 'query' or 'queries' parameter must be provided", "subreddits": [], "summary": { "total_found": 0, "returned": 0, "coverage": "error" } } async def _search_vector_db( query: str, collection, limit: int, include_nsfw: bool, min_confidence: float, config: SearchConfig = None, ctx: Context = None ) -> Dict[str, Any]: """ Internal function to perform semantic search for a single query. Args: query: Search query string collection: ChromaDB collection object limit: Max results to return include_nsfw: Whether to include NSFW subreddits min_confidence: Minimum confidence threshold config: SearchConfig with tunable parameters ctx: FastMCP context for progress reporting Returns: Dictionary with search results and statistics """ if config is None: config = DEFAULT_SEARCH_CONFIG try: # Search with a larger limit to allow for filtering search_limit = min(limit * config.SEARCH_MULTIPLIER, config.MAX_SEARCH_RESULTS) # Perform semantic search results = collection.query( query_texts=[query], n_results=search_limit ) if not results or not results['metadatas'] or not results['metadatas'][0]: return { "query": query, "subreddits": [], "summary": { "total_found": 0, "returned": 0, "has_more": False }, "next_actions": ["Try different search terms"] } # Process results processed_results = [] nsfw_filtered = 0 total_results = len(results['metadatas'][0]) for i, (metadata, distance) in enumerate(zip( results['metadatas'][0], results['distances'][0] )): # Report progress if ctx: await ctx.report_progress( progress=i + 1, total=total_results, message=f"Analyzing r/{metadata.get('name', 'unknown')}" ) # Skip NSFW if not requested if metadata.get('nsfw', False) and not include_nsfw: nsfw_filtered += 1 continue # Convert distance to confidence score using configurable model confidence = calculate_confidence_from_distance(distance, config) # Apply penalties for generic subreddits (configurable) subreddit_name = metadata.get('name', '').lower() if subreddit_name in config.GENERIC_SUBREDDITS and query.lower() not in subreddit_name: confidence *= config.GENERIC_PENALTY_MULTIPLIER # Apply boosts/penalties based on subscriber count subscribers = metadata.get('subscribers', 0) if subscribers > config.LARGE_SUB_THRESHOLD: confidence = min(1.0, confidence * config.LARGE_SUB_BOOST_MULTIPLIER) elif subscribers < config.SMALL_SUB_THRESHOLD: confidence *= config.SMALL_SUB_PENALTY_MULTIPLIER # Determine match type based on distance if distance < 0.3: match_type = "exact_match" elif distance < 0.7: match_type = "strong_match" elif distance < 1.0: match_type = "partial_match" else: match_type = "weak_match" # Classify match tier based on distance match_tier = classify_match_tier(distance, config) processed_results.append({ "name": metadata.get('name', 'unknown'), "subscribers": metadata.get('subscribers', 0), "confidence": round(confidence, 3), "distance": round(distance, 3), "match_tier": match_tier, "url": metadata.get('url', f"https://reddit.com/r/{metadata.get('name', '')}") }) # Filter by minimum confidence if specified (Phase 2a.3) if min_confidence > 0.0: processed_results = [ r for r in processed_results if r['confidence'] >= min_confidence ] # Sort by confidence (highest first), then by subscribers processed_results.sort(key=lambda x: (-x['confidence'], -(x['subscribers'] or 0))) # Limit to requested number limited_results = processed_results[:limit] # Calculate basic stats total_found = len(processed_results) # Calculate confidence statistics (Phase 2a.4) confidence_scores = [r['confidence'] for r in limited_results] confidence_stats = calculate_confidence_stats(confidence_scores) tier_distribution = calculate_tier_distribution(limited_results) # Generate next actions (only meaningful ones) next_actions = [] if len(processed_results) > limit: next_actions.append(f"{len(processed_results)} total results found, showing {limit}") if nsfw_filtered > 0: next_actions.append(f"{nsfw_filtered} NSFW subreddits filtered") return { "query": query, "subreddits": limited_results, "summary": { "total_found": total_found, "returned": len(limited_results), "has_more": total_found > len(limited_results), "confidence_stats": confidence_stats, "tier_distribution": tier_distribution }, "next_actions": next_actions } except Exception as e: # Map error patterns to specific recovery actions error_str = str(e).lower() if "not found" in error_str: guidance = "Verify subreddit name spelling" elif "rate" in error_str: guidance = "Rate limited - wait 60 seconds" elif "timeout" in error_str: guidance = "Reduce limit parameter to 10" else: guidance = "Try simpler search terms" return { "error": f"Failed to search vector database: {str(e)}", "query": query, "subreddits": [], "summary": { "total_found": 0, "returned": 0, "has_more": False }, "next_actions": [guidance] } def validate_subreddit( subreddit_name: str, ctx: Context = None ) -> Dict[str, Any]: """ Validate if a subreddit exists in the indexed database. Checks if the subreddit exists in our semantic search index and returns its metadata if found. Args: subreddit_name: Name of the subreddit to validate ctx: FastMCP context (optional) Returns: Dictionary with validation result and subreddit info if found """ # Clean the subreddit name clean_name = subreddit_name.replace("r/", "").replace("/r/", "").strip() try: # Search for exact match in vector database collection = _get_vector_collection("dialog-app-prod-db") # Search for the exact subreddit name results = collection.query( query_texts=[clean_name], n_results=5 ) if results and results['metadatas'] and results['metadatas'][0]: # Look for exact match in results for metadata in results['metadatas'][0]: if metadata.get('name', '').lower() == clean_name.lower(): return { "valid": True, "name": metadata.get('name'), "subscribers": metadata.get('subscribers', 0), "is_private": False, # We only index public subreddits "over_18": metadata.get('nsfw', False), "indexed": True } return { "valid": False, "name": clean_name, "error": f"Subreddit '{clean_name}' not found", "suggestion": "Use discover_subreddits to find similar communities" } except Exception as e: return { "valid": False, "name": clean_name, "error": f"Database error: {str(e)}", "suggestion": "Check database connection and retry" }

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/king-of-the-grackles/dialog-reddit-tools'

If you have feedback or need assistance with the MCP directory API, please join our Discord server