#!/usr/bin/env python3
import os
import json
import asyncio
import hashlib
import sqlite3
import sys
from pathlib import Path
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass, asdict
import fnmatch
import logging
import re
from datetime import datetime
# Configure logging first
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
handlers=[logging.StreamHandler(sys.stderr)]
)
logger = logging.getLogger("code-context-mcp")
REDIS_DB_NUMBER = 0
REDIS_URL_DEFAULT = f"redis://192.168.0.200:6378/{REDIS_DB_NUMBER}"
# Import necessary libraries
try:
import numpy as np
from sentence_transformers import SentenceTransformer
import redis
from redis.commands.search.field import VectorField, TextField
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
from redis.commands.search.query import Query
except ImportError as e:
logger.error(f"Missing required library: {e}")
sys.exit(1)
# Optional imports
try:
import esprima
except ImportError:
esprima = None
logger.warning("Optional library missing: esprima (JavaScript/TypeScript support)")
try:
import sqlparse
except ImportError:
sqlparse = None
logger.warning("Optional library missing: sqlparse (SQL support)")
# Import MCP components
from mcp.server.models import InitializationOptions
from mcp.server import NotificationOptions, Server
from mcp.types import (
Resource, Tool, TextContent, EmbeddedResource,
CallToolRequest, ListResourcesRequest, ListToolsRequest,
ReadResourceRequest
)
import mcp.server.stdio
import mcp.types as types
@dataclass
class CodeEntity:
"""Represents a code entity (function, class, variable, etc.)"""
name: str
type: str
file_path: str
start_line: int
end_line: int
content: str
signature: str
docstring: Optional[str] = None
dependencies: Optional[List[str]] = None
embedding: Optional[List[float]] = None
def __post_init__(self):
if self.dependencies is None:
self.dependencies = []
@dataclass
class FileContext:
"""Represents file-level context"""
file_path: str
language: str
size: int
last_modified: float
imports: List[str]
exports: List[str]
entities: List[CodeEntity]
summary: str
embedding: Optional[List[float]] = None
file_hash: str = ""
# Real code parser using AST for Python
class CodeParser:
def detect_language(self, file_path: str) -> Optional[str]:
ext = Path(file_path).suffix.lower()
return {
'.py': 'python',
'.js': 'javascript',
'.ts': 'typescript',
'.sql': 'sql'
}.get(ext)
def parse_file(self, file_path: str, content: str) -> Optional[FileContext]:
language = self.detect_language(file_path)
if not language:
return None
file_stats = os.stat(file_path)
file_hash = hashlib.md5(content.encode('utf-8')).hexdigest()
if language == 'python':
return self._parse_python_file(file_path, content, file_stats, file_hash)
elif language in ['javascript', 'typescript']:
if esprima is None:
return FileContext(
file_path=file_path, language=language, size=file_stats.st_size,
last_modified=file_stats.st_mtime, imports=[], exports=[],
entities=[], summary=f"File in {language} (parser not available)", file_hash=file_hash
)
return self._parse_javascript_file(file_path, content, file_stats, file_hash)
elif language == 'sql':
if sqlparse is None:
return FileContext(
file_path=file_path, language=language, size=file_stats.st_size,
last_modified=file_stats.st_mtime, imports=[], exports=[],
entities=[], summary=f"File in {language} (parser not available)", file_hash=file_hash
)
return self._parse_sql_file(file_path, content, file_stats, file_hash)
else:
# Dummy for other languages
return FileContext(
file_path=file_path, language=language, size=file_stats.st_size,
last_modified=file_stats.st_mtime, imports=[], exports=[],
entities=[], summary=f"File in {language}", file_hash=file_hash
)
def _parse_python_file(self, file_path: str, content: str, file_stats: os.stat_result, file_hash: str) -> FileContext:
import ast
try:
tree = ast.parse(content, filename=file_path)
except SyntaxError:
return FileContext(
file_path=file_path, language='python', size=file_stats.st_size,
last_modified=file_stats.st_mtime, imports=[], exports=[],
entities=[], summary="Syntax error in Python file", file_hash=file_hash
)
entities = []
imports = []
exports = []
class CodeVisitor(ast.NodeVisitor):
def __init__(self):
self.entities = []
self.imports = []
self.exports = []
def visit_FunctionDef(self, node):
entity = CodeEntity(
name=node.name,
type='function',
file_path=file_path,
start_line=node.lineno,
end_line=getattr(node, 'end_lineno', node.lineno),
signature=self._get_function_signature(node),
content=ast.get_source_segment(content, node) or '',
docstring=ast.get_docstring(node) or '',
dependencies=self._extract_dependencies(node)
)
self.entities.append(entity)
def visit_ClassDef(self, node):
entity = CodeEntity(
name=node.name,
type='class',
file_path=file_path,
start_line=node.lineno,
end_line=getattr(node, 'end_lineno', node.lineno),
signature=f"class {node.name}",
content=ast.get_source_segment(content, node) or '',
docstring=ast.get_docstring(node) or '',
dependencies=[]
)
self.entities.append(entity)
def visit_Import(self, node):
for alias in node.names:
self.imports.append(alias.name)
def visit_ImportFrom(self, node):
module = node.module or ''
for alias in node.names:
if module:
self.imports.append(f"{module}.{alias.name}")
else:
self.imports.append(alias.name)
def _get_function_signature(self, node):
args = []
if node.args.args:
for arg in node.args.args:
args.append(arg.arg)
if node.args.vararg:
args.append(f"*{node.args.vararg.arg}")
if node.args.kwarg:
args.append(f"**{node.args.kwarg.arg}")
params = ', '.join(args)
return f"def {node.name}({params})"
def _extract_dependencies(self, node):
deps = set()
for child in ast.walk(node):
if isinstance(child, ast.Name) and isinstance(child.ctx, ast.Load):
deps.add(child.id)
return list(deps)
visitor = CodeVisitor()
visitor.visit(tree)
# Generate summary
summary = f"Python file with {len(visitor.entities)} entities and {len(visitor.imports)} imports"
return FileContext(
file_path=file_path,
language='python',
size=file_stats.st_size,
last_modified=file_stats.st_mtime,
imports=visitor.imports,
exports=visitor.exports,
entities=visitor.entities,
summary=summary,
file_hash=file_hash
)
def _parse_javascript_file(self, file_path: str, content: str, file_stats: os.stat_result, file_hash: str) -> FileContext:
language = 'typescript' if file_path.endswith('.ts') else 'javascript'
try:
tree = esprima.parse(content, options={'tolerant': True, 'loc': True, 'range': True})
except Exception as e:
return FileContext(
file_path=file_path, language=language, size=file_stats.st_size,
last_modified=file_stats.st_mtime, imports=[], exports=[],
entities=[], summary=f"Parse error in {language} file: {e}", file_hash=file_hash
)
entities = []
imports = []
exports = []
class JSVisitor:
def __init__(self):
self.entities = []
self.imports = []
self.exports = []
def visit(self, node):
if hasattr(node, 'type'):
if node.type == 'FunctionDeclaration':
self.visit_FunctionDeclaration(node)
elif node.type == 'ClassDeclaration':
self.visit_ClassDeclaration(node)
elif node.type == 'ImportDeclaration':
self.visit_ImportDeclaration(node)
elif node.type == 'ExportNamedDeclaration':
self.visit_ExportNamedDeclaration(node)
elif node.type == 'ExportDefaultDeclaration':
self.visit_ExportDefaultDeclaration(node)
# Recurse on children
for key, value in node.__dict__.items():
if isinstance(value, list):
for item in value:
if hasattr(item, 'type'):
self.visit(item)
elif hasattr(value, 'type'):
self.visit(value)
def visit_FunctionDeclaration(self, node):
name = node.id.name if node.id else 'anonymous'
params = ', '.join(p.name for p in node.params if hasattr(p, 'name'))
signature = f"function {name}({params})"
start_line = node.loc.start.line if node.loc else 0
end_line = node.loc.end.line if node.loc else start_line
content_snippet = content[node.range[0]:node.range[1]] if node.range else ''
entity = CodeEntity(
name=name,
type='function',
file_path=file_path,
start_line=start_line,
end_line=end_line,
signature=signature,
content=content_snippet,
docstring='',
dependencies=[]
)
self.entities.append(entity)
def visit_ClassDeclaration(self, node):
name = node.id.name
signature = f"class {name}"
start_line = node.loc.start.line if node.loc else 0
end_line = node.loc.end.line if node.loc else start_line
content_snippet = content[node.range[0]:node.range[1]] if node.range else ''
entity = CodeEntity(
name=name,
type='class',
file_path=file_path,
start_line=start_line,
end_line=end_line,
signature=signature,
content=content_snippet,
docstring='',
dependencies=[]
)
self.entities.append(entity)
def visit_ImportDeclaration(self, node):
for spec in node.specifiers:
if spec.type == 'ImportSpecifier':
self.imports.append(spec.imported.name)
elif spec.type == 'ImportDefaultSpecifier':
self.imports.append(spec.local.name)
def visit_ExportNamedDeclaration(self, node):
if node.declaration:
self.visit(node.declaration)
def visit_ExportDefaultDeclaration(self, node):
if node.declaration:
self.visit(node.declaration)
visitor = JSVisitor()
visitor.visit(tree)
summary = f"{language.capitalize()} file with {len(visitor.entities)} entities and {len(visitor.imports)} imports"
return FileContext(
file_path=file_path,
language=language,
size=file_stats.st_size,
last_modified=file_stats.st_mtime,
imports=visitor.imports,
exports=visitor.exports,
entities=visitor.entities,
summary=summary,
file_hash=file_hash
)
def _parse_sql_file(self, file_path: str, content: str, file_stats: os.stat_result, file_hash: str) -> FileContext:
"""Parse SQL file and extract entities like tables, views, functions, etc."""
try:
# Parse the SQL content
parsed = sqlparse.parse(content)
except Exception as e:
return FileContext(
file_path=file_path, language='sql', size=file_stats.st_size,
last_modified=file_stats.st_mtime, imports=[], exports=[],
entities=[], summary=f"SQL parse error: {e}", file_hash=file_hash
)
entities = []
imports = [] # SQL doesn't have imports in the same way
exports = [] # SQL doesn't have exports
for statement in parsed:
if statement.get_type() is None:
continue
stmt_type = statement.get_type()
stmt_content = str(statement).strip()
# Extract different SQL entities
if stmt_type in ['CREATE', 'CREATE OR REPLACE']:
entity = self._extract_sql_create_entity(statement, stmt_content, file_path)
if entity:
entities.append(entity)
elif stmt_type == 'SELECT':
# For complex SELECT statements, we might want to index them as queries
entity = CodeEntity(
name=f"query_{len(entities)}",
type='query',
file_path=file_path,
start_line=getattr(statement, 'start_line', 0),
end_line=getattr(statement, 'end_line', getattr(statement, 'start_line', 0)),
signature="SELECT query",
content=stmt_content,
docstring='',
dependencies=[]
)
entities.append(entity)
summary = f"SQL file with {len(entities)} entities"
return FileContext(
file_path=file_path,
language='sql',
size=file_stats.st_size,
last_modified=file_stats.st_mtime,
imports=imports,
exports=exports,
entities=entities,
summary=summary,
file_hash=file_hash
)
def _extract_sql_create_entity(self, statement, content: str, file_path: str) -> Optional[CodeEntity]:
"""Extract entity information from CREATE statements."""
content_upper = content.upper()
# Find the object type and name
if 'CREATE TABLE' in content_upper:
# Extract table name
match = re.search(r'CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?["`\'"]?(\w+)["`\'"]?', content, re.IGNORECASE)
if match:
table_name = match.group(1)
return CodeEntity(
name=table_name,
type='table',
file_path=file_path,
start_line=getattr(statement, 'start_line', 0),
end_line=getattr(statement, 'end_line', getattr(statement, 'start_line', 0)),
signature=f"CREATE TABLE {table_name}",
content=content,
docstring='',
dependencies=[]
)
elif 'CREATE VIEW' in content_upper:
match = re.search(r'CREATE\s+(?:OR\s+REPLACE\s+)?VIEW\s+["`\'"]?(\w+)["`\'"]?', content, re.IGNORECASE)
if match:
view_name = match.group(1)
return CodeEntity(
name=view_name,
type='view',
file_path=file_path,
start_line=getattr(statement, 'start_line', 0),
end_line=getattr(statement, 'end_line', getattr(statement, 'start_line', 0)),
signature=f"CREATE VIEW {view_name}",
content=content,
docstring='',
dependencies=[]
)
elif 'CREATE FUNCTION' in content_upper or 'CREATE OR REPLACE FUNCTION' in content_upper:
match = re.search(r'CREATE\s+(?:OR\s+REPLACE\s+)?FUNCTION\s+["`\'"]?(\w+)\s*\(', content, re.IGNORECASE)
if match:
func_name = match.group(1)
return CodeEntity(
name=func_name,
type='function',
file_path=file_path,
start_line=getattr(statement, 'start_line', 0),
end_line=getattr(statement, 'end_line', getattr(statement, 'start_line', 0)),
signature=f"CREATE FUNCTION {func_name}",
content=content,
docstring='',
dependencies=[]
)
elif 'CREATE PROCEDURE' in content_upper:
match = re.search(r'CREATE\s+(?:OR\s+REPLACE\s+)?PROCEDURE\s+["`\'"]?(\w+)["`\'"]?', content, re.IGNORECASE)
if match:
proc_name = match.group(1)
return CodeEntity(
name=proc_name,
type='procedure',
file_path=file_path,
start_line=getattr(statement, 'start_line', 0),
end_line=getattr(statement, 'end_line', getattr(statement, 'start_line', 0)),
signature=f"CREATE PROCEDURE {proc_name}",
content=content,
docstring='',
dependencies=[]
)
return None
# --- VectorStore Class (Refactored for RedisSearch KNN) ---
class VectorStore:
EMBEDDING_DIM = 384
INDEX_NAME = 'code_index'
def __init__(self, redis_url: str = REDIS_URL_DEFAULT):
try:
self.redis_client = redis.from_url(redis_url, decode_responses=False)
self.redis_client.ping()
logger.info("Connected to Redis successfully")
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
raise
try:
# Use a robust model for code embedding
self.model = SentenceTransformer('all-MiniLM-L6-v2')
self.EMBEDDING_DIM = self.model.get_sentence_embedding_dimension()
except Exception as e:
logger.warning(f"Failed to load MiniLM: {e}")
self.model = SentenceTransformer('all-MiniLM-L6-v2') # fallback
self.EMBEDDING_DIM = self.model.get_sentence_embedding_dimension()
self._create_search_index()
def _create_search_index(self):
"""Creates the RediSearch index for both files and entities."""
try:
self.redis_client.ft(self.INDEX_NAME).dropindex(delete_documents=True)
logger.info(f"Dropped existing RediSearch index '{self.INDEX_NAME}'.")
except redis.ResponseError:
pass
schema = (
TextField(name="file_path", sortable=True),
TextField(name="name", sortable=True, no_stem=True),
TextField(name="type", sortable=True),
VectorField(
name="embedding",
algorithm="HNSW",
attributes={
"TYPE": "FLOAT32",
"DIM": self.EMBEDDING_DIM,
"DISTANCE_METRIC": "COSINE"
}
)
)
definition = IndexDefinition(
prefix=['file:', 'entity:'],
index_type=IndexType.HASH
)
self.redis_client.ft(self.INDEX_NAME).create_index(schema, definition=definition)
logger.info(f"Created RediSearch index '{self.INDEX_NAME}' successfully (DIM: {self.EMBEDDING_DIM}).")
def generate_embedding(self, text: str) -> List[float]:
embedding = self.model.encode([text], convert_to_numpy=True)[0]
return embedding.tolist()
def store_file_context(self, context: FileContext):
"""Store file context with embeddings in Redis"""
if not context.embedding:
context.embedding = self.generate_embedding(context.summary)
# Convert embedding to binary format (FLOAT32)
embedding_bytes = np.array(context.embedding, dtype=np.float32).tobytes()
# Store file context
key = f"file:{context.file_hash}"
self.redis_client.hset(key, mapping={
"file_path": context.file_path,
"summary": context.summary,
"embedding": embedding_bytes,
"type": "file",
# Add other fields needed for retrieval/context
"language": context.language,
})
# Store entities
for entity in context.entities:
if not entity.embedding:
entity_text = f"{entity.signature}\n{entity.docstring or ''}\n{entity.content[:500]}"
entity.embedding = self.generate_embedding(entity_text)
entity_embedding_bytes = np.array(entity.embedding, dtype=np.float32).tobytes()
entity_key = f"entity:{context.file_hash}:{entity.name}:{entity.start_line}"
self.redis_client.hset(entity_key, mapping={
"name": entity.name,
"type": entity.type,
"file_path": entity.file_path,
"start_line": str(entity.start_line),
"signature": entity.signature,
"content": entity.content,
"embedding": entity_embedding_bytes,
"dependencies": json.dumps(entity.dependencies) # Store dependencies as JSON string
})
def similarity_search(self, query: str, top_k: int = 10, search_type: str = "both") -> List[Dict]:
"""Search for similar code entities or files using the RediSearch Query Engine (KNN)."""
query_embedding = self.generate_embedding(query)
query_vector_bytes = np.array(query_embedding, dtype=np.float32).tobytes()
filter_expr = ""
if search_type == "entities":
filter_expr = "(@type:function | @type:class)"
elif search_type == "files":
filter_expr = "(@type:file)"
# NOTE: Returning all relevant fields (content, signature, name, etc.)
if filter_expr:
knn_query = f"{filter_expr} =>[KNN {top_k} @embedding $vec AS score]"
else:
knn_query = f"* =>[KNN {top_k} @embedding $vec AS score]"
try:
results = self.redis_client.ft(self.INDEX_NAME).search(
Query(knn_query)
.return_fields('file_path', 'name', 'type', 'signature', 'content', 'start_line', 'summary', 'score')
.sort_by('score', asc=False)
.dialect(2),
query_params={'vec': query_vector_bytes}
)
except Exception as e:
logger.error(f"RediSearch error during similarity search: {e}")
return []
final_results = []
for doc in results.docs:
def decode_field(field):
if field is None:
return None
if isinstance(field, bytes):
return field.decode('utf-8')
return str(field)
doc_dict = {
'type': decode_field(doc.type) or 'unknown',
'similarity': float(doc.score),
'file_path': decode_field(doc.file_path) or 'N/A',
'name': decode_field(getattr(doc, 'name', None)),
'signature': decode_field(getattr(doc, 'signature', None)),
'content': decode_field(getattr(doc, 'content', None)),
'summary': decode_field(getattr(doc, 'summary', None)),
'start_line': int(decode_field(getattr(doc, 'start_line', None))) if getattr(doc, 'start_line', None) else 0,
}
final_results.append({k: v for k, v in doc_dict.items() if v is not None})
return final_results
# --- CodeContextManager Class ---
class CodeContextManager:
"""Main class for managing code context and indexing"""
def __init__(self, redis_url: str = REDIS_URL_DEFAULT):
self.parser = CodeParser()
self.vector_store = VectorStore(redis_url)
self.db_path = "code_context.db"
self._init_database()
logger.info("Code Context Manager initialized successfully")
def _init_database(self):
"""Initialize SQLite database for metadata"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS indexed_files (
file_path TEXT PRIMARY KEY,
file_hash TEXT,
last_indexed TIMESTAMP,
last_modified REAL,
entity_count INTEGER,
imports TEXT -- Store imports/exports as JSON string
)
''')
conn.commit()
conn.close()
def index_file(self, file_path: str) -> Optional[FileContext]:
"""Index a single file and store its context and embeddings."""
# ... (implementation remains similar: read file, parse, store in vector_store, update DB)
try:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
context = self.parser.parse_file(file_path, content)
if context:
self.vector_store.store_file_context(context)
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('''
INSERT OR REPLACE INTO indexed_files
(file_path, file_hash, last_indexed, last_modified, entity_count, imports)
VALUES (?, ?, ?, ?, ?, ?)
''', (
file_path, context.file_hash, datetime.now().isoformat(),
context.last_modified, len(context.entities), json.dumps(context.imports)
))
conn.commit()
conn.close()
logger.info(f"Indexed {file_path}: {len(context.entities)} entities")
return context
except Exception as e:
logger.error(f"Error indexing {file_path}: {e}")
return None
def index_directory(self, directory: str, patterns: Optional[List[str]] = None, ignore_patterns: Optional[List[str]] = None) -> Dict[str, Any]:
"""Index all files in a directory."""
indexed_files = []
for root, dirs, files in os.walk(directory):
for file in files:
file_path = os.path.join(root, file)
if patterns:
if not any(fnmatch.fnmatch(file_path, pattern) for pattern in patterns):
continue
if ignore_patterns:
if any(fnmatch.fnmatch(file_path, pattern) for pattern in ignore_patterns):
continue
context = self.index_file(file_path)
if context:
indexed_files.append({
"file_path": file_path,
"entities": len(context.entities),
"size": context.size
})
return {"status": "success", "indexed_files": indexed_files}
def get_project_context(self, query: str, max_files: int = 5, max_entities: int = 10) -> List[Dict[str, Any]]:
"""Get relevant project context for a development task using similarity search."""
return self.vector_store.similarity_search(query, top_k=max_files + max_entities)
# --- Missing Function: get_file_dependencies ---
def get_file_dependencies(self, file_path: str) -> Dict[str, Any]:
"""
Retrieves dependency information for a given file from the SQLite database.
This is a placeholder implementation that retrieves stored import data.
In a real scenario, this would involve querying entities for dependency links.
"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Retrieve the stored JSON string of imports
cursor.execute('SELECT imports FROM indexed_files WHERE file_path = ?', (file_path,))
result = cursor.fetchone()
conn.close()
if result and result[0]:
try:
imports = json.loads(result[0])
return {
"file_path": file_path,
"dependencies": imports,
"dependency_count": len(imports)
}
except json.JSONDecodeError:
logger.error(f"Failed to decode imports for {file_path}")
return {"file_path": file_path, "error": "Invalid dependency data stored."}
return {"file_path": file_path, "dependencies": [], "dependency_count": 0}
def remove_indexed_file(self, file_path: str) -> bool:
"""Remove a file and its entities from both Redis and SQLite."""
try:
# Get file_hash from SQLite
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('SELECT file_hash FROM indexed_files WHERE file_path = ?', (file_path,))
row = cursor.fetchone()
if not row:
logger.warning(f"No indexed data found for {file_path}")
conn.close()
return False
file_hash = row[0]
# Remove from Redis: file and all related entities
file_key = f"file:{file_hash}"
self.vector_store.redis_client.delete(file_key)
# Remove all entities for this file
entity_pattern = f"entity:{file_hash}:*"
for key in self.vector_store.redis_client.scan_iter(entity_pattern):
self.vector_store.redis_client.delete(key)
# Remove from SQLite
cursor.execute('DELETE FROM indexed_files WHERE file_path = ?', (file_path,))
conn.commit()
conn.close()
logger.info(f"Removed indexed data for {file_path}")
return True
except Exception as e:
logger.error(f"Error removing indexed data for {file_path}: {e}")
return False
def clear_all_indexed_data(self) -> bool:
"""Remove all indexed data from Redis and SQLite."""
try:
# Drop RediSearch index and all keys
try:
self.vector_store.redis_client.ft(self.vector_store.INDEX_NAME).dropindex(delete_documents=True)
except Exception as e:
logger.warning(f"RediSearch index drop failed or not found: {e}")
# Remove all file/entity keys (in case any remain)
for key in self.vector_store.redis_client.scan_iter("file:*"):
self.vector_store.redis_client.delete(key)
for key in self.vector_store.redis_client.scan_iter("entity:*"):
self.vector_store.redis_client.delete(key)
# Recreate the index
self.vector_store._create_search_index()
# Clear SQLite table
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('DELETE FROM indexed_files')
conn.commit()
conn.close()
logger.info("Cleared all indexed data from Redis and SQLite.")
return True
except Exception as e:
logger.error(f"Error clearing all indexed data: {e}")
return False
# --- MCP Server Implementation ---
# 1. Initialize the MCP Server
app = Server("code-context-manager")
context_manager = CodeContextManager()
# 3. Define the Tool List Handler
@app.list_tools()
async def handle_list_tools() -> list[Tool]:
"""Expose the available functions as tools for the LLM."""
return [
Tool(
name="index_directory",
description="Indexes an entire project directory for file and entity context, making them available for vector search. Supports Python, JavaScript, and TypeScript files. This must be called before using 'search_code_context' for a new or updated project.",
inputSchema={
"type": "object",
"properties": {
"directory": {"type": "string", "description": "The root directory to index (e.g., '.')."},
"patterns": {"type": "array", "items": {"type": "string"}, "description": "File patterns to include (e.g., ['*.py', '*.js']). Optional."},
"ignore_patterns": {"type": "array", "items": {"type": "string"}, "description": "Patterns to ignore (e.g., ['venv/*', '__pycache__/*']). Optional."}
},
"required": ["directory"]
},
),
Tool(
name="index_file",
description="Indexes a single file for code context and entities, making it available for vector search. Supports Python, JavaScript, and TypeScript files.",
inputSchema={
"type": "object",
"properties": {
"file_path": {"type": "string", "description": "The path to the file to index (e.g., 'code_context_mcp.py')."}
},
"required": ["file_path"]
},
),
Tool(
name="search_code_context",
description="Performs a **vector similarity search** across all indexed files and code entities (functions, classes) from Python, JavaScript, and TypeScript files to find the most relevant code snippets or files for a natural language query (Retrieval-Augmented Generation/RAG).",
inputSchema={
"type": "object",
"properties": {
"query": {"type": "string", "description": "The natural language query describing the code, concept, or feature needed."},
"max_files": {"type": "integer", "description": "Maximum number of relevant files to return.", "default": 5},
"max_entities": {"type": "integer", "description": "Maximum number of relevant code entities (functions/classes) to return.", "default": 10}
},
"required": ["query"]
},
),
Tool(
name="read_file",
description="Reads and returns the complete text content of a single file. Use this for specific files identified by other tools.",
inputSchema={
"type": "object",
"properties": {
"file_path": {"type": "string", "description": "The exact, relative path to the file to read."}
},
"required": ["file_path"]
},
),
Tool(
name="list_directory_contents",
description="Lists the files and subdirectories in a specific path. Useful for exploring the project structure.",
inputSchema={
"type": "object",
"properties": {
"path": {"type": "string", "description": "The relative path to list files from (e.g., 'src/utils').", "default": "."},
"recursive": {"type": "boolean", "description": "If true, recursively lists all files in subdirectories (use with caution for large folders).", "default": False}
},
"required": []
},
),
Tool(
name="get_file_dependencies",
description="Retrieves the imported dependencies and basic summary context for a specific, indexed file path.",
inputSchema={
"type": "object",
"properties": {
"file_path": {"type": "string", "description": "The exact path to the file whose dependencies are needed."}
},
"required": ["file_path"]
},
),
Tool(
name="remove_indexed_file",
description="Removes all indexed data for a specific file from both Redis and SQLite.",
inputSchema={
"type": "object",
"properties": {
"file_path": {"type": "string", "description": "The exact, relative path to the file to remove from the index."}
},
"required": ["file_path"]
},
),
Tool(
name="clear_all_indexed_data",
description="Removes all indexed data from both Redis and SQLite, resetting the index.",
inputSchema={
"type": "object",
"properties": {},
"required": []
},
),
]
# 4. Define the Tool Call Handler
@app.call_tool()
async def handle_call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
"""Map tool requests from the LLM to the CodeContextManager methods."""
tool_name = name
args = arguments or {}
try:
if tool_name == "index_directory":
result = context_manager.index_directory(**args)
return [TextContent(type="text", text=json.dumps(result))]
elif tool_name == "index_file":
file_path = args.get("file_path")
if not file_path:
return [TextContent(type="text", text=json.dumps({"error": "Missing required parameter: file_path"}))]
context = context_manager.index_file(file_path)
if context:
result = {
"status": "success",
"file_path": file_path,
"entities": len(context.entities),
"size": context.size
}
else:
result = {"status": "failure", "file_path": file_path}
return [TextContent(type="text", text=json.dumps(result))]
elif tool_name == "search_code_context":
result = context_manager.get_project_context(**args)
return [TextContent(type="text", text=json.dumps(result))]
elif tool_name == "get_file_dependencies":
result = context_manager.get_file_dependencies(**args)
return [TextContent(type="text", text=json.dumps(result))]
# --- NEW Tool Call Handlers ---
elif tool_name == "read_file":
file_path = args.get("file_path")
if not file_path or not os.path.exists(file_path):
return [TextContent(type="text", text=json.dumps({"error": f"File not found or path is invalid: {file_path}"}))]
try:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
return [TextContent(type="text", text=json.dumps({"file_path": file_path, "content": content}))]
except Exception as e:
return [TextContent(type="text", text=json.dumps({"error": f"Failed to read file {file_path}: {e}"}))]
elif tool_name == "list_directory_contents":
path = args.get("path", ".")
recursive = args.get("recursive", False)
abs_path = os.path.join(os.getcwd(), path) # Assuming server runs from project root
if not os.path.isdir(abs_path):
return [TextContent(type="text", text=json.dumps({"error": f"Path is not a valid directory: {path}"}))]
file_list = []
try:
if recursive:
for root, dirs, files in os.walk(abs_path):
relative_root = os.path.relpath(root, abs_path)
for d in dirs:
file_list.append({"name": os.path.join(relative_root, d), "type": "directory"})
for f in files:
file_list.append({"name": os.path.join(relative_root, f), "type": "file"})
else:
for item in os.listdir(abs_path):
item_path = os.path.join(abs_path, item)
item_type = "directory" if os.path.isdir(item_path) else "file"
file_list.append({"name": item, "type": item_type})
return [TextContent(type="text", text=json.dumps({"path": path, "contents": file_list, "count": len(file_list)}))]
except Exception as e:
return [TextContent(type="text", text=json.dumps({"error": f"Failed to list directory contents for {path}: {e}"}))]
elif tool_name == "remove_indexed_file":
file_path = args.get("file_path")
if not file_path:
return [TextContent(type="text", text=json.dumps({"error": "Missing required parameter: file_path"}))]
result = context_manager.remove_indexed_file(file_path)
return [TextContent(type="text", text=json.dumps({"status": "success" if result else "failure"}))]
elif tool_name == "clear_all_indexed_data":
result = context_manager.clear_all_indexed_data()
return [TextContent(type="text", text=json.dumps({"status": "success" if result else "failure"}))]
else:
return [TextContent(type="text", text=json.dumps({"error": f"Unknown tool: {tool_name}"}))]
except Exception as e:
return [TextContent(type="text", text=json.dumps({"error": f"Tool execution failed: {str(e)}"}))]
def main():
"""Entry point for running the MCP server."""
asyncio.run(_main())
async def _main():
"""Main async function for the MCP server."""
async with mcp.server.stdio.stdio_server() as (read_stream, write_stream):
await app.run(
read_stream,
write_stream,
InitializationOptions(
server_name="code-context-manager",
server_version="1.0.0",
capabilities=app.get_capabilities(
notification_options=NotificationOptions(),
experimental_capabilities={},
),
),
)
if __name__ == "__main__":
if len(sys.argv) > 1 and sys.argv[1] == "index":
path = sys.argv[2] if len(sys.argv) > 2 else "."
if os.path.isfile(path):
print(f"Indexing file: {path}")
context = context_manager.index_file(path)
result = {"status": "success", "indexed_files": [{"file_path": path, "entities": len(context.entities) if context else 0, "size": context.size if context else 0}]}
else:
print(f"Indexing directory: {path}")
result = context_manager.index_directory(path, ignore_patterns=["venv/*", "__pycache__/*", "*.pyc", "*.db"])
print(f"Indexing completed: {result}")
else:
main()