embeddings_searcher.py•23.3 kB
#!/usr/bin/env python3
"""
Embeddings searcher for Claude code documentation.
Focused on navigating markdown documentation in the repos directory.
"""
import os
import argparse
import sqlite3
import hashlib
from pathlib import Path
from typing import List, Dict, Tuple
from dataclasses import dataclass
import threading
try:
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
except ImportError:
print("Required packages not found. Install with:")
print("uv add numpy scikit-learn tqdm")
exit(1)
# Try to import sentence-transformers or use ONNX
try:
from sentence_transformers import SentenceTransformer
HAS_SENTENCE_TRANSFORMERS = True
except ImportError:
HAS_SENTENCE_TRANSFORMERS = False
try:
import onnxruntime as ort
from transformers import AutoTokenizer
HAS_ONNX = True
except ImportError:
HAS_ONNX = False
print("Neither sentence-transformers nor onnxruntime available.")
print("Install with: uv add sentence-transformers")
exit(1)
@dataclass
class DocumentChunk:
"""Represents a chunk of documentation."""
content: str
title: str
file_path: str
chunk_index: int
repo_name: str
section_header: str = ""
line_start: int = 0
line_end: int = 0
@dataclass
class SearchResult:
"""Search result with similarity score."""
chunk: DocumentChunk
similarity: float
def __str__(self) -> str:
return f"{self.chunk.repo_name}/{self.chunk.file_path} (score: {self.similarity:.3f})"
class DocumentationSearcher:
"""Embeddings-based searcher for markdown documentation."""
def __init__(self, kb_path: str, db_path: str = "embeddings_docs.db", model_name: str = "all-MiniLM-L6-v2", ignore_dirs: List[str] = None):
self.kb_path = Path(kb_path)
self.repos_path = self.kb_path / "repos"
self.db_path = db_path
self.model_name = model_name
# Default ignore patterns
default_ignores = ['.git', 'node_modules', '__pycache__', '.vscode', '.idea', 'target', 'build', 'dist', '.svn', '.hg', 'third_party', 'out', 'crates', 'vendor', 'test', 'tests']
self.ignore_dirs = ignore_dirs if ignore_dirs is not None else default_ignores
# Initialize embedding model
self._init_model()
# Initialize database
self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
self._db_lock = threading.Lock()
self._init_database()
# Chunking parameters
self.chunk_size = 500 # words per chunk
def _init_model(self):
"""Initialize the embedding model."""
if HAS_SENTENCE_TRANSFORMERS:
print(f"Loading sentence-transformers model: {self.model_name}")
self.model = SentenceTransformer(self.model_name)
self.encode_func = self.model.encode
elif HAS_ONNX:
# Try to use ONNX model if available
onnx_path = self.kb_path / "sentence_model.onnx"
tokenizer_path = self.kb_path / "tokenizer"
if onnx_path.exists() and tokenizer_path.exists():
print("Loading ONNX model")
self.onnx_session = ort.InferenceSession(str(onnx_path))
self.tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path))
self.encode_func = self._encode_with_onnx
else:
print("ONNX model not found, falling back to dummy embeddings")
self.encode_func = self._dummy_encode
else:
print("No embedding model available, using dummy embeddings")
self.encode_func = self._dummy_encode
def _encode_with_onnx(self, texts: List[str]) -> np.ndarray:
"""Encode text using ONNX model."""
embeddings = []
for text in texts:
# Tokenize
inputs = self.tokenizer(text, return_tensors="np", padding=True, truncation=True, max_length=256)
# Run ONNX inference
outputs = self.onnx_session.run(None, {
'input_ids': inputs['input_ids'].astype(np.int64),
'attention_mask': inputs['attention_mask'].astype(np.int64)
})
# Mean pooling
token_embeddings = outputs[0]
attention_mask = inputs['attention_mask']
input_mask_expanded = np.expand_dims(attention_mask, axis=-1)
masked_embeddings = token_embeddings * input_mask_expanded
sum_embeddings = np.sum(masked_embeddings, axis=1)
sum_mask = np.sum(attention_mask, axis=1, keepdims=True)
mean_embeddings = sum_embeddings / sum_mask
# Normalize
norm = np.linalg.norm(mean_embeddings, axis=1, keepdims=True)
normalized_embeddings = mean_embeddings / norm
embeddings.append(normalized_embeddings[0])
return np.array(embeddings, dtype=np.float32)
def _dummy_encode(self, texts: List[str]) -> np.ndarray:
"""Dummy encoding for testing without models."""
return np.random.randn(len(texts), 384).astype(np.float32)
def _init_database(self):
"""Initialize SQLite database for storing embeddings."""
with self._db_lock:
# Create tables
self.conn.execute('''
CREATE TABLE IF NOT EXISTS documents (
id INTEGER PRIMARY KEY,
file_path TEXT UNIQUE,
repo_name TEXT,
title TEXT,
content_hash TEXT,
indexed_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
self.conn.execute('''
CREATE TABLE IF NOT EXISTS chunks (
id INTEGER PRIMARY KEY,
document_id INTEGER,
chunk_index INTEGER,
content TEXT,
section_header TEXT,
line_start INTEGER,
line_end INTEGER,
embedding BLOB,
FOREIGN KEY (document_id) REFERENCES documents (id)
)
''')
# Create indexes only after tables exist
try:
self.conn.execute('''
CREATE INDEX IF NOT EXISTS idx_doc_path ON documents(file_path)
''')
self.conn.execute('''
CREATE INDEX IF NOT EXISTS idx_chunk_doc ON chunks(document_id)
''')
except sqlite3.OperationalError:
# Indexes might already exist or tables might not be ready
pass
self.conn.commit()
def _get_content_hash(self, content: str) -> str:
"""Generate hash for content to detect changes."""
return hashlib.md5(content.encode('utf-8')).hexdigest()
def _extract_title(self, content: str, file_path: str) -> str:
"""Extract title from markdown content."""
lines = content.split('\n')
# Look for first H1 header
for line in lines[:20]:
line = line.strip()
if line.startswith('# '):
return line[2:].strip()
# Fallback to filename
return Path(file_path).stem.replace('-', ' ').replace('_', ' ').title()
def _chunk_markdown(self, content: str) -> List[Tuple[str, str, int, int]]:
"""
Chunk markdown content intelligently.
Returns list of (chunk_content, section_header, line_start, line_end).
"""
lines = content.split('\n')
chunks = []
current_chunk = []
current_section = ""
chunk_start_line = 0
word_count = 0
for i, line in enumerate(lines):
line_stripped = line.strip()
# Detect headers
if line_stripped.startswith('#'):
# Save current chunk if it has content
if current_chunk and word_count > 50:
chunk_content = '\n'.join(current_chunk).strip()
chunks.append((chunk_content, current_section, chunk_start_line, i-1))
# Start new chunk
current_section = line_stripped
current_chunk = [line]
chunk_start_line = i
word_count = len(line.split())
else:
current_chunk.append(line)
word_count += len(line.split())
# Split chunk if too large
if word_count > self.chunk_size:
chunk_content = '\n'.join(current_chunk).strip()
if chunk_content:
chunks.append((chunk_content, current_section, chunk_start_line, i))
# Start new chunk with overlap
overlap_lines = max(1, len(current_chunk) // 4)
current_chunk = current_chunk[-overlap_lines:] + [line]
chunk_start_line = i - overlap_lines
word_count = sum(len(l.split()) for l in current_chunk)
# Add final chunk
if current_chunk and word_count > 50:
chunk_content = '\n'.join(current_chunk).strip()
chunks.append((chunk_content, current_section, chunk_start_line, len(lines)-1))
return chunks
def index_document(self, file_path: Path, repo_name: str) -> bool:
"""Index a single markdown document."""
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
except Exception as e:
print(f"Error reading {file_path}: {e}")
return False
content_hash = self._get_content_hash(content)
rel_path = str(file_path.relative_to(self.repos_path))
title = self._extract_title(content, str(file_path))
# Check if already indexed with same content
with self._db_lock:
cursor = self.conn.execute(
'SELECT id, content_hash FROM documents WHERE file_path = ?',
(rel_path,)
)
row = cursor.fetchone()
if row and row[1] == content_hash:
return True # Already up to date
# Chunk the document
chunks = self._chunk_markdown(content)
if not chunks:
return False
# Generate embeddings for all chunks
chunk_texts = [chunk[0] for chunk in chunks]
embeddings = self.encode_func(chunk_texts)
# Store in database
with self._db_lock:
if row:
# Update existing document
doc_id = row[0]
self.conn.execute(
'UPDATE documents SET title = ?, content_hash = ?, indexed_at = CURRENT_TIMESTAMP WHERE id = ?',
(title, content_hash, doc_id)
)
# Delete old chunks
self.conn.execute('DELETE FROM chunks WHERE document_id = ?', (doc_id,))
else:
# Insert new document
cursor = self.conn.execute(
'INSERT INTO documents (file_path, repo_name, title, content_hash) VALUES (?, ?, ?, ?)',
(rel_path, repo_name, title, content_hash)
)
doc_id = cursor.lastrowid
# Insert chunks with embeddings
for i, ((chunk_content, section_header, line_start, line_end), embedding) in enumerate(zip(chunks, embeddings)):
embedding_blob = embedding.tobytes()
self.conn.execute(
'''INSERT INTO chunks (document_id, chunk_index, content, section_header,
line_start, line_end, embedding) VALUES (?, ?, ?, ?, ?, ?, ?)''',
(doc_id, i, chunk_content, section_header, line_start, line_end, embedding_blob)
)
self.conn.commit()
return True
def index_repository(self, repo_path: Path, force_reindex: bool = False) -> int:
"""Index all markdown files in a repository."""
repo_name = repo_path.name
print(f"Indexing documentation in {repo_name}")
# Find all markdown files
md_files = []
for pattern in ['**/*.md', '**/*.markdown', '**/*.txt']:
found_files = list(repo_path.glob(pattern))
md_files.extend(found_files)
# Filter out files that are too large or in ignored directories
filtered_files = []
for file_path in md_files:
# Skip if in ignored directory
file_parts = file_path.parts
if any(ignore_dir in file_parts for ignore_dir in self.ignore_dirs):
continue
# Skip if too large (>1MB)
try:
if file_path.stat().st_size > 1024 * 1024:
continue
filtered_files.append(file_path)
except OSError:
continue
print(f"Found {len(filtered_files)} markdown files")
# Index files with progress bar
indexed_count = 0
with tqdm(filtered_files, desc=f"Indexing {repo_name}", unit="file") as pbar:
for file_path in pbar:
if self.index_document(file_path, repo_name):
indexed_count += 1
pbar.set_postfix({"indexed": indexed_count})
return indexed_count
def index_all_repositories(self, force_reindex: bool = False) -> int:
"""Index all repositories in the repos directory."""
if not self.repos_path.exists():
print(f"Repositories directory not found: {self.repos_path}")
return 0
repo_dirs = [d for d in self.repos_path.iterdir() if d.is_dir() and not d.name.startswith('.')]
print(f"Found {len(repo_dirs)} repositories")
total_indexed = 0
for repo_dir in repo_dirs:
indexed = self.index_repository(repo_dir, force_reindex)
total_indexed += indexed
print(f"Indexed {indexed} files from {repo_dir.name}")
return total_indexed
def search(self, query: str, max_results: int = 10, min_similarity: float = 0.1) -> List[SearchResult]:
"""Search for relevant documentation chunks."""
if not query.strip():
return []
# Generate query embedding
query_embedding = self.encode_func([query])[0]
# Get all chunks from database
with self._db_lock:
cursor = self.conn.execute('''
SELECT c.content, c.section_header, c.line_start, c.line_end, c.chunk_index,
d.file_path, d.repo_name, d.title, c.embedding
FROM chunks c
JOIN documents d ON c.document_id = d.id
''')
results = []
for row in cursor.fetchall():
content, section_header, line_start, line_end, chunk_index, file_path, repo_name, title, embedding_blob = row
# Reconstruct embedding
embedding = np.frombuffer(embedding_blob, dtype=np.float32)
# Calculate similarity
similarity = cosine_similarity(
query_embedding.reshape(1, -1),
embedding.reshape(1, -1)
)[0][0]
if similarity >= min_similarity:
chunk = DocumentChunk(
content=content,
title=title,
file_path=file_path,
chunk_index=chunk_index,
repo_name=repo_name,
section_header=section_header,
line_start=line_start,
line_end=line_end
)
results.append(SearchResult(chunk=chunk, similarity=similarity))
# Sort by similarity and return top results
results.sort(key=lambda x: x.similarity, reverse=True)
return results[:max_results]
def search_by_repo(self, query: str, repo_name: str, max_results: int = 10) -> List[SearchResult]:
"""Search within a specific repository."""
if not query.strip():
return []
query_embedding = self.encode_func([query])[0]
with self._db_lock:
cursor = self.conn.execute('''
SELECT c.content, c.section_header, c.line_start, c.line_end, c.chunk_index,
d.file_path, d.repo_name, d.title, c.embedding
FROM chunks c
JOIN documents d ON c.document_id = d.id
WHERE d.repo_name = ?
''', (repo_name,))
results = []
for row in cursor.fetchall():
content, section_header, line_start, line_end, chunk_index, file_path, repo_name, title, embedding_blob = row
embedding = np.frombuffer(embedding_blob, dtype=np.float32)
similarity = cosine_similarity(
query_embedding.reshape(1, -1),
embedding.reshape(1, -1)
)[0][0]
chunk = DocumentChunk(
content=content,
title=title,
file_path=file_path,
chunk_index=chunk_index,
repo_name=repo_name,
section_header=section_header,
line_start=line_start,
line_end=line_end
)
results.append(SearchResult(chunk=chunk, similarity=similarity))
results.sort(key=lambda x: x.similarity, reverse=True)
return results[:max_results]
def list_repositories(self) -> List[str]:
"""List all indexed repositories."""
with self._db_lock:
cursor = self.conn.execute('SELECT DISTINCT repo_name FROM documents ORDER BY repo_name')
return [row[0] for row in cursor.fetchall()]
def get_stats(self) -> Dict[str, int]:
"""Get indexing statistics."""
with self._db_lock:
doc_count = self.conn.execute('SELECT COUNT(*) FROM documents').fetchone()[0]
chunk_count = self.conn.execute('SELECT COUNT(*) FROM chunks').fetchone()[0]
repo_count = self.conn.execute('SELECT COUNT(DISTINCT repo_name) FROM documents').fetchone()[0]
return {
'repositories': repo_count,
'documents': doc_count,
'chunks': chunk_count
}
def close(self):
"""Close database connection."""
self.conn.close()
def format_search_results(results: List[SearchResult], query: str) -> str:
"""Format search results for display."""
if not results:
return f"No results found for query: '{query}'"
output = [f"Found {len(results)} results for query: '{query}'\n"]
for i, result in enumerate(results, 1):
chunk = result.chunk
output.append(f"{i}. **{chunk.title}** ({chunk.repo_name})")
output.append(f" File: {chunk.file_path}")
if chunk.section_header:
output.append(f" Section: {chunk.section_header}")
output.append(f" Lines: {chunk.line_start}-{chunk.line_end}")
output.append(f" Similarity: {result.similarity:.3f}")
# Show content preview (first 200 chars)
content_preview = chunk.content[:200].replace('\n', ' ')
if len(chunk.content) > 200:
content_preview += "..."
output.append(f" Preview: {content_preview}")
output.append("")
return "\n".join(output)
def main():
parser = argparse.ArgumentParser(description="Documentation embeddings searcher for Claude code")
parser.add_argument("--kb-path", default="/Users/thypon/kb", help="Path to knowledge base")
parser.add_argument("--db-path", default="embeddings_docs.db", help="Path to embeddings database")
parser.add_argument("--model", default="all-MiniLM-L6-v2", help="Sentence transformer model name")
parser.add_argument("--ignore-dirs", nargs="*", help="Directories to ignore during indexing (default: .git, node_modules, __pycache__, .vscode, .idea, target, build, dist, .svn, .hg, third_party, out, crates, vendor, test, tests)")
# Actions
parser.add_argument("--index", action="store_true", help="Index all repositories")
parser.add_argument("--force", action="store_true", help="Force reindex of all documents")
parser.add_argument("--stats", action="store_true", help="Show indexing statistics")
parser.add_argument("--list-repos", action="store_true", help="List indexed repositories")
# Search
parser.add_argument("--query", help="Search query")
parser.add_argument("--repo", help="Search within specific repository")
parser.add_argument("--max-results", type=int, default=10, help="Maximum results to return")
parser.add_argument("--min-similarity", type=float, default=0.1, help="Minimum similarity threshold")
args = parser.parse_args()
# Show help if no arguments
if len(os.sys.argv) == 1:
parser.print_help()
return
searcher = DocumentationSearcher(args.kb_path, args.db_path, args.model, args.ignore_dirs)
try:
if args.index:
print("Indexing all repositories...")
total = searcher.index_all_repositories(args.force)
print(f"Indexed {total} documents total")
if args.stats:
stats = searcher.get_stats()
print("Statistics:")
print(f" Repositories: {stats['repositories']}")
print(f" Documents: {stats['documents']}")
print(f" Chunks: {stats['chunks']}")
if args.list_repos:
repos = searcher.list_repositories()
print("Indexed repositories:")
for repo in repos:
print(f" - {repo}")
if args.query:
if args.repo:
print(f"Searching in repository '{args.repo}' for: {args.query}")
results = searcher.search_by_repo(args.query, args.repo, args.max_results)
else:
print(f"Searching all repositories for: {args.query}")
results = searcher.search(args.query, args.max_results, args.min_similarity)
output = format_search_results(results, args.query)
print(output)
finally:
searcher.close()
if __name__ == "__main__":
main()