"""Retrieval logic for Lenny RAG using ChromaDB."""
import json
from typing import Literal
import chromadb
from chromadb.utils import embedding_functions
from .utils import (
get_chroma_dir,
get_examples_for_topic,
get_insights_for_topic,
get_topic_by_id,
get_transcript_segment,
list_available_episodes,
load_preprocessed,
load_transcript,
)
# Embedding model (must match what was used in embed.py)
EMBEDDING_MODEL = "BAAI/bge-small-en-v1.5"
COLLECTION_NAME = "lenny"
class LennyRetriever:
"""Retriever for Lenny podcast content."""
def __init__(self):
"""Initialize the retriever with ChromaDB connection."""
self._client = None
self._collection = None
self._embedding_fn = None
@property
def client(self) -> chromadb.PersistentClient:
"""Lazy-load ChromaDB client."""
if self._client is None:
chroma_path = get_chroma_dir()
if not chroma_path.exists():
raise RuntimeError(
f"ChromaDB not found at {chroma_path}. "
"Run 'python scripts/embed.py' first."
)
self._client = chromadb.PersistentClient(path=str(chroma_path))
return self._client
@property
def embedding_fn(self):
"""Lazy-load embedding function."""
if self._embedding_fn is None:
self._embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=EMBEDDING_MODEL
)
return self._embedding_fn
@property
def collection(self):
"""Lazy-load the Lenny collection."""
if self._collection is None:
self._collection = self.client.get_collection(
name=COLLECTION_NAME,
embedding_function=self.embedding_fn
)
return self._collection
def search(
self,
query: str,
top_k: int = 5,
type_filter: Literal["episode", "topic", "insight", "example"] | None = None,
) -> list[dict]:
"""
Semantic search across the Lenny corpus.
Args:
query: Search query
top_k: Number of results to return
type_filter: Optional filter by content type
Returns:
List of search results with metadata
"""
where_clause = None
if type_filter:
where_clause = {"type": type_filter}
results = self.collection.query(
query_texts=[query],
n_results=top_k,
where=where_clause,
include=["documents", "metadatas", "distances"]
)
# Format results
formatted = []
for i, doc_id in enumerate(results["ids"][0]):
metadata = results["metadatas"][0][i]
distance = results["distances"][0][i]
result = {
"id": doc_id,
"type": metadata.get("type"),
"episode_file": metadata.get("episode_file"),
"relevance_score": 1 - distance, # Convert distance to similarity
}
# Add type-specific fields
if metadata.get("type") == "episode":
result["guest"] = metadata.get("guest")
result["expertise_tags"] = json.loads(metadata.get("expertise_tags", "[]"))
elif metadata.get("type") == "topic":
result["topic_id"] = metadata.get("topic_id")
result["title"] = metadata.get("title")
result["line_start"] = metadata.get("line_start")
result["line_end"] = metadata.get("line_end")
elif metadata.get("type") == "insight":
result["topic_id"] = metadata.get("topic_id")
result["text"] = metadata.get("text")
result["line_start"] = metadata.get("line_start")
result["line_end"] = metadata.get("line_end")
elif metadata.get("type") == "example":
result["topic_id"] = metadata.get("topic_id")
result["explicit_text"] = metadata.get("explicit_text")
result["inferred_identity"] = metadata.get("inferred_identity")
result["confidence"] = metadata.get("confidence")
result["tags"] = json.loads(metadata.get("tags", "[]"))
result["lesson"] = metadata.get("lesson")
result["line_start"] = metadata.get("line_start")
result["line_end"] = metadata.get("line_end")
formatted.append(result)
return formatted
def get_chapter(self, episode_file: str, topic_id: str) -> dict | None:
"""
Get full chapter context including topic, insights, examples, and transcript.
Args:
episode_file: Episode filename (e.g., "Brian Chesky.txt")
topic_id: Topic ID (e.g., "topic_1")
Returns:
Chapter data with full context or None if not found
"""
topic = get_topic_by_id(episode_file, topic_id)
if topic is None:
return None
insights = get_insights_for_topic(episode_file, topic_id)
examples = get_examples_for_topic(episode_file, topic_id)
# Get transcript segment
transcript_segment = get_transcript_segment(
episode_file,
topic.get("line_start", 1),
topic.get("line_end", 1)
)
return {
"episode_file": episode_file,
"topic": topic,
"insights": insights,
"examples": examples,
"transcript_segment": transcript_segment,
}
def get_full_transcript(self, episode_file: str) -> dict | None:
"""
Get full transcript with metadata.
Args:
episode_file: Episode filename (e.g., "Brian Chesky.txt")
Returns:
Full episode data with transcript or None if not found
"""
preprocessed = load_preprocessed(episode_file)
transcript = load_transcript(episode_file)
if transcript is None:
return None
result = {
"episode_file": episode_file,
"transcript": transcript,
}
if preprocessed:
result["metadata"] = preprocessed.get("episode", {})
result["topics"] = preprocessed.get("topics", [])
return result
def list_episodes(self, expertise_filter: str | None = None) -> list[dict]:
"""
List all available episodes.
Args:
expertise_filter: Optional expertise tag to filter by
Returns:
List of episode info dicts
"""
episodes = list_available_episodes()
if expertise_filter:
expertise_lower = expertise_filter.lower()
episodes = [
ep for ep in episodes
if any(
expertise_lower in tag.lower()
for tag in ep.get("expertise_tags", [])
)
]
return episodes
# Global retriever instance
_retriever: LennyRetriever | None = None
def get_retriever() -> LennyRetriever:
"""Get the global retriever instance."""
global _retriever
if _retriever is None:
_retriever = LennyRetriever()
return _retriever