Skip to main content
Glama

Mutation Clinical Trial Matching MCP

by pickleton89
test_shared_utilities.py25.7 kB
""" Tests for the shared utilities module. """ import asyncio import time from unittest.mock import AsyncMock, Mock, patch import httpx import pytest import requests import requests.exceptions from utils.shared import ( SessionManager, extract_studies_from_response, get_service_config, map_http_exception_to_error_response, process_json_response, time_request, validate_llm_input, validate_mutation_input, ) class TestValidationFunctions: """Test input validation functions.""" def test_validate_mutation_input_valid(self): """Test valid mutation input.""" result = validate_mutation_input("BRAF V600E", min_rank=1, max_rank=100) assert result["valid"] is True assert result["error"] is None assert result["mutation"] == "BRAF V600E" assert result["min_rank"] == 1 assert result["max_rank"] == 100 assert len(result["warnings"]) == 0 def test_validate_mutation_input_empty_mutation(self): """Test empty mutation input.""" with patch('utils.metrics.increment') as mock_increment: result = validate_mutation_input("") assert result["valid"] is False assert "non-empty string" in result["error"] mock_increment.assert_called_with("api_validation_errors", tags={"error_type": "invalid_mutation"}) def test_validate_mutation_input_none_mutation(self): """Test None mutation input.""" with patch('utils.metrics.increment') as mock_increment: result = validate_mutation_input(None) # type: ignore[arg-type] assert result["valid"] is False assert "non-empty string" in result["error"] mock_increment.assert_called_with("api_validation_errors", tags={"error_type": "invalid_mutation"}) def test_validate_mutation_input_invalid_min_rank(self): """Test invalid min_rank correction.""" with patch('utils.metrics.increment') as mock_increment: result = validate_mutation_input("BRAF V600E", min_rank=0) assert result["valid"] is True assert result["min_rank"] == 1 assert "Invalid min_rank 0, corrected to 1" in result["warnings"] mock_increment.assert_called_with("api_validation_warnings", tags={"warning_type": "invalid_min_rank"}) def test_validate_mutation_input_invalid_max_rank(self): """Test invalid max_rank correction.""" with patch('utils.metrics.increment') as mock_increment: result = validate_mutation_input("BRAF V600E", max_rank=-5) assert result["valid"] is True assert result["max_rank"] is None assert "Invalid max_rank -5, corrected to unlimited" in result["warnings"] mock_increment.assert_called_with("api_validation_warnings", tags={"warning_type": "invalid_max_rank"}) def test_validate_mutation_input_rank_order_correction(self): """Test rank order correction.""" with patch('utils.metrics.increment') as mock_increment: result = validate_mutation_input("BRAF V600E", min_rank=100, max_rank=50) assert result["valid"] is True assert result["min_rank"] == 50 assert result["max_rank"] == 100 assert "min_rank and max_rank were swapped" in result["warnings"][0] mock_increment.assert_called_with("api_validation_warnings", tags={"warning_type": "rank_order_corrected"}) def test_validate_mutation_input_whitespace_trimming(self): """Test whitespace trimming.""" result = validate_mutation_input(" BRAF V600E ") assert result["valid"] is True assert result["mutation"] == "BRAF V600E" def test_validate_llm_input_valid(self): """Test valid LLM input.""" messages = [ {"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there!"} ] result = validate_llm_input(messages, model="claude-3", max_tokens=1000, temperature=0.7) assert result["valid"] is True assert result["error"] is None assert result["messages"] == messages assert result["model"] == "claude-3" assert result["max_tokens"] == 1000 assert result["temperature"] == 0.7 assert len(result["warnings"]) == 0 def test_validate_llm_input_empty_messages(self): """Test empty messages list.""" with patch('utils.metrics.increment') as mock_increment: result = validate_llm_input([]) assert result["valid"] is False assert "non-empty list" in result["error"] mock_increment.assert_called_with("llm_validation_errors", tags={"error_type": "invalid_messages"}) def test_validate_llm_input_invalid_message_structure(self): """Test invalid message structure.""" messages = [ {"role": "user", "content": "Hello"}, "invalid message" ] with patch('utils.metrics.increment') as mock_increment: result = validate_llm_input(messages) assert result["valid"] is False assert "Message 1 must be a dictionary" in result["error"] mock_increment.assert_called_with("llm_validation_errors", tags={"error_type": "invalid_message_structure"}) def test_validate_llm_input_missing_fields(self): """Test missing required fields in message.""" messages = [ {"role": "user"}, # Missing content ] with patch('utils.metrics.increment') as mock_increment: result = validate_llm_input(messages) assert result["valid"] is False assert "must have 'role' and 'content' fields" in result["error"] mock_increment.assert_called_with("llm_validation_errors", tags={"error_type": "missing_message_fields"}) def test_validate_llm_input_unusual_role(self): """Test unusual role warning.""" messages = [ {"role": "moderator", "content": "Hello"} ] with patch('utils.metrics.increment') as mock_increment: result = validate_llm_input(messages) assert result["valid"] is True assert "unusual role: moderator" in result["warnings"][0] mock_increment.assert_called_with("llm_validation_warnings", tags={"warning_type": "unusual_role"}) def test_validate_llm_input_invalid_max_tokens(self): """Test invalid max_tokens correction.""" messages = [{"role": "user", "content": "Hello"}] with patch('utils.metrics.increment') as mock_increment: result = validate_llm_input(messages, max_tokens=-100) assert result["valid"] is True assert result["max_tokens"] == 1000 assert "Invalid max_tokens -100, corrected to 1000" in result["warnings"] mock_increment.assert_called_with("llm_validation_warnings", tags={"warning_type": "invalid_max_tokens"}) def test_validate_llm_input_invalid_temperature(self): """Test invalid temperature correction.""" messages = [{"role": "user", "content": "Hello"}] with patch('utils.metrics.increment') as mock_increment: result = validate_llm_input(messages, temperature=5.0) assert result["valid"] is True assert result["temperature"] == 0.7 assert "Invalid temperature 5.0, corrected to 0.7" in result["warnings"] mock_increment.assert_called_with("llm_validation_warnings", tags={"warning_type": "invalid_temperature"}) class TestErrorHandling: """Test error handling functions.""" def test_map_requests_timeout_exception(self): """Test mapping requests timeout exception.""" exception = requests.exceptions.Timeout("Request timeout") with patch('utils.metrics.increment') as mock_increment: result = map_http_exception_to_error_response(exception, "test_service") assert result["error"] == "Request timed out" assert result["retry_after"] == 30 assert result["error_type"] == "Timeout" assert result["studies"] == [] mock_increment.assert_called_with("api_errors", tags={"service": "test_service", "error_type": "timeout"}) def test_map_requests_connection_error(self): """Test mapping requests connection error.""" exception = requests.exceptions.ConnectionError("Connection failed") with patch('utils.metrics.increment') as mock_increment: result = map_http_exception_to_error_response(exception, "test_service") assert result["error"] == "Connection failed" assert result["retry_after"] == 60 assert result["error_type"] == "ConnectionError" mock_increment.assert_called_with("api_errors", tags={"service": "test_service", "error_type": "connection_error"}) def test_map_requests_http_error_429(self): """Test mapping requests HTTP 429 error.""" mock_response = Mock() mock_response.status_code = 429 exception = requests.exceptions.HTTPError("Rate limit") exception.response = mock_response with patch('utils.metrics.increment') as mock_increment: result = map_http_exception_to_error_response(exception, "test_service") assert result["error"] == "Rate limit exceeded" assert result["retry_after"] == 60 mock_increment.assert_called_with("api_errors", tags={"service": "test_service", "error_type": "rate_limit"}) def test_map_requests_http_error_500(self): """Test mapping requests HTTP 500 error.""" mock_response = Mock() mock_response.status_code = 500 exception = requests.exceptions.HTTPError("Server error") exception.response = mock_response with patch('utils.metrics.increment') as mock_increment: result = map_http_exception_to_error_response(exception, "test_service") assert result["error"] == "Server error" assert result["retry_after"] == 120 mock_increment.assert_called_with("api_errors", tags={"service": "test_service", "error_type": "server_error"}) def test_map_httpx_timeout_exception(self): """Test mapping httpx timeout exception.""" exception = httpx.TimeoutException("Request timeout") with patch('utils.metrics.increment') as mock_increment: result = map_http_exception_to_error_response(exception, "test_service") assert result["error"] == "Request timed out" assert result["retry_after"] == 30 assert result["error_type"] == "TimeoutException" mock_increment.assert_called_with("api_errors", tags={"service": "test_service", "error_type": "timeout"}) def test_map_httpx_connect_error(self): """Test mapping httpx connect error.""" exception = httpx.ConnectError("Connection failed") with patch('utils.metrics.increment') as mock_increment: result = map_http_exception_to_error_response(exception, "test_service") assert result["error"] == "Connection failed" assert result["retry_after"] == 60 assert result["error_type"] == "ConnectError" mock_increment.assert_called_with("api_errors", tags={"service": "test_service", "error_type": "connection_error"}) def test_map_httpx_http_status_error(self): """Test mapping httpx HTTP status error.""" mock_response = Mock() mock_response.status_code = 429 exception = httpx.HTTPStatusError("Rate limit", request=None, response=mock_response) with patch('utils.metrics.increment') as mock_increment: result = map_http_exception_to_error_response(exception, "test_service") assert result["error"] == "Rate limit exceeded" assert result["retry_after"] == 60 mock_increment.assert_called_with("api_errors", tags={"service": "test_service", "error_type": "rate_limit"}) def test_map_json_error(self): """Test mapping JSON parsing error.""" exception = ValueError("Invalid JSON format") with patch('utils.metrics.increment') as mock_increment: result = map_http_exception_to_error_response(exception, "test_service") assert result["error"] == "Invalid JSON response" assert result["error_type"] == "ValueError" mock_increment.assert_called_with("api_errors", tags={"service": "test_service", "error_type": "json_error"}) def test_map_unknown_exception(self): """Test mapping unknown exception.""" exception = RuntimeError("Unknown error") with patch('utils.metrics.increment') as mock_increment: result = map_http_exception_to_error_response(exception, "test_service") assert result["error"] == "Request failed" assert result["error_type"] == "RuntimeError" assert result["error_details"] == "Unknown error" mock_increment.assert_called_with("api_errors", tags={"service": "test_service", "error_type": "unknown"}) class TestTimingDecorator: """Test request timing decorator.""" @patch('utils.metrics.increment') @patch('utils.metrics.histogram') @patch('utils.metrics.gauge') def test_time_request_sync_success(self, mock_gauge, mock_histogram, mock_increment): """Test timing decorator for successful sync function.""" @time_request("test_service", "test_operation") def test_function(): time.sleep(0.01) # Small delay for timing return "success" result = test_function() assert result == "success" mock_increment.assert_called_with("api_requests_total", tags={ "service": "test_service", "operation": "test_operation", "status": "success" }) mock_histogram.assert_called() mock_gauge.assert_called() @patch('utils.metrics.increment') @patch('utils.metrics.histogram') def test_time_request_sync_error(self, mock_histogram, mock_increment): """Test timing decorator for failed sync function.""" @time_request("test_service", "test_operation") def test_function(): raise ValueError("Test error") with pytest.raises(ValueError): test_function() mock_increment.assert_called_with("api_requests_total", tags={ "service": "test_service", "operation": "test_operation", "status": "error" }) mock_histogram.assert_called() @patch('utils.metrics.increment') @patch('utils.metrics.histogram') @patch('utils.metrics.gauge') @pytest.mark.asyncio async def test_time_request_async_success(self, mock_gauge, mock_histogram, mock_increment): """Test timing decorator for successful async function.""" @time_request("test_service", "test_operation") async def test_function(): await asyncio.sleep(0.01) # Small delay for timing return "success" result = await test_function() assert result == "success" mock_increment.assert_called_with("api_requests_total", tags={ "service": "test_service", "operation": "test_operation", "status": "success" }) mock_histogram.assert_called() mock_gauge.assert_called() @patch('utils.metrics.increment') @patch('utils.metrics.histogram') @pytest.mark.asyncio async def test_time_request_async_error(self, mock_histogram, mock_increment): """Test timing decorator for failed async function.""" @time_request("test_service", "test_operation") async def test_function(): raise ValueError("Test error") with pytest.raises(ValueError): await test_function() mock_increment.assert_called_with("api_requests_total", tags={ "service": "test_service", "operation": "test_operation", "status": "error" }) mock_histogram.assert_called() class TestResponseProcessing: """Test response processing utilities.""" def test_extract_studies_from_response_studies_key(self): """Test extracting studies from response with 'studies' key.""" response_data = { "studies": [ {"nct_id": "NCT12345", "title": "Study 1"}, {"nct_id": "NCT67890", "title": "Study 2"} ] } with patch('utils.metrics.gauge') as mock_gauge: studies = extract_studies_from_response(response_data) assert len(studies) == 2 assert studies[0]["nct_id"] == "NCT12345" assert studies[1]["nct_id"] == "NCT67890" mock_gauge.assert_called_with("api_studies_returned", 2, tags={"service": "clinicaltrials"}) def test_extract_studies_from_response_study_key(self): """Test extracting studies from response with 'Study' key.""" response_data = { "Study": [ {"nct_id": "NCT12345", "title": "Study 1"} ] } with patch('utils.metrics.gauge') as mock_gauge: studies = extract_studies_from_response(response_data) assert len(studies) == 1 assert studies[0]["nct_id"] == "NCT12345" mock_gauge.assert_called_with("api_studies_returned", 1, tags={"service": "clinicaltrials"}) def test_extract_studies_from_response_single_study(self): """Test extracting single study (not in list).""" response_data = { "studies": {"nct_id": "NCT12345", "title": "Study 1"} } with patch('utils.metrics.gauge'): studies = extract_studies_from_response(response_data) assert len(studies) == 1 assert studies[0]["nct_id"] == "NCT12345" def test_extract_studies_from_response_no_studies(self): """Test extracting from response with no studies.""" response_data = {"other_data": "value"} with patch('utils.metrics.gauge') as mock_gauge: studies = extract_studies_from_response(response_data) assert len(studies) == 0 mock_gauge.assert_called_with("api_studies_returned", 0, tags={"service": "clinicaltrials"}) def test_extract_studies_from_response_error(self): """Test extracting studies with error.""" response_data = None with patch('utils.metrics.increment') as mock_increment, \ patch('utils.metrics.gauge'): studies = extract_studies_from_response(response_data) # type: ignore[arg-type] assert len(studies) == 0 mock_increment.assert_called_with("response_processing_errors", tags={"error_type": "studies_extraction"}) def test_process_json_response_valid(self): """Test processing valid JSON response.""" response_text = '{"success": true, "data": "test"}' with patch('utils.metrics.gauge') as mock_gauge: result = process_json_response(response_text, "test_service") assert result["success"] is True assert result["data"] == "test" mock_gauge.assert_called_with("api_response_size", len(response_text), tags={"service": "test_service"}) def test_process_json_response_invalid(self): """Test processing invalid JSON response.""" response_text = "invalid json" with patch('utils.metrics.increment') as mock_increment, \ patch('utils.metrics.gauge'): result = process_json_response(response_text, "test_service") assert result["error"] == "Invalid JSON response" assert result["studies"] == [] mock_increment.assert_called_with("response_processing_errors", tags={ "service": "test_service", "error_type": "json_parsing" }) def test_process_json_response_missing_fields(self): """Test processing JSON response with missing expected fields.""" response_text = '{"data": "test"}' expected_fields = ["success", "message"] with patch('utils.metrics.increment') as mock_increment: result = process_json_response(response_text, "test_service", expected_fields) assert result["data"] == "test" mock_increment.assert_called_with("response_validation_warnings", tags={ "service": "test_service", "warning_type": "missing_fields" }) class TestConfigurationHelpers: """Test configuration helper functions.""" def test_get_service_config_basic(self): """Test basic service configuration retrieval.""" config_dict = { "test_service": { "timeout": 15.0, "custom_option": "value" } } result = get_service_config("test_service", config_dict) assert result["timeout"] == 15.0 assert result["custom_option"] == "value" assert result["max_retries"] == 3 # Default value assert result["circuit_breaker_threshold"] == 5 # Default value def test_get_service_config_clinicaltrials(self): """Test ClinicalTrials service configuration with defaults.""" config_dict = {} result = get_service_config("clinicaltrials", config_dict) assert result["timeout"] == 10.0 # Service-specific default assert result["base_url"] == "https://clinicaltrials.gov/api/" assert result["max_retries"] == 3 # Common default def test_get_service_config_anthropic(self): """Test Anthropic service configuration with defaults.""" config_dict = {} result = get_service_config("anthropic", config_dict) assert result["timeout"] == 60.0 # Service-specific default assert result["base_url"] == "https://api.anthropic.com/" assert result["model"] == "claude-3-5-sonnet-20241022" assert result["max_tokens"] == 1000 def test_get_service_config_override(self): """Test service configuration with overrides.""" config_dict = { "anthropic": { "timeout": 120.0, # Override default "max_tokens": 2000 # Override default } } result = get_service_config("anthropic", config_dict) assert result["timeout"] == 120.0 # Overridden assert result["max_tokens"] == 2000 # Overridden assert result["model"] == "claude-3-5-sonnet-20241022" # Still default class TestSessionManager: """Test the SessionManager class.""" def test_session_manager_sync(self): """Test SessionManager in sync mode.""" manager = SessionManager(async_mode=False) with patch('requests.Session') as mock_session_class: mock_session = Mock() mock_session_class.return_value = mock_session session = manager.get_session("test_service", headers={"Custom": "Header"}) assert session == mock_session mock_session.headers.update.assert_called_with({"Custom": "Header"}) def test_session_manager_async(self): """Test SessionManager in async mode.""" manager = SessionManager(async_mode=True) with patch('httpx.AsyncClient') as mock_client_class: mock_client = Mock() mock_client_class.return_value = mock_client client = manager.get_session("test_service", headers={"Custom": "Header"}, timeout=30.0, base_url="https://api.example.com") assert client == mock_client mock_client_class.assert_called_with( headers={"Custom": "Header"}, timeout=30.0, base_url="https://api.example.com" ) def test_session_manager_reuse(self): """Test that SessionManager reuses sessions.""" manager = SessionManager(async_mode=False) with patch('requests.Session') as mock_session_class: mock_session = Mock() mock_session_class.return_value = mock_session session1 = manager.get_session("test_service") session2 = manager.get_session("test_service") assert session1 is session2 assert mock_session_class.call_count == 1 def test_session_manager_close_all_sync(self): """Test closing all sync sessions.""" manager = SessionManager(async_mode=False) with patch('requests.Session') as mock_session_class: mock_session = Mock() mock_session_class.return_value = mock_session manager.get_session("test_service") manager.close_all() mock_session.close.assert_called_once() assert len(manager._sessions) == 0 @pytest.mark.asyncio async def test_session_manager_close_all_async(self): """Test closing all async sessions.""" manager = SessionManager(async_mode=True) with patch('httpx.AsyncClient') as mock_client_class: mock_client = Mock() mock_client.aclose = AsyncMock() mock_client_class.return_value = mock_client manager.get_session("test_service") await manager.aclose_all() mock_client.aclose.assert_called_once() assert len(manager._sessions) == 0

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/pickleton89/mutation-clinical-trial-matching-mcp'

If you have feedback or need assistance with the MCP directory API, please join our Discord server