"""
Comprehensive testing framework for MCP system
Includes unit tests, integration tests, and performance tests
"""
import asyncio
import pytest
import logging
import time
import tempfile
from typing import Dict, Any, List, Optional
from unittest.mock import Mock, AsyncMock, patch
from sqlalchemy import create_engine
from datetime import datetime, timedelta
import json
# Import MCP components for testing
from repositories.postgres_repository import PostgresRepository
from repositories.vector_repository import VectorRepository
from services.schema_service import SchemaService
from services.sql_service import SQLService
from services.semantic_service import SemanticService
from services.synthesis_service import SynthesisService
from services.smart_search_service import SmartSearchService
from caching.cache_manager import CacheManager, CacheBackend
from security.security_manager import SecurityManager, Permission
from monitoring.performance_monitor import PerformanceMonitor
logger = logging.getLogger(__name__)
class MockDatabase:
"""Mock database for testing"""
def __init__(self):
self.tables = {
'users': {
'columns': [
{'name': 'id', 'type': 'INTEGER', 'nullable': False},
{'name': 'email', 'type': 'VARCHAR(255)', 'nullable': False},
{'name': 'name', 'type': 'VARCHAR(100)', 'nullable': True},
{'name': 'created_at', 'type': 'TIMESTAMP', 'nullable': False}
],
'data': [
{'id': 1, 'email': 'user1@example.com', 'name': 'User One', 'created_at': '2024-01-01'},
{'id': 2, 'email': 'user2@example.com', 'name': 'User Two', 'created_at': '2024-01-02'},
]
},
'orders': {
'columns': [
{'name': 'id', 'type': 'INTEGER', 'nullable': False},
{'name': 'user_id', 'type': 'INTEGER', 'nullable': False},
{'name': 'amount', 'type': 'DECIMAL(10,2)', 'nullable': False},
{'name': 'status', 'type': 'VARCHAR(50)', 'nullable': False}
],
'data': [
{'id': 1, 'user_id': 1, 'amount': 99.99, 'status': 'completed'},
{'id': 2, 'user_id': 2, 'amount': 149.99, 'status': 'pending'},
]
}
}
def execute_query(self, sql: str):
"""Mock SQL execution"""
sql_lower = sql.lower().strip()
if sql_lower.startswith('select count(*)'):
if 'users' in sql_lower:
return {'success': True, 'data': [{'count': 2}]}
elif 'orders' in sql_lower:
return {'success': True, 'data': [{'count': 2}]}
elif sql_lower.startswith('select') and 'users' in sql_lower:
return {'success': True, 'data': self.tables['users']['data']}
elif sql_lower.startswith('select') and 'orders' in sql_lower:
return {'success': True, 'data': self.tables['orders']['data']}
return {'success': True, 'data': []}
class TestRepositories:
"""Test repository layer components"""
@pytest.fixture
def mock_engine(self):
"""Create mock database engine"""
engine = Mock()
return engine
def test_postgres_repository_initialization(self, mock_engine):
"""Test PostgresRepository initialization"""
repo = PostgresRepository(mock_engine)
assert repo.engine == mock_engine
@pytest.mark.asyncio
async def test_postgres_repository_execute_query(self, mock_engine):
"""Test SQL execution in PostgresRepository"""
repo = PostgresRepository(mock_engine)
# Mock the connection and execution
with patch.object(repo, '_execute_with_connection') as mock_execute:
mock_execute.return_value = {
'success': True,
'data': [{'count': 5}],
'rows_affected': 0,
'execution_time': 0.1
}
result = repo.execute_query("SELECT COUNT(*) FROM users")
assert result.success is True
assert result.data == [{'count': 5}]
assert result.execution_time == 0.1
def test_vector_repository_initialization(self, mock_engine):
"""Test VectorRepository initialization"""
repo = VectorRepository(mock_engine)
assert repo.engine == mock_engine
def test_vector_repository_search_fallback(self, mock_engine):
"""Test text search fallback"""
repo = VectorRepository(mock_engine)
with patch.object(repo, '_execute_text_search') as mock_search:
mock_search.return_value = [
{'content': 'Test result', 'score': 0.8, 'metadata': {}}
]
results = repo.text_search_fallback("test query")
assert len(results) == 1
assert results[0]['content'] == 'Test result'
class TestServices:
"""Test service layer components"""
@pytest.fixture
def mock_postgres_repo(self):
"""Create mock PostgresRepository"""
repo = Mock(spec=PostgresRepository)
repo.get_all_table_names.return_value = ['users', 'orders']
repo.execute_query.return_value = Mock(
success=True,
data=[{'count': 10}],
error=None
)
return repo
@pytest.fixture
def mock_vector_repo(self):
"""Create mock VectorRepository"""
repo = Mock(spec=VectorRepository)
repo.has_vector_extension.return_value = True
return repo
def test_schema_service_initialization(self, mock_postgres_repo):
"""Test SchemaService initialization"""
service = SchemaService(mock_postgres_repo)
assert service.postgres_repo == mock_postgres_repo
def test_schema_service_get_schema_info(self, mock_postgres_repo):
"""Test schema information retrieval"""
service = SchemaService(mock_postgres_repo)
# Mock database introspection
with patch.object(service, '_introspect_database') as mock_introspect:
mock_introspect.return_value = {
'tables': [
{'table_name': 'users', 'columns': []},
{'table_name': 'orders', 'columns': []}
],
'relationships': [],
'summary': 'Test database with 2 tables'
}
schema_info = service.get_schema_info()
assert len(schema_info.tables) == 2
assert 'users' in [t['table_name'] for t in schema_info.tables]
def test_sql_service_initialization(self, mock_postgres_repo):
"""Test SQLService initialization"""
schema_service = SchemaService(mock_postgres_repo)
sql_service = SQLService(mock_postgres_repo, schema_service)
assert sql_service.postgres_repo == mock_postgres_repo
assert sql_service.schema_service == schema_service
def test_sql_service_execute_safe(self, mock_postgres_repo):
"""Test safe SQL execution"""
schema_service = SchemaService(mock_postgres_repo)
sql_service = SQLService(mock_postgres_repo, schema_service)
result = sql_service.execute_safe("SELECT COUNT(*) FROM users")
assert result.success is True
def test_semantic_service_initialization(self, mock_vector_repo):
"""Test SemanticService initialization"""
service = SemanticService(mock_vector_repo)
assert service.vector_repo == mock_vector_repo
@pytest.mark.asyncio
async def test_semantic_service_search(self, mock_vector_repo):
"""Test semantic search"""
service = SemanticService(mock_vector_repo)
with patch.object(service.vector_repo, 'semantic_search_with_fallback') as mock_search:
mock_search.return_value = [
Mock(content='Test result', score=0.8, metadata={})
]
results = service.search("test query")
assert len(results) == 1
assert results[0].content == 'Test result'
class TestSmartSearch:
"""Test smart search orchestration"""
@pytest.fixture
def mock_services(self):
"""Create mock services for smart search"""
schema_service = Mock()
sql_service = Mock()
semantic_service = Mock()
synthesis_service = Mock()
# Configure mocks
schema_service.get_schema_info.return_value = Mock(
tables=[{'table_name': 'users'}],
summary='Test database'
)
sql_service.get_suggested_queries.return_value = [
Mock(sql="SELECT COUNT(*) FROM users", description="Count users")
]
sql_service.execute_safe.return_value = Mock(
success=True,
data=[{'count': 10}]
)
semantic_service.search.return_value = [
Mock(content='Semantic result', score=0.9)
]
synthesis_service.synthesize_response.return_value = "Test response"
return {
'schema': schema_service,
'sql': sql_service,
'semantic': semantic_service,
'synthesis': synthesis_service
}
@pytest.mark.asyncio
async def test_smart_search_initialization(self, mock_services):
"""Test SmartSearchService initialization"""
smart_search = SmartSearchService(
mock_services['schema'],
mock_services['sql'],
mock_services['semantic'],
mock_services['synthesis']
)
assert smart_search.schema_service == mock_services['schema']
assert smart_search.sql_service == mock_services['sql']
@pytest.mark.asyncio
async def test_smart_search_question_classification(self, mock_services):
"""Test question classification"""
smart_search = SmartSearchService(
mock_services['schema'],
mock_services['sql'],
mock_services['semantic'],
mock_services['synthesis']
)
# Test SQL classification
classification = smart_search._classify_question("How many users are there?")
assert classification.strategy.value in ['sql_only', 'hybrid']
# Test semantic classification
classification = smart_search._classify_question("What is a database?")
assert classification.strategy.value in ['semantic_only', 'hybrid']
@pytest.mark.asyncio
async def test_smart_search_full_search(self, mock_services):
"""Test full search workflow"""
smart_search = SmartSearchService(
mock_services['schema'],
mock_services['sql'],
mock_services['semantic'],
mock_services['synthesis']
)
result = await smart_search.search("How many users are there?")
assert result['success'] is True
assert 'response' in result
assert 'strategy_used' in result
class TestCaching:
"""Test caching system"""
@pytest.mark.asyncio
async def test_memory_cache_backend(self):
"""Test memory cache backend"""
cache = CacheManager(CacheBackend.MEMORY, {'max_size': 100})
# Test set and get
await cache.set('test_key', 'test_value', ttl=60)
value = await cache.get('test_key')
assert value == 'test_value'
# Test expiration
await cache.set('expire_key', 'expire_value', ttl=0)
await asyncio.sleep(0.1)
value = await cache.get('expire_key')
assert value is None
@pytest.mark.asyncio
async def test_cache_decorator(self):
"""Test cache decorator functionality"""
cache = CacheManager()
@cache.cache_decorator('test_func', ttl=60)
def test_function(arg1, arg2):
return f"{arg1}_{arg2}_{time.time()}"
# First call
result1 = test_function('a', 'b')
# Second call should return cached result
result2 = test_function('a', 'b')
assert result1 == result2
# Different arguments should not use cache
result3 = test_function('c', 'd')
assert result3 != result1
@pytest.mark.asyncio
async def test_cache_invalidation(self):
"""Test cache invalidation by tags"""
cache = CacheManager()
await cache.set('key1', 'value1', tags=['tag1'])
await cache.set('key2', 'value2', tags=['tag1'])
await cache.set('key3', 'value3', tags=['tag2'])
# Invalidate by tag
deleted_count = await cache.invalidate_by_tag('tag1')
assert deleted_count == 2
# Check that tagged items are gone
assert await cache.get('key1') is None
assert await cache.get('key2') is None
assert await cache.get('key3') == 'value3'
class TestSecurity:
"""Test security features"""
def test_security_manager_initialization(self):
"""Test SecurityManager initialization"""
security = SecurityManager()
assert security.sql_scanner is not None
assert security.rate_limiter is not None
def test_sql_security_scanner(self):
"""Test SQL security scanning"""
security = SecurityManager()
# Test safe query
safe_result = security.sql_scanner.scan_sql("SELECT * FROM users")
assert safe_result['is_safe'] is True
# Test dangerous query
dangerous_result = security.sql_scanner.scan_sql("DROP TABLE users")
assert dangerous_result['is_safe'] is False
assert dangerous_result['risk_level'] == 'high'
@pytest.mark.asyncio
async def test_rate_limiting(self):
"""Test rate limiting functionality"""
security = SecurityManager()
# First request should be allowed
result1 = await security.check_rate_limit('test_client', 'api_general')
assert result1['allowed'] is True
# Test with custom rule
from security.security_manager import RateLimitRule
security.rate_limiter.add_rule('test_rule', RateLimitRule(
requests_per_window=2,
window_seconds=60
))
# First two requests should be allowed
result1 = await security.check_rate_limit('test_client', 'test_rule')
result2 = await security.check_rate_limit('test_client', 'test_rule')
assert result1['allowed'] is True
assert result2['allowed'] is True
# Third request should be blocked
result3 = await security.check_rate_limit('test_client', 'test_rule')
assert result3['allowed'] is False
def test_api_key_generation(self):
"""Test API key generation and verification"""
security = SecurityManager()
# Generate API key
api_key = security.generate_api_key(
user_id='test_user',
name='test_key',
permissions={Permission.READ_SCHEMA, Permission.EXECUTE_SQL}
)
assert api_key.startswith('mcp_')
# Verify API key
key_record = security.verify_api_key(api_key)
assert key_record is not None
assert key_record.user_id == 'test_user'
assert Permission.READ_SCHEMA in key_record.permissions
def test_jwt_token_creation(self):
"""Test JWT token creation and verification"""
security = SecurityManager()
# Create JWT token
token = security.create_jwt_token(
user_id='test_user',
permissions=['read_schema', 'execute_sql']
)
assert isinstance(token, str)
# Verify JWT token
payload = security.verify_jwt_token(token)
assert payload is not None
assert payload['user_id'] == 'test_user'
assert 'read_schema' in payload['permissions']
class TestPerformanceMonitoring:
"""Test performance monitoring"""
def test_performance_monitor_initialization(self):
"""Test PerformanceMonitor initialization"""
monitor = PerformanceMonitor()
assert monitor.query_history is not None
assert monitor.system_history is not None
def test_query_tracking(self):
"""Test query performance tracking"""
monitor = PerformanceMonitor()
query_id = monitor.start_query('test_query', 'sql')
assert query_id in monitor.active_queries
monitor.end_query(query_id, 'sql', True, None, 5)
assert query_id not in monitor.active_queries
assert len(monitor.query_history) == 1
def test_performance_statistics(self):
"""Test performance statistics generation"""
monitor = PerformanceMonitor()
# Add some test data
monitor.start_query('query1', 'sql')
monitor.end_query('query1', 'sql', True, None, 10)
monitor.start_query('query2', 'semantic')
monitor.end_query('query2', 'semantic', False, 'timeout', 0)
stats = monitor.get_query_statistics(hours=24)
assert stats['total_queries'] == 2
assert stats['success_rate'] == 50.0
def test_performance_report(self):
"""Test comprehensive performance report"""
monitor = PerformanceMonitor()
report = monitor.get_performance_report()
assert 'health_score' in report
assert 'query_statistics' in report
assert 'system_statistics' in report
class TestIntegration:
"""Integration tests"""
@pytest.mark.asyncio
async def test_full_system_integration(self):
"""Test full system integration with mocks"""
# Create mock engine
engine = Mock()
# Create repositories with mocks
postgres_repo = Mock(spec=PostgresRepository)
vector_repo = Mock(spec=VectorRepository)
# Configure repository mocks
postgres_repo.get_all_table_names.return_value = ['users', 'orders']
postgres_repo.execute_query.return_value = Mock(
success=True,
data=[{'count': 10}],
error=None
)
vector_repo.has_vector_extension.return_value = True
vector_repo.semantic_search_with_fallback.return_value = []
# Create services
schema_service = SchemaService(postgres_repo)
sql_service = SQLService(postgres_repo, schema_service)
semantic_service = SemanticService(vector_repo)
synthesis_service = SynthesisService({})
# Mock schema service
with patch.object(schema_service, 'get_schema_info') as mock_schema:
mock_schema.return_value = Mock(
tables=[{'table_name': 'users'}],
summary='Test database'
)
# Create smart search
smart_search = SmartSearchService(
schema_service, sql_service, semantic_service, synthesis_service
)
# Test search
result = await smart_search.search("How many users?")
assert result['success'] is True
# Test configuration
@pytest.fixture(scope="session")
def test_config():
"""Test configuration"""
return {
'database': {
'connection_string': 'sqlite:///:memory:'
},
'cache': {
'backend': 'memory',
'max_size': 100
},
'security': {
'jwt_secret': 'test_secret',
'rate_limit_general': 100
}
}
# Performance test utilities
class PerformanceTest:
"""Utilities for performance testing"""
@staticmethod
async def measure_execution_time(coro):
"""Measure async function execution time"""
start_time = time.time()
result = await coro
end_time = time.time()
return result, end_time - start_time
@staticmethod
def measure_memory_usage():
"""Measure current memory usage"""
import psutil
process = psutil.Process()
return process.memory_info().rss
# Test runner
def run_all_tests():
"""Run all tests"""
pytest.main([
__file__,
'-v',
'--tb=short',
'--disable-warnings'
])
if __name__ == "__main__":
run_all_tests()