Skip to main content
Glama
rag_pipeline.py23.2 kB
""" RAG Pipeline for Multi-Modal Medical Query Processing This module implements the Retrieval-Augmented Generation (RAG) pipeline for querying clinical notes using vector similarity search and LLM generation. Main Components: - Query embedding generation (NVIDIA NIM embeddings) - Vector similarity search (IRIS database) - Context assembly from retrieved documents - LLM response generation (NVIDIA NIM LLM) - Citation extraction and formatting Usage: from query.rag_pipeline import RAGPipeline pipeline = RAGPipeline() result = pipeline.process_query( query_text="What are the patient's chronic conditions?", top_k=10, patient_id="patient-123" # Optional filter ) print(f"Response: {result['response']}") print(f"Citations: {result['citations']}") Performance Target (SC-007): <5 seconds end-to-end query latency """ import os import sys import logging import time import requests from typing import List, Dict, Any, Optional, Tuple from datetime import datetime from pathlib import Path # Add parent directory to path for imports sys.path.insert(0, str(Path(__file__).parent.parent)) from vectorization.embedding_client import NVIDIAEmbeddingsClient from vectorization.vector_db_client import IRISVectorDBClient # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) logger = logging.getLogger(__name__) class RAGPipeline: """ Main RAG pipeline for medical query processing. Coordinates query embedding, vector search, context assembly, and LLM generation with citation tracking. """ # Default configuration DEFAULT_TOP_K = 10 DEFAULT_SIMILARITY_THRESHOLD = 0.5 DEFAULT_MAX_CONTEXT_TOKENS = 4000 # Approximate token budget for context DEFAULT_LLM_MAX_TOKENS = 500 DEFAULT_LLM_TEMPERATURE = 0.7 def __init__( self, iris_host: str = None, iris_port: int = None, iris_namespace: str = None, iris_username: str = None, iris_password: str = None, llm_endpoint: str = None, embedding_api_key: str = None, vector_dimension: int = 1024 ): """ Initialize RAG pipeline with database and API connections. Args: iris_host: IRIS database host (default: from env IRIS_HOST or localhost) iris_port: IRIS database port (default: from env IRIS_PORT or 1972) iris_namespace: IRIS namespace (default: from env IRIS_NAMESPACE or DEMO) iris_username: IRIS username (default: from env IRIS_USERNAME or _SYSTEM) iris_password: IRIS password (default: from env IRIS_PASSWORD or SYS) llm_endpoint: NIM LLM endpoint URL (default: http://localhost:8001) embedding_api_key: NVIDIA API key (default: from env NVIDIA_API_KEY) vector_dimension: Embedding dimension (default: 1024) """ # IRIS database configuration self.iris_host = iris_host or os.getenv("IRIS_HOST", "localhost") self.iris_port = iris_port or int(os.getenv("IRIS_PORT", "1972")) self.iris_namespace = iris_namespace or os.getenv("IRIS_NAMESPACE", "DEMO") self.iris_username = iris_username or os.getenv("IRIS_USERNAME", "_SYSTEM") self.iris_password = iris_password or os.getenv("IRIS_PASSWORD", "SYS") self.vector_dimension = vector_dimension # LLM configuration self.llm_endpoint = llm_endpoint or "http://localhost:8001" self.llm_completions_url = f"{self.llm_endpoint}/v1/chat/completions" # Initialize clients logger.info("Initializing NVIDIA NIM embeddings client...") self.embedding_client = NVIDIAEmbeddingsClient(api_key=embedding_api_key) logger.info("Initializing IRIS vector database client...") self.vector_db_client = IRISVectorDBClient( host=self.iris_host, port=self.iris_port, namespace=self.iris_namespace, username=self.iris_username, password=self.iris_password, vector_dimension=self.vector_dimension ) # Connect to IRIS self.vector_db_client.connect() logger.info(f"✓ Connected to IRIS: {self.iris_host}:{self.iris_port}/{self.iris_namespace}") def __del__(self): """Cleanup database connection.""" try: if hasattr(self, 'vector_db_client'): self.vector_db_client.disconnect() except: pass def generate_query_embedding(self, query_text: str) -> List[float]: """ Generate embedding vector for query text. Uses the same NVIDIA NIM embeddings API as the vectorization pipeline to ensure query and document embeddings are in the same space. Args: query_text: Natural language query Returns: 1024-dimensional embedding vector Raises: RuntimeError: If embedding generation fails """ try: logger.info(f"Generating embedding for query: '{query_text[:100]}...'") embedding = self.embedding_client.embed(query_text) logger.info(f"✓ Generated {len(embedding)}-dimensional embedding") return embedding except Exception as e: logger.error(f"Failed to generate query embedding: {e}") raise RuntimeError(f"Query embedding generation failed: {e}") def search_similar_documents( self, query_vector: List[float], top_k: int = DEFAULT_TOP_K, patient_id: Optional[str] = None, document_type: Optional[str] = None, similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD ) -> List[Dict[str, Any]]: """ Search for similar documents using vector similarity. Args: query_vector: Query embedding vector top_k: Number of results to retrieve patient_id: Optional patient ID filter document_type: Optional document type filter similarity_threshold: Minimum cosine similarity (0-1) Returns: List of documents with metadata and similarity scores: [ { "resource_id": str, "patient_id": str, "document_type": str, "text_content": str, "similarity": float }, ... ] """ logger.info(f"Searching for top-{top_k} similar documents...") if patient_id: logger.info(f" Filtering by patient_id: {patient_id}") if document_type: logger.info(f" Filtering by document_type: {document_type}") logger.info(f" Similarity threshold: {similarity_threshold}") try: results = self.vector_db_client.search_similar( query_vector=query_vector, top_k=top_k, patient_id=patient_id, document_type=document_type ) # Apply similarity threshold filter filtered_results = [ result for result in results if result["similarity"] >= similarity_threshold ] logger.info(f"✓ Retrieved {len(filtered_results)} documents above threshold") return filtered_results except Exception as e: logger.error(f"Vector search failed: {e}") raise RuntimeError(f"Vector similarity search failed: {e}") def filter_and_rank_results( self, results: List[Dict[str, Any]], min_similarity: float = DEFAULT_SIMILARITY_THRESHOLD ) -> List[Dict[str, Any]]: """ Filter and rank search results. Applies additional filtering and re-ranking logic: - Filters by minimum similarity threshold - Re-ranks by document recency if timestamps available - Deduplicates similar content Args: results: Raw search results min_similarity: Minimum similarity score Returns: Filtered and ranked results """ # Filter by minimum similarity filtered = [r for r in results if r["similarity"] >= min_similarity] # Sort by similarity (already sorted by vector_db_client, but ensure) filtered.sort(key=lambda x: x["similarity"], reverse=True) logger.info(f"Filtered to {len(filtered)} results (min_similarity={min_similarity})") return filtered def assemble_context( self, documents: List[Dict[str, Any]], max_tokens: int = DEFAULT_MAX_CONTEXT_TOKENS ) -> Tuple[str, List[Dict[str, Any]]]: """ Assemble context from retrieved documents for LLM prompt. Concatenates document text with metadata and source identifiers. Truncates if total exceeds max_tokens budget. Args: documents: Retrieved documents with metadata max_tokens: Approximate maximum tokens for context Returns: Tuple of (assembled_context_text, source_documents) """ if not documents: return "", [] context_parts = [] sources = [] # Rough token estimation: ~4 characters per token chars_per_token = 4 max_chars = max_tokens * chars_per_token current_chars = 0 for i, doc in enumerate(documents): # Format document with metadata doc_text = f""" [Document {i+1}] Source ID: {doc['resource_id']} Patient: {doc['patient_id']} Type: {doc['document_type']} Similarity: {doc['similarity']:.3f} Content: {doc['text_content']} --- """ # Check if adding this document exceeds budget if current_chars + len(doc_text) > max_chars: logger.warning(f"Context budget exceeded at document {i+1}/{len(documents)}") break context_parts.append(doc_text) sources.append({ "index": i + 1, "resource_id": doc["resource_id"], "patient_id": doc["patient_id"], "document_type": doc["document_type"], "similarity": doc["similarity"] }) current_chars += len(doc_text) assembled_context = "\n".join(context_parts) logger.info(f"✓ Assembled context from {len(sources)} documents (~{current_chars//chars_per_token} tokens)") return assembled_context, sources def create_llm_prompt( self, query_text: str, context: str ) -> List[Dict[str, str]]: """ Create LLM prompt with system instructions, context, and query. Formats a chat completion prompt with: - System message instructing to answer only from context - User message with retrieved context and query Args: query_text: User's natural language query context: Assembled context from retrieved documents Returns: Messages array for chat completion API: [ {"role": "system", "content": "..."}, {"role": "user", "content": "..."} ] """ system_prompt = """You are a helpful medical AI assistant. Answer the user's question based ONLY on the clinical notes provided in the context below. IMPORTANT INSTRUCTIONS: - Use only information from the provided clinical notes - If the context doesn't contain relevant information, say "I don't have enough information to answer this question based on the available clinical notes" - Cite specific document numbers when referencing information (e.g., "According to Document 1...") - Be concise and factual - Do not make up or infer information not present in the context""" user_prompt = f"""Context (Retrieved Clinical Notes): {context} Question: {query_text} Please answer the question based on the clinical notes provided above. Cite the document numbers when referencing specific information.""" messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ] return messages def call_llm_api( self, messages: List[Dict[str, str]], max_tokens: int = DEFAULT_LLM_MAX_TOKENS, temperature: float = DEFAULT_LLM_TEMPERATURE ) -> str: """ Call NVIDIA NIM LLM API for response generation. Sends chat completion request to NIM LLM service (meta/llama-3.1-8b-instruct) running on port 8001. Args: messages: Chat messages array max_tokens: Maximum tokens in response temperature: Sampling temperature (0-1) Returns: Generated response text Raises: RuntimeError: If API call fails """ payload = { "model": "meta/llama-3.1-8b-instruct", "messages": messages, "max_tokens": max_tokens, "temperature": temperature } try: logger.info(f"Calling NIM LLM API at {self.llm_completions_url}...") response = requests.post( self.llm_completions_url, json=payload, timeout=60 # 60 second timeout ) response.raise_for_status() response_data = response.json() generated_text = response_data["choices"][0]["message"]["content"] logger.info(f"✓ LLM response generated ({len(generated_text)} chars)") return generated_text except requests.exceptions.Timeout: logger.error("LLM API request timed out") raise RuntimeError("LLM API request timed out after 60 seconds") except requests.exceptions.ConnectionError: logger.error(f"Could not connect to LLM API at {self.llm_completions_url}") raise RuntimeError(f"LLM service unavailable at {self.llm_completions_url}") except requests.exceptions.HTTPError as e: logger.error(f"LLM API returned error: {e}") raise RuntimeError(f"LLM API error: {e}") except (KeyError, IndexError) as e: logger.error(f"Unexpected LLM API response format: {e}") raise RuntimeError(f"Invalid LLM API response: {e}") def extract_citations( self, response_text: str, sources: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: """ Extract citations from LLM response. Maps LLM response back to source documents by detecting references to document numbers (e.g., "Document 1", "Document 2"). Args: response_text: Generated LLM response sources: Source documents with metadata Returns: List of cited sources: [ { "resource_id": str, "patient_id": str, "document_type": str, "similarity": float, "cited_in_response": bool }, ... ] """ citations = [] for source in sources: doc_number = source["index"] # Check if this document number is mentioned in response mentioned = ( f"Document {doc_number}" in response_text or f"document {doc_number}" in response_text or f"[{doc_number}]" in response_text ) citations.append({ "resource_id": source["resource_id"], "patient_id": source["patient_id"], "document_type": source["document_type"], "similarity": source["similarity"], "cited_in_response": mentioned }) cited_count = sum(1 for c in citations if c["cited_in_response"]) logger.info(f"✓ Extracted {cited_count}/{len(citations)} citations mentioned in response") return citations def handle_no_results(self, query_text: str) -> Dict[str, Any]: """ Handle case when no relevant documents are found. Returns explicit "no information found" message instead of allowing LLM to hallucinate an answer. Args: query_text: Original query Returns: Result dict with no-results message """ logger.warning("No relevant documents found above similarity threshold") return { "query": query_text, "response": "I don't have enough information to answer this question based on the available clinical notes. No relevant documents were found with sufficient similarity to your query.", "retrieved_documents": [], "citations": [], "metadata": { "documents_retrieved": 0, "documents_used_in_context": 0, "processing_time_seconds": 0, "timestamp": datetime.now().isoformat() } } def process_query( self, query_text: str, top_k: int = DEFAULT_TOP_K, patient_id: Optional[str] = None, document_type: Optional[str] = None, similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD, max_context_tokens: int = DEFAULT_MAX_CONTEXT_TOKENS, llm_max_tokens: int = DEFAULT_LLM_MAX_TOKENS, llm_temperature: float = DEFAULT_LLM_TEMPERATURE ) -> Dict[str, Any]: """ Process a complete RAG query from start to finish. Main entry point for the RAG pipeline. Coordinates all steps: 1. Generate query embedding 2. Search for similar documents 3. Filter and rank results 4. Assemble context 5. Create LLM prompt 6. Call LLM API 7. Extract citations Args: query_text: Natural language query top_k: Number of documents to retrieve patient_id: Optional patient ID filter document_type: Optional document type filter similarity_threshold: Minimum similarity score (0-1) max_context_tokens: Maximum tokens for context llm_max_tokens: Maximum tokens in LLM response llm_temperature: LLM temperature parameter Returns: Complete query result: { "query": str, "response": str, "retrieved_documents": List[Dict], "citations": List[Dict], "metadata": Dict } Example: >>> pipeline = RAGPipeline() >>> result = pipeline.process_query( ... query_text="What are the patient's chronic conditions?", ... top_k=10, ... patient_id="patient-123" ... ) >>> print(result["response"]) """ start_time = time.time() logger.info("="*80) logger.info("Starting RAG Query Processing") logger.info(f"Query: {query_text}") logger.info("="*80) try: # Step 1: Generate query embedding query_vector = self.generate_query_embedding(query_text) # Step 2: Search for similar documents search_results = self.search_similar_documents( query_vector=query_vector, top_k=top_k, patient_id=patient_id, document_type=document_type, similarity_threshold=similarity_threshold ) # Step 3: Handle no results case if not search_results: return self.handle_no_results(query_text) # Step 4: Filter and rank results (additional processing) filtered_results = self.filter_and_rank_results( results=search_results, min_similarity=similarity_threshold ) # Step 5: Assemble context context, sources = self.assemble_context( documents=filtered_results, max_tokens=max_context_tokens ) # Step 6: Create LLM prompt messages = self.create_llm_prompt( query_text=query_text, context=context ) # Step 7: Call LLM API response_text = self.call_llm_api( messages=messages, max_tokens=llm_max_tokens, temperature=llm_temperature ) # Step 8: Extract citations citations = self.extract_citations( response_text=response_text, sources=sources ) # Calculate processing time processing_time = time.time() - start_time # Assemble result result = { "query": query_text, "response": response_text, "retrieved_documents": filtered_results, "citations": citations, "metadata": { "documents_retrieved": len(search_results), "documents_used_in_context": len(sources), "processing_time_seconds": round(processing_time, 2), "timestamp": datetime.now().isoformat(), "parameters": { "top_k": top_k, "similarity_threshold": similarity_threshold, "patient_id": patient_id, "document_type": document_type } } } logger.info("="*80) logger.info(f"✅ Query processing complete in {processing_time:.2f}s") logger.info(f" Retrieved: {len(search_results)} documents") logger.info(f" Used in context: {len(sources)} documents") logger.info(f" Response length: {len(response_text)} chars") logger.info("="*80) return result except Exception as e: logger.error(f"❌ Query processing failed: {e}") raise if __name__ == "__main__": """ Example usage and testing. """ print("RAG Pipeline Module") print("=" * 80) print() print("This module provides the main RAG query processing functionality.") print() print("Example usage:") print() print(" from query.rag_pipeline import RAGPipeline") print() print(" pipeline = RAGPipeline()") print(" result = pipeline.process_query(") print(" query_text='What medications is the patient taking?',") print(" top_k=10,") print(" patient_id='patient-123'") print(" )") print() print(" print(f\"Response: {result['response']}\")") print(" print(f\"Citations: {len(result['citations'])} sources\")") print() print("For CLI testing, use: python src/validation/test_rag_query.py")

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/isc-tdyar/medical-graphrag-assistant'

If you have feedback or need assistance with the MCP directory API, please join our Discord server