"""
Token embedding support for efficient context storage.
This module provides the EmbeddingManager class for converting text to
vector embeddings and storing/searching historical context efficiently.
GitHub Issue #9: Token Embedding for Context Efficiency
"""
import json
import sqlite3
from pathlib import Path
from typing import Dict, Any, List, Optional
import numpy as np
# Lazy import to avoid loading model at module import time
_SentenceTransformer = None
def _get_sentence_transformer():
"""Lazy import of SentenceTransformer."""
global _SentenceTransformer
if _SentenceTransformer is None:
from sentence_transformers import SentenceTransformer
_SentenceTransformer = SentenceTransformer
return _SentenceTransformer
class EmbeddingManager:
"""Manages text embeddings and vector storage for historical context."""
# Model configuration
MODEL_NAME = "all-MiniLM-L6-v2"
EMBEDDING_DIM = 384
def __init__(self, db_path: Path):
"""
Initialize the embedding manager.
Args:
db_path: Path to SQLite database for vector storage
"""
self.db_path = db_path
self.model: Optional[Any] = None
self._ensure_schema()
def _ensure_schema(self):
"""Create database schema if it doesn't exist."""
# Ensure parent directory exists
self.db_path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(self.db_path)
conn.execute("""
CREATE TABLE IF NOT EXISTS task_archive (
id TEXT PRIMARY KEY,
title TEXT NOT NULL,
description TEXT,
embedding BLOB NOT NULL,
created_at TEXT,
completed_at TEXT,
metadata TEXT
)
""")
conn.execute("""
CREATE TABLE IF NOT EXISTS message_archive (
id TEXT PRIMARY KEY,
sender TEXT NOT NULL,
message TEXT NOT NULL,
embedding BLOB NOT NULL,
timestamp TEXT,
metadata TEXT
)
""")
conn.execute("""
CREATE INDEX IF NOT EXISTS idx_task_created
ON task_archive(created_at)
""")
conn.execute("""
CREATE INDEX IF NOT EXISTS idx_message_timestamp
ON message_archive(timestamp)
""")
conn.commit()
conn.close()
def _get_model(self):
"""Lazy-load the embedding model."""
if self.model is None:
SentenceTransformer = _get_sentence_transformer()
self.model = SentenceTransformer(self.MODEL_NAME)
return self.model
def encode_text(self, text: str) -> np.ndarray:
"""
Convert text to embedding vector.
Args:
text: Input text to encode
Returns:
384-dimensional embedding vector (float32)
"""
model = self._get_model()
embedding = model.encode(text, convert_to_numpy=True)
return embedding.astype(np.float32)
def archive_task(self, task: Dict[str, Any]) -> None:
"""
Archive a completed task with its embedding.
Args:
task: Task dictionary with id, title, description, etc.
"""
# Create searchable text from task
title = task.get("title", task.get("task", ""))
description = task.get("description", "")
search_text = f"{title} {description}".strip()
if not search_text:
return # Don't archive empty tasks
embedding = self.encode_text(search_text)
conn = sqlite3.connect(self.db_path)
conn.execute(
"""
INSERT OR REPLACE INTO task_archive
(id, title, description, embedding, created_at, completed_at, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(
task.get("id", task.get("task", "")[:50]),
title,
description,
embedding.tobytes(),
task.get("created_at", ""),
task.get("completed_at", ""),
json.dumps(task.get("metadata", {})),
),
)
conn.commit()
conn.close()
def archive_message(self, message: Dict[str, Any]) -> None:
"""
Archive a message with its embedding.
Args:
message: Message dictionary with sender, message text, etc.
"""
msg_text = message.get("message", "")
if not msg_text:
return # Don't archive empty messages
embedding = self.encode_text(msg_text)
# Generate ID if not provided
msg_id = message.get("id")
if not msg_id:
sender = message.get("sender", "unknown")
timestamp = message.get("timestamp", "")
msg_id = f"{sender}_{timestamp}_{hash(msg_text) % 10000}"
conn = sqlite3.connect(self.db_path)
conn.execute(
"""
INSERT OR REPLACE INTO message_archive
(id, sender, message, embedding, timestamp, metadata)
VALUES (?, ?, ?, ?, ?, ?)
""",
(
msg_id,
message.get("sender", "unknown"),
msg_text,
embedding.tobytes(),
message.get("timestamp", ""),
json.dumps(message.get("metadata", {})),
),
)
conn.commit()
conn.close()
def search_tasks(self, query: str, limit: int = 5) -> List[Dict[str, Any]]:
"""
Search archived tasks by semantic similarity.
Args:
query: Search query
limit: Maximum number of results
Returns:
List of matching tasks with similarity scores
"""
if not query.strip():
return []
query_embedding = self.encode_text(query)
conn = sqlite3.connect(self.db_path)
cursor = conn.execute(
"""
SELECT id, title, description, embedding, created_at, completed_at, metadata
FROM task_archive
"""
)
results = []
for row in cursor.fetchall():
task_embedding = np.frombuffer(row[3], dtype=np.float32)
similarity = self._cosine_similarity(query_embedding, task_embedding)
results.append(
{
"id": row[0],
"title": row[1],
"description": row[2],
"created_at": row[4],
"completed_at": row[5],
"metadata": json.loads(row[6]) if row[6] else {},
"similarity": float(similarity),
}
)
conn.close()
# Sort by similarity and return top results
results.sort(key=lambda x: x["similarity"], reverse=True)
return results[:limit]
def search_messages(self, query: str, limit: int = 5) -> List[Dict[str, Any]]:
"""
Search archived messages by semantic similarity.
Args:
query: Search query
limit: Maximum number of results
Returns:
List of matching messages with similarity scores
"""
if not query.strip():
return []
query_embedding = self.encode_text(query)
conn = sqlite3.connect(self.db_path)
cursor = conn.execute(
"""
SELECT id, sender, message, embedding, timestamp, metadata
FROM message_archive
"""
)
results = []
for row in cursor.fetchall():
msg_embedding = np.frombuffer(row[3], dtype=np.float32)
similarity = self._cosine_similarity(query_embedding, msg_embedding)
results.append(
{
"id": row[0],
"sender": row[1],
"message": row[2],
"timestamp": row[4],
"metadata": json.loads(row[5]) if row[5] else {},
"similarity": float(similarity),
}
)
conn.close()
results.sort(key=lambda x: x["similarity"], reverse=True)
return results[:limit]
def get_task_count(self) -> int:
"""Get the number of archived tasks."""
conn = sqlite3.connect(self.db_path)
cursor = conn.execute("SELECT COUNT(*) FROM task_archive")
count = cursor.fetchone()[0]
conn.close()
return count
def get_message_count(self) -> int:
"""Get the number of archived messages."""
conn = sqlite3.connect(self.db_path)
cursor = conn.execute("SELECT COUNT(*) FROM message_archive")
count = cursor.fetchone()[0]
conn.close()
return count
@staticmethod
def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
"""Calculate cosine similarity between two vectors."""
norm_a = np.linalg.norm(a)
norm_b = np.linalg.norm(b)
if norm_a == 0 or norm_b == 0:
return 0.0
return float(np.dot(a, b) / (norm_a * norm_b))