#!/usr/bin/env python3
"""
Embedding script for preprocessed Lenny podcast data.
Uses bge-small-en-v1.5 embeddings with ChromaDB for vector storage.
Usage:
python scripts/embed.py # Build embeddings from preprocessed JSONs
python scripts/embed.py --rebuild # Clear and rebuild all embeddings
"""
import argparse
import json
import sys
from pathlib import Path
try:
import chromadb
from chromadb.utils import embedding_functions
except ImportError:
print("Error: chromadb package not installed.")
print("Install with: pip install chromadb")
sys.exit(1)
# Project paths
PROJECT_ROOT = Path(__file__).parent.parent
PREPROCESSED_DIR = PROJECT_ROOT / "preprocessed"
CHROMA_DIR = PROJECT_ROOT / "chroma_db"
TRANSCRIPTS_DIR = PROJECT_ROOT / "transcripts"
# Embedding model
EMBEDDING_MODEL = "BAAI/bge-small-en-v1.5"
COLLECTION_NAME = "lenny"
def sanitize_metadata(metadata: dict) -> dict:
"""Sanitize metadata values - ChromaDB doesn't accept None values."""
result = {}
for key, value in metadata.items():
if value is None:
# Convert None to appropriate default based on expected type
if key in ("line_start", "line_end"):
result[key] = 0
else:
result[key] = ""
else:
result[key] = value
return result
def get_embedding_function():
"""Get the sentence transformer embedding function."""
return embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=EMBEDDING_MODEL
)
def create_collection(client: chromadb.PersistentClient, rebuild: bool = False):
"""Create or get the Lenny collection."""
embedding_fn = get_embedding_function()
if rebuild:
# Delete existing collection if it exists
try:
client.delete_collection(COLLECTION_NAME)
print(f"Deleted existing collection: {COLLECTION_NAME}")
except Exception:
pass # Collection didn't exist
return client.get_or_create_collection(
name=COLLECTION_NAME,
embedding_function=embedding_fn,
metadata={"description": "Lenny podcast transcripts hierarchical RAG"}
)
def process_json_file(json_path: Path, collection, processed_episodes: set) -> dict:
"""
Process a preprocessed JSON file and add to ChromaDB.
Returns stats dict with counts.
"""
episode_file = json_path.stem + ".txt"
# Skip if already processed
if episode_file in processed_episodes:
return {"skipped": 1}
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f)
stats = {"episodes": 0, "topics": 0, "insights": 0, "examples": 0}
# 1. Embed episode summary
if "episode" in data:
ep = data["episode"]
episode_id = f"{episode_file}_episode"
# Combine summary with expertise tags for better retrieval
doc_text = f"{ep.get('summary', '')} {' '.join(ep.get('expertise_tags', []))}"
if ep.get("key_frameworks"):
doc_text += f" Frameworks: {', '.join(ep['key_frameworks'])}"
collection.add(
ids=[episode_id],
documents=[doc_text],
metadatas=[sanitize_metadata({
"type": "episode",
"episode_file": episode_file,
"guest": ep.get("guest", "Unknown"),
"expertise_tags": json.dumps(ep.get("expertise_tags", [])),
"key_frameworks": json.dumps(ep.get("key_frameworks", [])),
})]
)
stats["episodes"] = 1
# 2. Embed topics
for idx, topic in enumerate(data.get("topics", [])):
raw_id = topic.get('id', f"topic_{idx+1}")
topic_id = f"{episode_file}_{raw_id}"
# Combine title and summary
doc_text = f"{topic.get('title', '')}. {topic.get('summary', '')}"
collection.add(
ids=[topic_id],
documents=[doc_text],
metadatas=[sanitize_metadata({
"type": "topic",
"episode_file": episode_file,
"topic_id": raw_id,
"title": topic.get("title", ""),
"line_start": topic.get("line_start", 0),
"line_end": topic.get("line_end", 0),
})]
)
stats["topics"] += 1
# 3. Embed insights
for idx, insight in enumerate(data.get("insights", [])):
raw_id = insight.get('id', f"insight_{idx+1}")
insight_id = f"{episode_file}_{raw_id}"
# Combine insight text with context
doc_text = f"{insight.get('text', '')} {insight.get('context', '')}"
collection.add(
ids=[insight_id],
documents=[doc_text],
metadatas=[sanitize_metadata({
"type": "insight",
"episode_file": episode_file,
"topic_id": insight.get("topic_id", ""),
"text": insight.get("text", ""),
"line_start": insight.get("line_start", 0),
"line_end": insight.get("line_end", 0),
})]
)
stats["insights"] += 1
# 4. Embed examples (enriched for better retrieval)
for idx, example in enumerate(data.get("examples", [])):
# Generate ID if missing
raw_id = example.get('id', f"example_{idx+1}")
example_id = f"{episode_file}_{raw_id}"
# Enrich example text for better retrieval
parts = [example.get("explicit_text", "")]
if example.get("inferred_identity"):
parts.append(f"Company/Product: {example['inferred_identity']}")
if example.get("tags"):
parts.append(f"Keywords: {', '.join(example['tags'])}")
if example.get("lesson"):
parts.append(f"Lesson: {example['lesson']}")
doc_text = " ".join(parts)
collection.add(
ids=[example_id],
documents=[doc_text],
metadatas=[sanitize_metadata({
"type": "example",
"episode_file": episode_file,
"topic_id": example.get("topic_id", ""),
"explicit_text": example.get("explicit_text", ""),
"inferred_identity": example.get("inferred_identity", ""),
"confidence": example.get("confidence", "medium"),
"tags": json.dumps(example.get("tags", [])),
"lesson": example.get("lesson", ""),
"line_start": example.get("line_start", 0),
"line_end": example.get("line_end", 0),
})]
)
stats["examples"] += 1
return stats
def get_processed_episodes(collection) -> set:
"""Get set of episode files already in the collection."""
try:
# Query for all episode-type documents
results = collection.get(
where={"type": "episode"},
include=["metadatas"]
)
return {m["episode_file"] for m in results.get("metadatas", [])}
except Exception:
return set()
def main():
parser = argparse.ArgumentParser(description="Build embeddings for Lenny RAG")
parser.add_argument(
"--rebuild",
action="store_true",
help="Clear and rebuild all embeddings"
)
args = parser.parse_args()
# Ensure directories exist
CHROMA_DIR.mkdir(exist_ok=True)
# Initialize ChromaDB
print(f"Initializing ChromaDB at: {CHROMA_DIR}")
client = chromadb.PersistentClient(path=str(CHROMA_DIR))
# Create/get collection
collection = create_collection(client, rebuild=args.rebuild)
# Get already processed episodes (unless rebuilding)
if args.rebuild:
processed_episodes = set()
else:
processed_episodes = get_processed_episodes(collection)
if processed_episodes:
print(f"Found {len(processed_episodes)} already processed episodes")
# Find all preprocessed JSON files
json_files = sorted(PREPROCESSED_DIR.glob("*.json"))
if not json_files:
print("No preprocessed JSON files found.")
print(f"Run 'python scripts/preprocess.py' first to process transcripts.")
return
print(f"Found {len(json_files)} preprocessed files")
# Process each file
total_stats = {"episodes": 0, "topics": 0, "insights": 0, "examples": 0, "skipped": 0}
for i, json_path in enumerate(json_files, 1):
print(f"[{i}/{len(json_files)}] {json_path.stem}...", end=" ")
try:
stats = process_json_file(json_path, collection, processed_episodes)
if stats.get("skipped"):
print("(skipped - already processed)")
total_stats["skipped"] += 1
else:
print(f"(+{stats['topics']} topics, +{stats['insights']} insights, +{stats['examples']} examples)")
for key in ["episodes", "topics", "insights", "examples"]:
total_stats[key] += stats.get(key, 0)
except Exception as e:
print(f"ERROR: {e}")
# Summary
print(f"\n{'='*50}")
print("Embedding complete!")
print(f" Episodes: {total_stats['episodes']}")
print(f" Topics: {total_stats['topics']}")
print(f" Insights: {total_stats['insights']}")
print(f" Examples: {total_stats['examples']}")
print(f" Skipped: {total_stats['skipped']}")
print(f"\nCollection stored at: {CHROMA_DIR}")
if __name__ == "__main__":
main()