Skip to main content
Glama
text_vectorizer.py17 kB
#!/usr/bin/env python3 """ Clinical Note Vectorization Pipeline Main entry point for vectorizing clinical notes into IRIS vector database. Uses NVIDIA NIM embeddings API and resumable batch processing with SQLite checkpoints. Usage: # Initial run python text_vectorizer.py --input synthea_clinical_notes.json --batch-size 50 # Resume from checkpoint python text_vectorizer.py --input synthea_clinical_notes.json --resume # Test search after vectorization python text_vectorizer.py --input synthea_clinical_notes.json --test-search Dependencies: - src/vectorization/embedding_client.py (NVIDIA NIM embeddings) - src/vectorization/vector_db_client.py (IRIS vector database) - src/vectorization/batch_processor.py (resumable batch processing) Environment Variables: - NVIDIA_API_KEY: NVIDIA NGC API key for embeddings - IRIS_HOST: IRIS server host (default: localhost) - IRIS_PORT: IRIS SQL port (default: 1972) - IRIS_NAMESPACE: IRIS namespace (default: DEMO) - IRIS_USERNAME: IRIS username (default: _SYSTEM) - IRIS_PASSWORD: IRIS password (default: SYS) """ import argparse import json import logging import os import sys import time from pathlib import Path from typing import List, Dict, Any, Optional from datetime import datetime # Add parent directory to path for imports sys.path.insert(0, str(Path(__file__).parent.parent)) from vectorization.embedding_client import NVIDIAEmbeddingsClient from vectorization.vector_db_client import IRISVectorDBClient from vectorization.batch_processor import BatchProcessor # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) class ClinicalNoteVectorizer: """ Main vectorization pipeline for clinical notes. Coordinates embedding generation, vector storage, and checkpoint management. """ def __init__( self, embedding_client: NVIDIAEmbeddingsClient, vector_db_client: IRISVectorDBClient, checkpoint_db: str = "vectorization_state.db", error_log: str = "vectorization_errors.log" ): """ Initialize vectorization pipeline. Args: embedding_client: NVIDIA NIM embeddings client vector_db_client: IRIS vector database client checkpoint_db: Path to SQLite checkpoint database error_log: Path to error log file """ self.embedding_client = embedding_client self.vector_db_client = vector_db_client self.checkpoint_db = checkpoint_db self.error_log = error_log # Statistics self.stats = { "total_documents": 0, "processed": 0, "successful": 0, "failed": 0, "validation_errors": 0, "start_time": None, "end_time": None } logger.info("Initialized ClinicalNoteVectorizer") def load_documents(self, input_file: str) -> List[Dict[str, Any]]: """ Load clinical notes from JSON file. Args: input_file: Path to JSON file with clinical notes Returns: List of document dictionaries Raises: FileNotFoundError: If input file doesn't exist json.JSONDecodeError: If file is not valid JSON """ input_path = Path(input_file) if not input_path.exists(): raise FileNotFoundError(f"Input file not found: {input_file}") logger.info(f"Loading documents from {input_file}") try: with open(input_path, 'r', encoding='utf-8') as f: documents = json.load(f) if not isinstance(documents, list): raise ValueError("Input file must contain a JSON array of documents") self.stats["total_documents"] = len(documents) logger.info(f"✓ Loaded {len(documents):,} documents") return documents except json.JSONDecodeError as e: logger.error(f"✗ Invalid JSON in {input_file}: {e}") raise def validate_document(self, doc: Dict[str, Any]) -> Optional[str]: """ Validate document has required fields. Args: doc: Document dictionary Returns: None if valid, error message if invalid """ required_fields = ["resource_id", "patient_id", "document_type", "text_content"] for field in required_fields: if field not in doc or not doc[field]: return f"Missing required field: {field}" # Validate text content is not empty if not doc["text_content"].strip(): return "Empty text_content" return None def preprocess_document(self, doc: Dict[str, Any]) -> Dict[str, Any]: """ Preprocess document for vectorization. Args: doc: Raw document dictionary Returns: Preprocessed document """ # Make a copy to avoid modifying original processed = doc.copy() # Normalize whitespace in text content text = processed["text_content"] text = " ".join(text.split()) # Collapse multiple whitespace # Store truncated version for TextContent field (10000 chars max) processed["text_content_truncated"] = text[:10000] # Keep full text for embedding generation processed["text_content"] = text return processed def vectorize( self, input_file: str, batch_size: int = 50, resume: bool = False, show_progress: bool = True ) -> Dict[str, Any]: """ Vectorize clinical notes from input file. Args: input_file: Path to JSON file with clinical notes batch_size: Number of documents per batch resume: Whether to resume from checkpoint show_progress: Whether to show progress updates Returns: Processing statistics dictionary """ # Load documents documents = self.load_documents(input_file) # Validate and preprocess valid_documents = [] validation_errors = [] logger.info("Validating and preprocessing documents...") for doc in documents: # Validate error = self.validate_document(doc) if error: validation_errors.append({ "resource_id": doc.get("resource_id", "unknown"), "error": error }) self.stats["validation_errors"] += 1 continue # Preprocess processed = self.preprocess_document(doc) valid_documents.append(processed) # Log validation errors to file if validation_errors: self._log_validation_errors(validation_errors) logger.warning(f"! {len(validation_errors)} validation errors (see {self.error_log})") logger.info(f"✓ {len(valid_documents):,} valid documents ready for vectorization") # Initialize batch processor with BatchProcessor( embedding_client=self.embedding_client, vector_db_client=self.vector_db_client, checkpoint_db=self.checkpoint_db ) as processor: self.stats["start_time"] = datetime.utcnow() # Process or resume if resume: logger.info("Resuming from checkpoint...") result_stats = processor.resume( documents=valid_documents, batch_size=batch_size, show_progress=show_progress ) else: result_stats = processor.process_documents( documents=valid_documents, batch_size=batch_size, show_progress=show_progress, on_batch_complete=self._on_batch_complete if show_progress else None ) self.stats["end_time"] = datetime.utcnow() self.stats["processed"] = result_stats["total_processed"] self.stats["successful"] = result_stats["successful"] self.stats["failed"] = result_stats["failed"] return self.stats def _on_batch_complete( self, batch_num: int, total_batches: int, stats: Dict[str, Any] ) -> None: """ Callback invoked after each batch completes. Args: batch_num: Current batch number total_batches: Total number of batches stats: Current processing statistics """ # Calculate throughput if stats.get("start_time"): elapsed = (datetime.utcnow() - stats["start_time"]).total_seconds() if elapsed > 0: docs_per_min = (stats["successful"] / elapsed) * 60 # Estimate time remaining remaining_batches = total_batches - batch_num remaining_time = (remaining_batches / batch_num) * elapsed if batch_num > 0 else 0 logger.info( f"Progress: {batch_num}/{total_batches} batches | " f"{stats['successful']:,} successful | " f"{stats['failed']} failed | " f"{docs_per_min:.1f} docs/min | " f"ETA: {remaining_time/60:.1f} min" ) def _log_validation_errors(self, errors: List[Dict[str, str]]) -> None: """ Log validation errors to error log file. Args: errors: List of validation error dictionaries """ error_log_path = Path(self.error_log) with open(error_log_path, 'a', encoding='utf-8') as f: f.write(f"\n{'='*80}\n") f.write(f"Validation Errors - {datetime.utcnow().isoformat()}\n") f.write(f"{'='*80}\n") for error in errors: f.write(f"Resource ID: {error['resource_id']}\n") f.write(f"Error: {error['error']}\n") f.write("-" * 80 + "\n") def test_search(self, query: str = "diabetes", top_k: int = 3) -> None: """ Test vector similarity search with sample query. Args: query: Search query text top_k: Number of results to return """ logger.info(f"\nTesting vector search: '{query}'") # Generate query embedding query_embedding = self.embedding_client.embed(query) # Search results = self.vector_db_client.search_similar( query_vector=query_embedding, top_k=top_k ) # Display results logger.info(f"\nTop {len(results)} results:") for i, result in enumerate(results, 1): logger.info(f"\n{i}. Similarity: {result['similarity']:.3f}") logger.info(f" Patient ID: {result['patient_id']}") logger.info(f" Doc Type: {result['document_type']}") logger.info(f" Content: {result['text_content'][:200]}...") def print_summary(self) -> None: """Print final processing summary.""" elapsed = 0 if self.stats["start_time"] and self.stats["end_time"]: elapsed = (self.stats["end_time"] - self.stats["start_time"]).total_seconds() logger.info("\n" + "="*80) logger.info("Vectorization Summary") logger.info("="*80) logger.info(f"Total documents: {self.stats['total_documents']:,}") logger.info(f"Validation errors: {self.stats['validation_errors']:,}") logger.info(f"Processed: {self.stats['processed']:,}") logger.info(f"Successful: {self.stats['successful']:,}") logger.info(f"Failed: {self.stats['failed']:,}") logger.info(f"Elapsed time: {elapsed:.1f}s ({elapsed/60:.1f} min)") if elapsed > 0 and self.stats["successful"] > 0: docs_per_min = (self.stats["successful"] / elapsed) * 60 logger.info(f"Throughput: {docs_per_min:.1f} docs/min") logger.info("="*80) def parse_arguments() -> argparse.Namespace: """Parse command-line arguments.""" parser = argparse.ArgumentParser( description="Vectorize clinical notes into IRIS vector database", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Initial vectorization %(prog)s --input synthea_clinical_notes.json --batch-size 50 # Resume from checkpoint %(prog)s --input synthea_clinical_notes.json --resume # Test search after vectorization %(prog)s --input synthea_clinical_notes.json --test-search "diabetes" """ ) parser.add_argument( "--input", type=str, required=True, help="Path to JSON file containing clinical notes" ) parser.add_argument( "--batch-size", type=int, default=50, help="Number of documents to process per batch (default: 50)" ) parser.add_argument( "--resume", action="store_true", help="Resume from checkpoint (skip already processed documents)" ) parser.add_argument( "--test-search", type=str, nargs='?', const="diabetes", help="Test vector search after vectorization with optional query (default: 'diabetes')" ) parser.add_argument( "--checkpoint-db", type=str, default="vectorization_state.db", help="Path to SQLite checkpoint database (default: vectorization_state.db)" ) parser.add_argument( "--error-log", type=str, default="vectorization_errors.log", help="Path to error log file (default: vectorization_errors.log)" ) parser.add_argument( "--iris-host", type=str, default=os.getenv("IRIS_HOST", "localhost"), help="IRIS database host (default: $IRIS_HOST or localhost)" ) parser.add_argument( "--iris-port", type=int, default=int(os.getenv("IRIS_PORT", "1972")), help="IRIS SQL port (default: $IRIS_PORT or 1972)" ) parser.add_argument( "--iris-namespace", type=str, default=os.getenv("IRIS_NAMESPACE", "DEMO"), help="IRIS namespace (default: $IRIS_NAMESPACE or DEMO)" ) parser.add_argument( "--iris-username", type=str, default=os.getenv("IRIS_USERNAME", "_SYSTEM"), help="IRIS username (default: $IRIS_USERNAME or _SYSTEM)" ) parser.add_argument( "--iris-password", type=str, default=os.getenv("IRIS_PASSWORD", "SYS"), help="IRIS password (default: $IRIS_PASSWORD or SYS)" ) return parser.parse_args() def main() -> int: """ Main entry point for vectorization pipeline. Returns: Exit code (0 for success, 1 for error) """ args = parse_arguments() try: # Initialize clients logger.info("Initializing NVIDIA NIM embeddings client...") embedding_client = NVIDIAEmbeddingsClient() logger.info("Initializing IRIS vector database client...") vector_db_client = IRISVectorDBClient( host=args.iris_host, port=args.iris_port, namespace=args.iris_namespace, username=args.iris_username, password=args.iris_password, vector_dimension=1024 # NV-EmbedQA-E5-V5 dimension ) # Connect to IRIS vector_db_client.connect() # Initialize vectorizer vectorizer = ClinicalNoteVectorizer( embedding_client=embedding_client, vector_db_client=vector_db_client, checkpoint_db=args.checkpoint_db, error_log=args.error_log ) # Vectorize documents logger.info(f"\nStarting vectorization pipeline...") logger.info(f"Input file: {args.input}") logger.info(f"Batch size: {args.batch_size}") logger.info(f"Resume mode: {args.resume}") logger.info("") vectorizer.vectorize( input_file=args.input, batch_size=args.batch_size, resume=args.resume, show_progress=True ) # Print summary vectorizer.print_summary() # Test search if requested if args.test_search: vectorizer.test_search(query=args.test_search, top_k=3) # Cleanup vector_db_client.disconnect() embedding_client.close() logger.info("\n✅ Vectorization complete!") return 0 except KeyboardInterrupt: logger.warning("\n⚠ Interrupted by user") return 1 except Exception as e: logger.error(f"\n✗ Vectorization failed: {e}", exc_info=True) return 1 if __name__ == "__main__": sys.exit(main())

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/isc-tdyar/medical-graphrag-assistant'

If you have feedback or need assistance with the MCP directory API, please join our Discord server