Skip to main content
Glama
test_batch_processor.py•23.5 kB
""" Unit Tests for BatchProcessor Tests resumable batch processing with SQLite checkpointing. Validates checkpoint creation, state transitions, and resume functionality. Usage: pytest tests/unit/test_batch_processor.py -v pytest tests/unit/test_batch_processor.py::TestBatchProcessor::test_register_documents -v Dependencies: pytest, unittest.mock """ import pytest import sys import sqlite3 import tempfile import os from pathlib import Path from unittest.mock import Mock, MagicMock, patch, call # Add src to path for imports sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) from vectorization.batch_processor import BatchProcessor class TestBatchProcessor: """Test suite for BatchProcessor class.""" @pytest.fixture def mock_clients(self): """Create mock embedding and vector DB clients.""" embedding_client = Mock() embedding_client.model = "nvidia/nv-embedqa-e5-v5" embedding_client.embed_batch.return_value = [[0.1] * 1024, [0.2] * 1024] vector_db_client = Mock() vector_db_client.insert_vector = Mock() return embedding_client, vector_db_client @pytest.fixture def temp_db(self): """Create a temporary SQLite database for testing.""" fd, path = tempfile.mkstemp(suffix=".db") os.close(fd) yield path # Cleanup if os.path.exists(path): os.remove(path) @pytest.fixture def processor(self, mock_clients, temp_db): """Create a test BatchProcessor instance.""" embedding_client, vector_db_client = mock_clients return BatchProcessor( embedding_client=embedding_client, vector_db_client=vector_db_client, checkpoint_db=temp_db, auto_commit_interval=2 ) def test_initialization(self, processor, temp_db): """Test processor initialization.""" assert processor.checkpoint_db == temp_db assert processor.auto_commit_interval == 2 assert processor.connection is None assert processor.cursor is None assert processor.stats["total_processed"] == 0 def test_connect_creates_database(self, processor): """Test connect creates checkpoint database.""" processor.connect() assert processor.connection is not None assert processor.cursor is not None # Verify database file exists assert os.path.exists(processor.checkpoint_db) processor.disconnect() def test_create_checkpoint_table(self, processor): """Test checkpoint table creation with correct schema.""" processor.connect() processor._create_checkpoint_table() # Verify table was created processor.cursor.execute( "SELECT name FROM sqlite_master WHERE type='table' AND name='VectorizationState'" ) result = processor.cursor.fetchone() assert result is not None # Verify columns processor.cursor.execute("PRAGMA table_info(VectorizationState)") columns = {row[1] for row in processor.cursor.fetchall()} expected_columns = { "DocumentID", "DocumentType", "Status", "ProcessingStartedAt", "ProcessingCompletedAt", "ErrorMessage", "RetryCount" } assert expected_columns.issubset(columns) processor.disconnect() def test_context_manager(self, mock_clients, temp_db): """Test using processor as context manager.""" embedding_client, vector_db_client = mock_clients with BatchProcessor(embedding_client, vector_db_client, temp_db) as proc: assert proc.connection is not None # Table should be created proc.cursor.execute( "SELECT name FROM sqlite_master WHERE type='table' AND name='VectorizationState'" ) assert proc.cursor.fetchone() is not None # Connection should be closed # Attempting to use cursor after disconnect should fail with pytest.raises(sqlite3.ProgrammingError): proc.cursor.execute("SELECT 1") def test_register_documents_new(self, processor): """Test registering new documents.""" with processor: documents = [ {"resource_id": "doc-001", "document_type": "Progress Note"}, {"resource_id": "doc-002", "document_type": "Discharge Summary"}, {"resource_id": "doc-003", "document_type": "Consultation"} ] new_count = processor.register_documents(documents, "clinical_note") assert new_count == 3 # Verify documents were inserted processor.cursor.execute( "SELECT DocumentID, DocumentType, Status FROM VectorizationState ORDER BY DocumentID" ) rows = processor.cursor.fetchall() assert len(rows) == 3 assert rows[0] == ("doc-001", "clinical_note", "pending") assert rows[1] == ("doc-002", "clinical_note", "pending") assert rows[2] == ("doc-003", "clinical_note", "pending") def test_register_documents_duplicates_ignored(self, processor): """Test that duplicate documents are ignored.""" with processor: documents = [ {"resource_id": "doc-001", "document_type": "Progress Note"} ] # Register first time count1 = processor.register_documents(documents, "clinical_note") assert count1 == 1 # Register again (duplicate) count2 = processor.register_documents(documents, "clinical_note") assert count2 == 0 # No new documents # Verify only one record exists processor.cursor.execute("SELECT COUNT(*) FROM VectorizationState") total = processor.cursor.fetchone()[0] assert total == 1 def test_get_pending_documents(self, processor): """Test retrieving pending documents.""" with processor: # Register documents documents = [ {"resource_id": f"doc-{i:03d}", "document_type": "Note"} for i in range(5) ] processor.register_documents(documents, "clinical_note") # Mark some as completed processor.mark_processing("doc-001") processor.mark_completed("doc-001") processor.mark_processing("doc-002") processor.mark_completed("doc-002") # Get pending pending = processor.get_pending_documents("clinical_note") assert len(pending) == 3 assert "doc-003" in pending assert "doc-004" in pending assert "doc-005" in pending assert "doc-001" not in pending assert "doc-002" not in pending def test_get_pending_documents_with_limit(self, processor): """Test retrieving pending documents with limit.""" with processor: documents = [ {"resource_id": f"doc-{i:03d}", "document_type": "Note"} for i in range(10) ] processor.register_documents(documents, "clinical_note") pending = processor.get_pending_documents("clinical_note", limit=3) assert len(pending) == 3 def test_mark_processing(self, processor): """Test marking document as processing.""" with processor: processor.register_documents( [{"resource_id": "doc-001", "document_type": "Note"}], "clinical_note" ) processor.mark_processing("doc-001") processor.connection.commit() # Verify status changed processor.cursor.execute( "SELECT Status, ProcessingStartedAt FROM VectorizationState WHERE DocumentID = 'doc-001'" ) row = processor.cursor.fetchone() assert row[0] == "processing" assert row[1] is not None # timestamp should be set def test_mark_completed(self, processor): """Test marking document as completed.""" with processor: processor.register_documents( [{"resource_id": "doc-001", "document_type": "Note"}], "clinical_note" ) processor.mark_processing("doc-001") processor.mark_completed("doc-001") processor.connection.commit() # Verify status changed processor.cursor.execute( "SELECT Status, ProcessingCompletedAt, ErrorMessage FROM VectorizationState WHERE DocumentID = 'doc-001'" ) row = processor.cursor.fetchone() assert row[0] == "completed" assert row[1] is not None # completion timestamp assert row[2] is None # no error message def test_mark_failed(self, processor): """Test marking document as failed.""" with processor: processor.register_documents( [{"resource_id": "doc-001", "document_type": "Note"}], "clinical_note" ) processor.mark_processing("doc-001") processor.mark_failed("doc-001", "API rate limit exceeded") processor.connection.commit() # Verify status and error processor.cursor.execute( "SELECT Status, ErrorMessage, RetryCount FROM VectorizationState WHERE DocumentID = 'doc-001'" ) row = processor.cursor.fetchone() assert row[0] == "failed" assert "API rate limit" in row[1] assert row[2] == 1 # retry count incremented def test_process_documents_success(self, processor, mock_clients): """Test successful document processing.""" embedding_client, vector_db_client = mock_clients # Mock successful embedding and insertion embedding_client.embed_batch.return_value = [ [0.1] * 1024, [0.2] * 1024 ] with processor: documents = [ { "resource_id": "doc-001", "patient_id": "patient-123", "document_type": "Progress Note", "text_content": "Patient presents with symptoms", "source_bundle": "bundle-001.json" }, { "resource_id": "doc-002", "patient_id": "patient-456", "document_type": "Discharge Summary", "text_content": "Patient discharged in stable condition", "source_bundle": "bundle-002.json" } ] stats = processor.process_documents( documents, batch_size=2, document_type="clinical_note", show_progress=False ) assert stats["successful"] == 2 assert stats["failed"] == 0 assert stats["total_processed"] == 2 # Verify embeddings were generated embedding_client.embed_batch.assert_called_once() # Verify vectors were inserted assert vector_db_client.insert_vector.call_count == 2 def test_process_documents_with_failures(self, processor, mock_clients): """Test document processing with some failures.""" embedding_client, vector_db_client = mock_clients # Mock successful embeddings embedding_client.embed_batch.return_value = [ [0.1] * 1024, [0.2] * 1024 ] # Mock vector insertion to fail on second document vector_db_client.insert_vector.side_effect = [ None, # First succeeds Exception("Database connection lost") # Second fails ] with processor: documents = [ { "resource_id": "doc-001", "patient_id": "patient-123", "document_type": "Note", "text_content": "Content 1" }, { "resource_id": "doc-002", "patient_id": "patient-456", "document_type": "Note", "text_content": "Content 2" } ] stats = processor.process_documents( documents, batch_size=2, document_type="clinical_note", show_progress=False ) assert stats["successful"] == 1 assert stats["failed"] == 1 # Check database state processor.cursor.execute( "SELECT DocumentID, Status FROM VectorizationState WHERE Status='failed'" ) failed = processor.cursor.fetchall() assert len(failed) == 1 assert failed[0][0] == "doc-002" def test_process_documents_batch_embedding_failure(self, processor, mock_clients): """Test handling of batch embedding failure.""" embedding_client, vector_db_client = mock_clients # Mock embedding to fail embedding_client.embed_batch.side_effect = Exception("API unavailable") with processor: documents = [ { "resource_id": "doc-001", "patient_id": "patient-123", "document_type": "Note", "text_content": "Content" } ] stats = processor.process_documents( documents, batch_size=1, document_type="clinical_note", show_progress=False ) # All documents should fail assert stats["successful"] == 0 assert stats["failed"] == 1 # Verify status in database processor.cursor.execute( "SELECT Status, ErrorMessage FROM VectorizationState WHERE DocumentID='doc-001'" ) row = processor.cursor.fetchone() assert row[0] == "failed" assert "API unavailable" in row[1] def test_resume_from_checkpoint(self, processor, mock_clients): """Test resuming processing from checkpoint.""" embedding_client, vector_db_client = mock_clients with processor: # Register documents documents = [ { "resource_id": f"doc-{i:03d}", "patient_id": "patient-123", "document_type": "Note", "text_content": f"Content {i}" } for i in range(5) ] processor.register_documents(documents, "clinical_note") # Manually mark some as completed (simulating previous run) processor.mark_processing("doc-000") processor.mark_completed("doc-000") processor.mark_processing("doc-001") processor.mark_completed("doc-001") processor.connection.commit() # Mock embeddings for remaining documents embedding_client.embed_batch.return_value = [ [0.1] * 1024, [0.2] * 1024, [0.3] * 1024 ] # Resume processing stats = processor.resume( documents, batch_size=3, document_type="clinical_note", show_progress=False ) # Should only process 3 remaining documents assert stats["successful"] == 3 assert stats["total_processed"] == 3 def test_get_statistics(self, processor): """Test retrieving processing statistics.""" with processor: # Create documents with different statuses documents = [ {"resource_id": f"doc-{i:03d}", "document_type": "Note"} for i in range(10) ] processor.register_documents(documents, "clinical_note") # Mark various statuses processor.mark_processing("doc-000") processor.mark_completed("doc-000") processor.mark_processing("doc-001") processor.mark_completed("doc-001") processor.mark_processing("doc-002") processor.mark_failed("doc-002", "Error 1") processor.mark_processing("doc-003") processor.mark_failed("doc-003", "Error 2") processor.mark_failed("doc-003", "Error 2 retry") # increment retry processor.connection.commit() stats = processor.get_statistics() assert stats["completed"]["count"] == 2 assert stats["failed"]["count"] == 2 assert stats["failed"]["retries"] == 2 # One document retried assert stats["pending"]["count"] == 6 def test_reset_failed_documents(self, processor): """Test resetting failed documents to pending.""" with processor: documents = [ {"resource_id": f"doc-{i:03d}", "document_type": "Note"} for i in range(5) ] processor.register_documents(documents, "clinical_note") # Mark all as failed for doc in documents: processor.mark_processing(doc["resource_id"]) processor.mark_failed(doc["resource_id"], "Temporary error") processor.connection.commit() # Reset failed documents reset_count = processor.reset_failed(max_retries=3) assert reset_count == 5 # Verify all are now pending processor.cursor.execute( "SELECT COUNT(*) FROM VectorizationState WHERE Status='pending'" ) pending_count = processor.cursor.fetchone()[0] assert pending_count == 5 def test_reset_failed_respects_max_retries(self, processor): """Test that reset_failed respects max retry limit.""" with processor: processor.register_documents( [{"resource_id": "doc-001", "document_type": "Note"}], "clinical_note" ) # Fail document multiple times for i in range(4): processor.mark_processing("doc-001") processor.mark_failed("doc-001", f"Error {i}") processor.connection.commit() # Verify retry count is 4 processor.cursor.execute( "SELECT RetryCount FROM VectorizationState WHERE DocumentID='doc-001'" ) assert processor.cursor.fetchone()[0] == 4 # Try to reset with max_retries=3 reset_count = processor.reset_failed(max_retries=3) assert reset_count == 0 # Should not reset (retry count >= 3) # With max_retries=5, should reset reset_count = processor.reset_failed(max_retries=5) assert reset_count == 1 def test_clear_checkpoint(self, processor): """Test clearing checkpoint database.""" with processor: # Register documents documents = [ {"resource_id": f"doc-{i:03d}", "document_type": "Note"} for i in range(5) ] processor.register_documents(documents, "clinical_note") # Clear all deleted_count = processor.clear_checkpoint() assert deleted_count == 5 # Verify database is empty processor.cursor.execute("SELECT COUNT(*) FROM VectorizationState") assert processor.cursor.fetchone()[0] == 0 def test_clear_checkpoint_by_type(self, processor): """Test clearing checkpoint for specific document type.""" with processor: # Register different types processor.register_documents( [{"resource_id": "doc-001", "document_type": "Note"}], "clinical_note" ) processor.register_documents( [{"resource_id": "img-001", "document_type": "Image"}], "medical_image" ) processor.connection.commit() # Clear only clinical notes deleted_count = processor.clear_checkpoint("clinical_note") assert deleted_count == 1 # Verify only medical_image remains processor.cursor.execute( "SELECT DocumentID, DocumentType FROM VectorizationState" ) rows = processor.cursor.fetchall() assert len(rows) == 1 assert rows[0] == ("img-001", "medical_image") class TestBatchProcessorIntegration: """Integration-style tests for full workflows.""" @pytest.fixture def mock_clients(self): """Create mock clients with realistic behavior.""" embedding_client = Mock() embedding_client.model = "nvidia/nv-embedqa-e5-v5" vector_db_client = Mock() return embedding_client, vector_db_client @pytest.fixture def temp_db(self): """Create temporary database.""" fd, path = tempfile.mkstemp(suffix=".db") os.close(fd) yield path if os.path.exists(path): os.remove(path) def test_full_workflow_with_interruption(self, mock_clients, temp_db): """Test complete workflow with simulated interruption and resume.""" embedding_client, vector_db_client = mock_clients documents = [ { "resource_id": f"doc-{i:03d}", "patient_id": "patient-123", "document_type": "Note", "text_content": f"Clinical note {i}" } for i in range(10) ] # First run: process first 5 documents with BatchProcessor(embedding_client, vector_db_client, temp_db) as proc: # Mock embeddings embedding_client.embed_batch.return_value = [[0.1] * 1024] * 5 # Register all documents proc.register_documents(documents, "clinical_note") # Process only first 5 (simulate interruption) first_five = documents[:5] for doc in first_five: proc.mark_processing(doc["resource_id"]) proc.mark_completed(doc["resource_id"]) proc.connection.commit() # Second run: resume and process remaining with BatchProcessor(embedding_client, vector_db_client, temp_db) as proc: embedding_client.embed_batch.return_value = [[0.2] * 1024] * 5 stats = proc.resume( documents, batch_size=5, document_type="clinical_note", show_progress=False ) # Should only process remaining 5 assert stats["successful"] == 5 assert stats["total_processed"] == 5 # Verify all documents are completed with BatchProcessor(embedding_client, vector_db_client, temp_db) as proc: proc.cursor.execute( "SELECT COUNT(*) FROM VectorizationState WHERE Status='completed'" ) assert proc.cursor.fetchone()[0] == 10 # Run tests if __name__ == "__main__": pytest.main([__file__, "-v", "--tb=short"])

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/isc-tdyar/medical-graphrag-assistant'

If you have feedback or need assistance with the MCP directory API, please join our Discord server