pipeline.py•4.01 kB
"""RAG pipeline orchestration with configurable strategies."""
from typing import List, Optional
from ..storage.schema import SearchResult
from .retriever import Retriever
from .reranker import Reranker, NoOpReranker
from .expander import QueryExpander, NoOpExpander
from .generator import LLMGenerator
class RAGPipeline:
"""Configurable RAG pipeline with strategy pattern."""
def __init__(
self,
retriever: Retriever,
generator: LLMGenerator,
reranker: Optional[Reranker] = None,
query_expander: Optional[QueryExpander] = None,
top_k: int = 5
):
self.retriever = retriever
self.generator = generator
self.reranker = reranker or NoOpReranker()
self.query_expander = query_expander or NoOpExpander()
self.top_k = top_k
async def query(
self,
question: str,
namespace: Optional[str] = None,
content_type: Optional[str] = None,
system_prompt: Optional[str] = None
) -> dict:
"""Execute full RAG pipeline.
Args:
question: User question
namespace: Optional namespace filter
content_type: Optional content type filter
system_prompt: Optional system prompt for generation
Returns:
Dict with 'answer' and 'sources' keys
"""
# 1. Query expansion
queries = await self.query_expander.expand(question)
# 2. Retrieve documents for each query
all_results = []
for query in queries:
results = await self.retriever.retrieve(
query=query,
limit=self.top_k,
namespace=namespace,
content_type=content_type
)
all_results.extend(results)
# Remove duplicates (by document_id + chunk_id)
seen = set()
unique_results = []
for result in all_results:
key = (result.document_id, result.chunk_id)
if key not in seen:
seen.add(key)
unique_results.append(result)
# 3. Rerank results
reranked_results = await self.reranker.rerank(question, unique_results)
# Limit to top_k after reranking
top_results = reranked_results[:self.top_k]
# 4. Generate response
if not top_results:
return {
"answer": "I couldn't find any relevant information to answer your question.",
"sources": []
}
context = [result.text for result in top_results]
answer = await self.generator.generate(
question=question,
context=context,
system_prompt=system_prompt
)
# 5. Format sources
sources = [
{
"document_id": result.document_id,
"chunk_id": result.chunk_id,
"score": result.score,
"title": result.metadata.title,
"namespace": result.metadata.namespace,
"snippet": result.text[:200] + "..." if len(result.text) > 200 else result.text
}
for result in top_results
]
return {
"answer": answer,
"sources": sources
}
async def search(
self,
query: str,
limit: int = 5,
namespace: Optional[str] = None,
content_type: Optional[str] = None
) -> List[SearchResult]:
"""Search without generation (just retrieval).
Args:
query: Search query
limit: Maximum results
namespace: Optional namespace filter
content_type: Optional content type filter
Returns:
List of SearchResult objects
"""
results = await self.retriever.retrieve(
query=query,
limit=limit,
namespace=namespace,
content_type=content_type
)
return results