"""
Script to index products into ChromaDB using Azure OpenAI embeddings
"""
import os
import pandas as pd
import chromadb
from chromadb.config import Settings
from openai import AzureOpenAI
from dotenv import load_dotenv
from typing import List, Dict
# Load environment variables
load_dotenv()
class ProductIndexer:
def __init__(self):
"""Initialize the ProductIndexer with Azure OpenAI and ChromaDB clients"""
# Azure OpenAI configuration
self.azure_client = AzureOpenAI(
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT")
)
self.embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT")
# ChromaDB configuration
persist_directory = os.getenv("CHROMA_PERSIST_DIRECTORY", "./chroma_db")
self.chroma_client = chromadb.PersistentClient(path=persist_directory)
# Create or get collection
self.collection = self.chroma_client.get_or_create_collection(
name="products",
metadata={"description": "Product catalog with embeddings"}
)
print(f"ChromaDB initialized with persist directory: {persist_directory}")
print(f"Collection 'products' ready with {self.collection.count()} documents")
def get_embedding(self, text: str) -> List[float]:
"""
Get embedding from Azure OpenAI for the given text
Args:
text: Text to embed
Returns:
List of floats representing the embedding
"""
response = self.azure_client.embeddings.create(
input=text,
model=self.embedding_deployment
)
return response.data[0].embedding
def create_document_text(self, row: pd.Series) -> str:
"""
Create document text by concatenating relevant fields including price category
Args:
row: DataFrame row containing product information
Returns:
Concatenated text for embedding
"""
# Concatenate: name, category, brand, description, and price_category
doc_text = f"{row['name']} {row['category']} {row['brand']} {row['description']} {row['price_category']}"
return doc_text
def index_products(self, csv_path: str = "products_v2.csv"):
"""
Index all products from CSV into ChromaDB
Args:
csv_path: Path to products_v2.csv file (default: products_v2.csv)
"""
# Read the CSV
df = pd.read_csv(csv_path)
print(f"\nLoading {len(df)} products from {csv_path}")
# Clear existing collection
if self.collection.count() > 0:
print(f"Clearing existing {self.collection.count()} documents from collection...")
self.collection.delete(where={})
# Prepare data for indexing
documents = []
embeddings = []
metadatas = []
ids = []
print("\nGenerating embeddings and preparing documents...")
for idx, row in df.iterrows():
# Create document text
doc_text = self.create_document_text(row)
documents.append(doc_text)
# Generate embedding
embedding = self.get_embedding(doc_text)
embeddings.append(embedding)
# Create metadata (all fields except id)
metadata = {
"name": str(row['name']),
"category": str(row['category']),
"brand": str(row['brand']),
"description": str(row['description']),
"price": float(row['price']),
"price_category": str(row['price_category'])
}
metadatas.append(metadata)
# Use product id as document id
ids.append(str(row['id']))
if (idx + 1) % 5 == 0:
print(f" Processed {idx + 1}/{len(df)} products...")
# Add to ChromaDB
print("\nAdding documents to ChromaDB...")
self.collection.add(
documents=documents,
embeddings=embeddings,
metadatas=metadatas,
ids=ids
)
print(f"\n✓ Successfully indexed {len(df)} products into ChromaDB")
print(f"✓ Collection now contains {self.collection.count()} documents")
# Display sample
print("\nSample indexed products:")
for i in range(min(3, len(df))):
print(f"\n ID: {ids[i]}")
print(f" Name: {metadatas[i]['name']}")
print(f" Category: {metadatas[i]['category']}")
print(f" Brand: {metadatas[i]['brand']}")
print(f" Price: ${metadatas[i]['price']}")
print(f" Price Category: {metadatas[i]['price_category']}")
def verify_indexing(self):
"""Verify the indexing by performing a test query"""
print("\n" + "="*60)
print("VERIFICATION: Testing indexed data")
print("="*60)
# Test query
test_query = "affordable camping tent"
print(f"\nTest Query: '{test_query}'")
# Get embedding for test query
query_embedding = self.get_embedding(test_query)
# Search
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=3
)
print(f"\nTop 3 Results:")
for i, (doc_id, metadata, distance) in enumerate(zip(
results['ids'][0],
results['metadatas'][0],
results['distances'][0]
)):
print(f"\n {i+1}. {metadata['name']}")
print(f" Category: {metadata['category']}")
print(f" Brand: {metadata['brand']}")
print(f" Price: ${metadata['price']} ({metadata['price_category']})")
print(f" Distance: {distance:.4f}")
if __name__ == "__main__":
# Initialize indexer
indexer = ProductIndexer()
# Index products from current directory
csv_path = "products_v2.csv"
indexer.index_products(csv_path)
# Verify indexing
indexer.verify_indexing()