#!/usr/bin/env python3
"""Enhanced Search Capabilities for Session Management MCP Server.
Provides multi-modal search including code snippets, error patterns, and time-based queries.
"""
import ast
import contextlib
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, cast
if TYPE_CHECKING:
from dateutil.parser import parse as parse_date
from dateutil.relativedelta import relativedelta
try:
from dateutil.parser import parse as parse_date
from dateutil.relativedelta import relativedelta
DATEUTIL_AVAILABLE = True
except ImportError:
DATEUTIL_AVAILABLE = False
import operator
from .reflection_tools import ReflectionDatabase
from .session_types import TimeRange
from .utils.regex_patterns import SAFE_PATTERNS
class CodeSearcher:
"""AST-based code search for Python code snippets."""
def __init__(self) -> None:
self.search_types: dict[str, type[ast.AST] | tuple[type[ast.AST], ...]] = {
"function": ast.FunctionDef,
"class": ast.ClassDef,
"import": (ast.Import, ast.ImportFrom),
"assignment": ast.Assign,
"call": ast.Call,
"loop": (ast.For, ast.While),
"conditional": ast.If,
"try": ast.Try,
"async": (ast.AsyncFunctionDef, ast.AsyncWith, ast.AsyncFor),
}
def _extract_pattern_info(
self,
node: ast.AST,
pattern_type: str,
code: str,
block_index: int,
) -> dict[str, Any]:
"""Extract pattern information from AST node."""
pattern_info = {
"type": pattern_type,
"content": code,
"block_index": block_index,
"line_number": getattr(node, "lineno", 0),
}
# Extract specific information based on node type
if isinstance(node, ast.FunctionDef):
pattern_info["name"] = node.name
pattern_info["args"] = [arg.arg for arg in node.args.args]
elif isinstance(node, ast.ClassDef):
pattern_info["name"] = node.name
elif isinstance(node, ast.Import | ast.ImportFrom):
if isinstance(node, ast.Import):
pattern_info["modules"] = [alias.name for alias in node.names]
else:
pattern_info["module"] = node.module
pattern_info["names"] = [alias.name for alias in node.names]
return pattern_info
def _process_code_block(self, code: str, block_index: int) -> list[dict[str, Any]]:
"""Process a single code block and extract patterns."""
patterns = []
with contextlib.suppress(SyntaxError, ValueError):
# Not valid Python code, skip
tree = ast.parse(code)
for node in ast.walk(tree):
for pattern_type, node_types in self.search_types.items():
# Handle both single classes and tuples of classes
type_check = (
node_types if isinstance(node_types, tuple) else (node_types,)
)
if isinstance(node, type_check):
pattern_info = self._extract_pattern_info(
node,
pattern_type,
code,
block_index,
)
patterns.append(pattern_info)
return patterns
def extract_code_patterns(self, content: str) -> list[dict[str, Any]]:
"""Extract code patterns from conversation content."""
patterns = []
# Extract Python code blocks using validated patterns
python_code_blocks = SAFE_PATTERNS["python_code_block"].findall(content)
generic_code_blocks = SAFE_PATTERNS["generic_code_block"].findall(content)
code_blocks = python_code_blocks + generic_code_blocks
for i, code in enumerate(code_blocks):
block_patterns = self._process_code_block(code, i)
patterns.extend(block_patterns)
return patterns
class ErrorPatternMatcher:
"""Pattern matching for error messages and debugging contexts."""
def __init__(self) -> None:
# Map pattern names to our validated patterns
self.error_patterns = {
"python_traceback": "python_traceback",
"python_exception": "python_exception",
"javascript_error": "javascript_error",
"compile_error": "compile_error",
"warning": "warning_pattern",
"assertion": "assertion_error",
"import_error": "import_error",
"module_not_found": "module_not_found",
"file_not_found": "file_not_found",
"permission_denied": "permission_denied",
"network_error": "network_error",
}
# Map context pattern names to our validated patterns
self.context_patterns = {
"debugging": "debugging_context",
"testing": "testing_context",
"error_handling": "error_handling_context",
"performance": "performance_context",
"security": "security_context",
}
def extract_error_patterns(self, content: str) -> list[dict[str, Any]]:
"""Extract error patterns and debugging context from content."""
patterns = []
# Find error patterns using validated patterns
for pattern_name, safe_pattern_key in self.error_patterns.items():
safe_pattern = SAFE_PATTERNS[safe_pattern_key]
# Use search() method to find matches with position info
match = safe_pattern.search(content)
if match:
patterns.append(
{
"type": "error",
"subtype": pattern_name,
"content": match.group(0),
"start": match.start(),
"end": match.end(),
"groups": match.groups() or [],
},
)
# Find context patterns using validated patterns
for context_name, safe_pattern_key in self.context_patterns.items():
safe_pattern = SAFE_PATTERNS[safe_pattern_key]
if safe_pattern.test(content):
patterns.append(
{
"type": "context",
"subtype": context_name,
"content": content,
"relevance": "high"
if context_name in {"debugging", "error_handling"}
else "medium",
},
)
return patterns
class TemporalSearchParser:
"""Parse natural language time expressions for conversation search."""
def __init__(self) -> None:
self.relative_patterns = {
"today": timedelta(hours=0),
"yesterday": timedelta(days=1),
"this week": timedelta(weeks=1),
"last week": timedelta(weeks=1, days=7),
"this month": relativedelta(months=1)
if DATEUTIL_AVAILABLE
else timedelta(days=30),
"last month": relativedelta(months=2)
if DATEUTIL_AVAILABLE
else timedelta(days=60),
"this year": relativedelta(years=1)
if DATEUTIL_AVAILABLE
else timedelta(days=365),
}
# Map to validated time parsing patterns
self.time_patterns = [
"time_ago_pattern",
"relative_time_pattern",
"since_time_pattern",
"last_duration_pattern",
"iso_date_pattern",
"us_date_pattern",
]
def _calculate_delta(self, amount: int, unit: str) -> timedelta:
"""Calculate timedelta from amount and unit."""
if unit == "minute":
return timedelta(minutes=amount)
if unit == "hour":
return timedelta(hours=amount)
if unit == "day":
return timedelta(days=amount)
if unit == "week":
return timedelta(weeks=amount)
if unit == "month":
# Always use timedelta approximation for type safety
return timedelta(days=amount * 30)
if unit == "year":
# Always use timedelta approximation for type safety
return timedelta(days=amount * 365)
return timedelta()
def _parse_relative_patterns(
self,
expression: str,
now: datetime,
) -> TimeRange:
"""Parse relative time patterns."""
for pattern, delta in self.relative_patterns.items():
if pattern in expression:
if "last" in pattern or pattern == "yesterday":
end_time = now - delta
start_time = end_time - delta
else:
start_time = now - delta
end_time = now
return TimeRange(start=start_time, end=end_time)
return TimeRange()
def _parse_ago_pattern(
self,
expression: str,
now: datetime,
) -> TimeRange:
"""Parse 'X time units ago' pattern."""
match = SAFE_PATTERNS["time_ago_pattern"].search(expression)
if match:
amount = int(match.group(1))
unit = match.group(2)
delta = self._calculate_delta(amount, unit)
end_time = now - delta
return TimeRange(start=end_time, end=now)
return TimeRange()
def _parse_last_pattern(
self,
expression: str,
now: datetime,
) -> TimeRange:
"""Parse 'in the last X units' pattern."""
match = SAFE_PATTERNS["last_duration_pattern"].search(expression)
if match:
amount = int(match.group(1))
unit = match.group(2)
delta = self._calculate_delta(amount, unit)
start_time = now - delta
return TimeRange(start=start_time, end=now)
return TimeRange()
def _parse_absolute_date(
self,
expression: str,
) -> TimeRange:
"""Parse absolute date expressions."""
if not DATEUTIL_AVAILABLE:
return TimeRange()
from contextlib import suppress
with suppress(ValueError, TypeError):
parsed_date = parse_date(expression)
# Ensure parsed_date is a datetime object
if isinstance(parsed_date, datetime):
# Return day range (start of day to end of day)
start_time = parsed_date.replace(
hour=0,
minute=0,
second=0,
microsecond=0,
)
end_time = start_time + timedelta(days=1)
return TimeRange(start=start_time, end=end_time)
return TimeRange()
def parse_time_expression(
self,
expression: str,
) -> TimeRange:
"""Parse time expression into start and end datetime."""
expression = expression.lower().strip()
now = datetime.now()
# Try different parsing strategies
parsers = [
self._parse_relative_patterns,
self._parse_ago_pattern,
self._parse_last_pattern,
lambda expr, dt: self._parse_absolute_date(expr),
]
for parser in parsers:
result = parser(expression, now)
if result.start is not None or result.end is not None:
return result
return TimeRange()
class EnhancedSearchEngine:
"""Main search engine that combines all enhanced search capabilities."""
def __init__(self, reflection_db: ReflectionDatabase) -> None:
self.reflection_db = reflection_db
self.code_searcher = CodeSearcher()
self.error_matcher = ErrorPatternMatcher()
self.temporal_parser = TemporalSearchParser()
async def search_code_patterns(
self,
query: str,
pattern_type: str | None = None,
limit: int = 10,
) -> list[dict[str, Any]]:
"""Search for code patterns in conversations."""
conversations = self._get_all_conversations()
if not conversations:
return []
results = []
for conv in conversations:
conv_results = self._process_conversation_for_code_patterns(
conv,
query,
pattern_type,
)
results.extend(conv_results)
return self._sort_and_limit_results(results, limit)
def _get_all_conversations(self) -> list[tuple[str, str, str, str, str]]:
"""Get all conversations from database."""
if not hasattr(self.reflection_db, "conn") or not self.reflection_db.conn:
return []
cursor = self.reflection_db.conn.execute(
"SELECT id, content, project, timestamp, metadata FROM conversations",
)
return cast("list[tuple[str, str, str, str, str]]", cursor.fetchall())
def _process_conversation_for_code_patterns(
self,
conv: tuple[str, str, str, str, str],
query: str,
pattern_type: str | None,
) -> list[dict[str, Any]]:
"""Process a single conversation for code patterns."""
conv_id, content, project, timestamp, _metadata = conv
patterns = self.code_searcher.extract_code_patterns(content)
results = []
for pattern in patterns:
if pattern_type and pattern["type"] != pattern_type:
continue
relevance = self._calculate_code_relevance(pattern, query)
if relevance > 0.3: # Threshold for relevance
results.append(
{
"conversation_id": conv_id,
"project": project,
"timestamp": timestamp,
"pattern": pattern,
"relevance": relevance,
"snippet": content[:500] + "..."
if len(content) > 500
else content,
},
)
return results
def _sort_and_limit_results(
self,
results: list[dict[str, Any]],
limit: int,
) -> list[dict[str, Any]]:
"""Sort results by relevance and limit."""
results.sort(key=operator.itemgetter("relevance"), reverse=True)
return results[:limit]
async def search_error_patterns(
self,
query: str,
error_type: str | None = None,
limit: int = 10,
) -> list[dict[str, Any]]:
"""Search for error patterns and debugging contexts."""
conversations = self._get_all_conversations()
if not conversations:
return []
results = []
for conv in conversations:
conv_results = self._process_conversation_for_error_patterns(
conv,
query,
error_type,
)
results.extend(conv_results)
return self._sort_and_limit_results(results, limit)
def _process_conversation_for_error_patterns(
self,
conv: tuple[str, str, str, str, str],
query: str,
error_type: str | None,
) -> list[dict[str, Any]]:
"""Process a single conversation for error patterns."""
conv_id, content, project, timestamp, _metadata = conv
patterns = self.error_matcher.extract_error_patterns(content)
results = []
for pattern in patterns:
if error_type and pattern["subtype"] != error_type:
continue
relevance = self._calculate_error_relevance(pattern, query)
if relevance > 0.2: # Lower threshold for errors
results.append(
{
"conversation_id": conv_id,
"project": project,
"timestamp": timestamp,
"pattern": pattern,
"relevance": relevance,
"snippet": content[:500] + "..."
if len(content) > 500
else content,
},
)
return results
async def search_temporal(
self,
time_expression: str,
query: str | None = None,
limit: int = 10,
) -> list[dict[str, Any]]:
"""Search conversations within a time range."""
time_range = self.temporal_parser.parse_time_expression(
time_expression,
)
if not time_range.start or not time_range.end:
return [{"error": f"Could not parse time expression: {time_expression}"}]
start_time = time_range.start
end_time = time_range.end
results = []
if hasattr(self.reflection_db, "conn") and self.reflection_db.conn:
# Convert to ISO format for database query
start_iso = start_time.isoformat()
end_iso = end_time.isoformat()
sql_query = """
SELECT id, content, project, timestamp, metadata
FROM conversations
WHERE timestamp BETWEEN ? AND ?
ORDER BY timestamp DESC
"""
cursor = self.reflection_db.conn.execute(sql_query, (start_iso, end_iso))
conversations = cursor.fetchall()
for conv in conversations:
conv_id, content, project, timestamp, _metadata = conv
# If query provided, filter by content relevance
if query:
relevance = self._calculate_text_relevance(content, query)
if relevance < 0.3:
continue
else:
relevance = 1.0
results.append(
{
"conversation_id": conv_id,
"project": project,
"timestamp": timestamp,
"content": content[:500] + "..."
if len(content) > 500
else content,
"relevance": relevance,
},
)
return results[:limit]
def _calculate_code_relevance(self, pattern: dict[str, Any], query: str) -> float:
"""Calculate relevance score for code patterns."""
relevance = 0.0
query_lower = query.lower()
# Type matching
if pattern["type"] in query_lower:
relevance += 0.5
# Name matching (for functions/classes)
if "name" in pattern and pattern["name"].lower() in query_lower:
relevance += 0.7
# Content matching
if query_lower in pattern["content"].lower():
relevance += 0.4
# Module/import matching
if "modules" in pattern:
for module in pattern["modules"]:
if module.lower() in query_lower:
relevance += 0.3
return min(relevance, 1.0)
def _calculate_error_relevance(self, pattern: dict[str, Any], query: str) -> float:
"""Calculate relevance score for error patterns."""
relevance = 0.0
query_lower = query.lower()
# Error type matching
if pattern["subtype"] in query_lower:
relevance += 0.6
# Content matching
if "content" in pattern and query_lower in pattern["content"].lower():
relevance += 0.5
# Context relevance boost
if pattern["type"] == "context" and pattern.get("relevance") == "high":
relevance += 0.3
return min(relevance, 1.0)
def _calculate_text_relevance(self, content: str, query: str) -> float:
"""Simple text relevance calculation."""
query_lower = query.lower()
content_lower = content.lower()
# Simple keyword matching
query_words = query_lower.split()
content_words = content_lower.split()
matches = sum(1 for word in query_words if word in content_words)
return matches / len(query_words) if query_words else 0.0