"""Populate Extracted texts into a database that supports RAG lookup"""
import json
import os
from pathlib import Path
from typing import Dict, Optional
from haystack import Document, Pipeline
from haystack.components.embedders import SentenceTransformersDocumentEmbedder
from haystack.components.preprocessors import DocumentCleaner, DocumentSplitter
from haystack.components.writers import DocumentWriter
from haystack_integrations.components.retrievers.chroma import ChromaEmbeddingRetriever
from haystack_integrations.document_stores.chroma import ChromaDocumentStore
# Add parent directory to path to import parameters
from parameters import ( # noqa: E402
CHUNK_OVERLAP,
CHUNK_SIZE,
COLLECTION_NAME,
DEFAULT_TOP_K,
EMBEDDING_MODEL,
PERSIST_PATH,
)
class RAGDatabasePopulator:
"""Manages population of extracted texts into a ChromaDB vector database for RAG"""
def __init__(
self,
collection_name: str = COLLECTION_NAME,
persist_path: str = PERSIST_PATH,
embedding_model: str = EMBEDDING_MODEL,
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = CHUNK_OVERLAP,
):
"""
Initialize the RAG database populator.
Args:
collection_name: Name of the ChromaDB collection
persist_path: Path where ChromaDB will persist data
embedding_model: Sentence transformer model for embeddings
chunk_size: Size of text chunks in words
chunk_overlap: Overlap between chunks in words
"""
self.collection_name = collection_name
self.persist_path = persist_path
self.embedding_model = embedding_model
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
# Initialize document store
self.document_store = ChromaDocumentStore(
collection_name=collection_name,
persist_path=persist_path,
embedding_function="default",
)
# Initialize indexing pipeline
self.indexing_pipeline = self._create_indexing_pipeline()
def _create_indexing_pipeline(self) -> Pipeline:
"""Create a Haystack pipeline for document indexing"""
pipeline = Pipeline()
# Add components
pipeline.add_component("cleaner", DocumentCleaner())
pipeline.add_component(
"splitter",
DocumentSplitter(
split_by="word",
split_length=self.chunk_size,
split_overlap=self.chunk_overlap,
),
)
pipeline.add_component(
"embedder",
SentenceTransformersDocumentEmbedder(
model=self.embedding_model, progress_bar=True
),
)
pipeline.add_component(
"writer", DocumentWriter(document_store=self.document_store)
)
# Connect components
pipeline.connect("cleaner", "splitter")
pipeline.connect("splitter", "embedder")
pipeline.connect("embedder", "writer")
return pipeline
def load_document_from_files(
self, text_path: str, meta_path: Optional[str] = None
) -> Document:
"""
Load a document from text file and optional metadata file.
Args:
text_path: Path to the text file
meta_path: Optional path to metadata JSON file
Returns:
Haystack Document object
"""
# Read text content
with open(text_path, "r", encoding="utf-8") as f:
content = f.read()
# Load metadata if available
metadata = {}
if meta_path and os.path.exists(meta_path):
with open(meta_path, "r", encoding="utf-8") as f:
metadata = json.load(f)
# Add source information
metadata["source_file"] = os.path.abspath(text_path)
metadata["file_name"] = os.path.basename(text_path)
return Document(content=content, meta=metadata)
def populate_from_directory(
self,
directory: str,
text_pattern: str = "*.txt",
meta_pattern: str = "*.meta.json",
) -> Dict[str, int]:
"""
Populate database from all text files in a directory.
Args:
directory: Directory containing extracted text files
text_pattern: Glob pattern for text files
meta_pattern: Glob pattern for metadata files
Returns:
Dictionary with statistics about the population
"""
directory_path = Path(directory)
# Find all text files
text_files = list(directory_path.glob(text_pattern))
if not text_files:
print(f"No text files found in {directory}")
return {"documents_processed": 0, "chunks_created": 0}
print(f"Found {len(text_files)} text files to process")
documents = []
for text_file in text_files:
# Look for corresponding metadata file
meta_file = text_file.with_suffix(".meta.json")
print(f"Loading: {text_file.name}")
doc = self.load_document_from_files(
str(text_file), str(meta_file) if meta_file.exists() else None
)
documents.append(doc)
# Process documents through pipeline
print(f"\nProcessing {len(documents)} documents through indexing pipeline...")
result = self.indexing_pipeline.run({"cleaner": {"documents": documents}})
# Get statistics
docs_written = result.get("writer", {}).get("documents_written", 0)
stats = {
"documents_processed": len(documents),
"chunks_created": docs_written,
"collection_name": self.collection_name,
"persist_path": self.persist_path,
}
print("\n✓ Successfully populated database:")
print(f" - Documents processed: {stats['documents_processed']}")
print(f" - Chunks created: {stats['chunks_created']}")
print(f" - Collection: {stats['collection_name']}")
print(f" - Persist path: {stats['persist_path']}")
return stats
def get_document_count(self) -> int:
"""Get the total number of documents in the store"""
return self.document_store.count_documents()
def clear_database(self):
"""Clear all documents from the database"""
count = self.get_document_count()
if count > 0:
print(f"Clearing {count} documents from database...")
# ChromaDB doesn't have a direct clear method, so we delete the collection
# and recreate it
self.document_store = ChromaDocumentStore(
collection_name=self.collection_name,
persist_path=self.persist_path,
embedding_function="default",
)
print("✓ Database cleared")
else:
print("Database is already empty")
def create_retrieval_pipeline(
document_store: ChromaDocumentStore,
embedding_model: str = EMBEDDING_MODEL,
top_k: int = DEFAULT_TOP_K,
) -> Pipeline:
"""
Create a retrieval pipeline for querying the database.
Args:
document_store: The ChromaDB document store
embedding_model: Model for query embeddings
top_k: Number of documents to retrieve
Returns:
Haystack Pipeline for retrieval
"""
from haystack.components.embedders import SentenceTransformersTextEmbedder
pipeline = Pipeline()
pipeline.add_component(
"embedder", SentenceTransformersTextEmbedder(model=embedding_model)
)
pipeline.add_component(
"retriever",
ChromaEmbeddingRetriever(document_store=document_store, top_k=top_k),
)
pipeline.connect("embedder.embedding", "retriever.query_embedding")
return pipeline
def main():
"""Main function to populate the database"""
import argparse
parser = argparse.ArgumentParser(
description="Populate RAG database from extracted text files"
)
parser.add_argument(
"--directory",
type=str,
default="./ExtractedText",
help="Directory containing extracted text files",
)
parser.add_argument(
"--collection",
type=str,
default=COLLECTION_NAME,
help="ChromaDB collection name",
)
parser.add_argument(
"--persist-path",
type=str,
default=PERSIST_PATH,
help="Path to persist ChromaDB data",
)
parser.add_argument(
"--chunk-size",
type=int,
default=CHUNK_SIZE,
help="Size of text chunks in words",
)
parser.add_argument(
"--chunk-overlap",
type=int,
default=CHUNK_OVERLAP,
help="Overlap between chunks in words",
)
parser.add_argument(
"--clear", action="store_true", help="Clear existing database before populating"
)
parser.add_argument(
"--test-query", type=str, help="Test the database with a query after population"
)
args = parser.parse_args()
# Initialize populator
populator = RAGDatabasePopulator(
collection_name=args.collection,
persist_path=args.persist_path,
chunk_size=args.chunk_size,
chunk_overlap=args.chunk_overlap,
)
# Clear if requested
if args.clear:
populator.clear_database()
# Populate database
populator.populate_from_directory(args.directory)
# Test query if provided
if args.test_query:
print(f"\n{'=' * 60}")
print(f"Testing with query: '{args.test_query}'")
print(f"{'=' * 60}\n")
retrieval_pipeline = create_retrieval_pipeline(
populator.document_store, top_k=3
)
result = retrieval_pipeline.run({"embedder": {"text": args.test_query}})
documents = result.get("retriever", {}).get("documents", [])
if documents:
print(f"Found {len(documents)} relevant chunks:\n")
for i, doc in enumerate(documents, 1):
print(f"Result {i} (Score: {doc.score:.4f}):")
print(f" Content: {doc.content[:200]}...")
if doc.meta:
print(f" Source: {doc.meta.get('file_name', 'Unknown')}")
print()
else:
print("No relevant documents found")
if __name__ == "__main__":
main()