test_enterprise_features.py•18.3 kB
"""Comprehensive tests for enterprise features."""
import pytest
import asyncio
from unittest.mock import Mock, patch, AsyncMock
from datetime import datetime, timedelta
from gcp_mcp.config import Config
from gcp_mcp.auth import GCPAuthenticator
from gcp_mcp.tools.enterprise_logging_tools import EnterpriseLoggingTools
from gcp_mcp.tools.enterprise_monitoring_tools import EnterpriseMonitoringTools
from gcp_mcp.validation import (
ProjectValidator, TimeRangeValidator, MetricValidator,
FilterValidator, SecurityValidator, validate_tool_arguments
)
from gcp_mcp.cache import EnterpriseCache, RateLimiter, cached, rate_limited
from gcp_mcp.exceptions import ValidationError, GCPServiceError
class TestProjectValidator:
"""Test project ID validation."""
def test_valid_project_ids(self):
"""Test valid project ID formats."""
valid_ids = [
"my-project-123",
"test-env-prod",
"a12345",
"enterprise-logging-prod-01"
]
for project_id in valid_ids:
result = ProjectValidator.validate_project_id(project_id)
assert result == project_id
def test_invalid_project_ids(self):
"""Test invalid project ID formats."""
invalid_ids = [
"", # Empty
"1project", # Starts with number
"project_name", # Underscore not allowed
"PROJECT-NAME", # Uppercase not allowed
"project-", # Ends with hyphen
"a", # Too short
"a" * 31, # Too long
]
for project_id in invalid_ids:
with pytest.raises(ValidationError):
ProjectValidator.validate_project_id(project_id)
def test_project_list_validation(self):
"""Test project list validation."""
# Valid list
valid_list = ["project-1", "project-2", "project-3"]
result = ProjectValidator.validate_project_list(valid_list)
assert result == valid_list
# Empty list
with pytest.raises(ValidationError):
ProjectValidator.validate_project_list([])
# Too many projects
too_many = [f"project-{i}" for i in range(51)]
with pytest.raises(ValidationError):
ProjectValidator.validate_project_list(too_many)
class TestTimeRangeValidator:
"""Test time range validation."""
def test_valid_durations(self):
"""Test valid duration formats."""
valid_durations = ["30s", "5m", "2h", "7d", "1w"]
for duration in valid_durations:
result = TimeRangeValidator.validate_duration(duration)
assert result == duration
def test_invalid_durations(self):
"""Test invalid duration formats."""
invalid_durations = [
"", # Empty
"30", # No unit
"30x", # Invalid unit
"abc", # Non-numeric
"99999h", # Too large
]
for duration in invalid_durations:
with pytest.raises(ValidationError):
TimeRangeValidator.validate_duration(duration)
def test_time_range_validation(self):
"""Test time range validation."""
# Valid relative time
result = TimeRangeValidator.validate_time_range("1h")
assert result["start"] == "1h"
# Valid ISO time
iso_time = "2024-01-01T00:00:00Z"
result = TimeRangeValidator.validate_time_range(iso_time)
assert result["start"] == iso_time
# Invalid time
with pytest.raises(ValidationError):
TimeRangeValidator.validate_time_range("invalid-time")
class TestMetricValidator:
"""Test metric validation."""
def test_valid_metric_types(self):
"""Test valid metric type formats."""
valid_types = [
"compute.googleapis.com/instance/cpu/utilization",
"logging.googleapis.com/log_entry_count",
"storage.googleapis.com/storage/total_bytes"
]
for metric_type in valid_types:
result = MetricValidator.validate_metric_type(metric_type)
assert result == metric_type
def test_invalid_metric_types(self):
"""Test invalid metric type formats."""
invalid_types = [
"", # Empty
"invalid", # No dots
"compute.cpu", # Not enough parts
"COMPUTE.GOOGLEAPIS.COM/CPU", # Uppercase
]
for metric_type in invalid_types:
with pytest.raises(ValidationError):
MetricValidator.validate_metric_type(metric_type)
def test_aggregation_config_validation(self):
"""Test aggregation configuration validation."""
# Valid config
valid_config = {
"alignment_period": "60s",
"per_series_aligner": "ALIGN_MEAN",
"cross_series_reducer": "REDUCE_MEAN"
}
result = MetricValidator.validate_aggregation_config(valid_config)
assert result == valid_config
# Invalid alignment period
invalid_config = {"alignment_period": "invalid"}
with pytest.raises(ValidationError):
MetricValidator.validate_aggregation_config(invalid_config)
class TestSecurityValidator:
"""Test security validation."""
def test_search_term_validation(self):
"""Test search term validation for security concerns."""
# Safe search terms
safe_terms = ["error", "warning", "timeout", "database connection"]
result = SecurityValidator.validate_search_terms(safe_terms)
assert result == safe_terms
# Dangerous search terms
dangerous_terms = [
"password=secret123",
"api_key=abc123",
"bearer token123"
]
for term in dangerous_terms:
with pytest.raises(ValidationError):
SecurityValidator.validate_search_terms([term])
def test_compliance_framework_validation(self):
"""Test compliance framework validation."""
# Valid frameworks
valid_frameworks = ["SOC2", "PCI-DSS", "HIPAA", "GDPR"]
for framework in valid_frameworks:
result = SecurityValidator.validate_compliance_framework(framework)
assert result == framework.upper()
# Invalid framework
with pytest.raises(ValidationError):
SecurityValidator.validate_compliance_framework("INVALID_FRAMEWORK")
class TestEnterpriseCache:
"""Test enterprise caching functionality."""
@pytest.fixture
def cache(self):
"""Create a test cache instance."""
return EnterpriseCache(max_size_mb=1, default_ttl=60)
@pytest.mark.asyncio
async def test_basic_cache_operations(self, cache):
"""Test basic cache get/set operations."""
# Set and get
await cache.set("test_key", "test_value")
result = await cache.get("test_key")
assert result == "test_value"
# Get non-existent key
result = await cache.get("non_existent")
assert result is None
@pytest.mark.asyncio
async def test_cache_expiration(self, cache):
"""Test cache TTL expiration."""
# Set with short TTL
await cache.set("expiring_key", "value", ttl=1)
# Should be available immediately
result = await cache.get("expiring_key")
assert result == "value"
# Wait for expiration
await asyncio.sleep(1.1)
result = await cache.get("expiring_key")
assert result is None
@pytest.mark.asyncio
async def test_cache_lru_eviction(self, cache):
"""Test LRU eviction when cache is full."""
# Fill cache with small items
for i in range(10):
await cache.set(f"key_{i}", "x" * 100000) # 100KB each
# Add one more item to trigger eviction
await cache.set("new_key", "x" * 100000)
# First key should be evicted
result = await cache.get("key_0")
assert result is None
# New key should be present
result = await cache.get("new_key")
assert result is not None
@pytest.mark.asyncio
async def test_cache_stats(self, cache):
"""Test cache statistics."""
# Generate some hits and misses
await cache.set("key1", "value1")
await cache.get("key1") # Hit
await cache.get("key2") # Miss
stats = await cache.get_stats()
assert stats["hits"] == 1
assert stats["misses"] == 1
assert stats["entries_count"] == 1
class TestRateLimiter:
"""Test rate limiting functionality."""
@pytest.fixture
def rate_limiter(self):
"""Create a test rate limiter instance."""
limiter = RateLimiter()
limiter.global_limits["queries_per_minute"] = 5 # Low limit for testing
return limiter
@pytest.mark.asyncio
async def test_rate_limit_enforcement(self, rate_limiter):
"""Test rate limit enforcement."""
identifier = "test_user"
# First 5 requests should be allowed
for i in range(5):
result = await rate_limiter.check_rate_limit(identifier, "queries_per_minute")
assert result is True
# 6th request should be blocked
result = await rate_limiter.check_rate_limit(identifier, "queries_per_minute")
assert result is False
@pytest.mark.asyncio
async def test_concurrent_limit(self, rate_limiter):
"""Test concurrent request limiting."""
rate_limiter.global_limits["concurrent_requests"] = 2
# Acquire 2 slots
assert await rate_limiter.check_concurrent_limit() is True
await rate_limiter.acquire_request_slot()
assert await rate_limiter.check_concurrent_limit() is True
await rate_limiter.acquire_request_slot()
# 3rd request should be blocked
assert await rate_limiter.check_concurrent_limit() is False
# Release a slot
await rate_limiter.release_request_slot()
assert await rate_limiter.check_concurrent_limit() is True
class TestEnterpriseLoggingTools:
"""Test enterprise logging tools."""
@pytest.fixture
def logging_tools(self):
"""Create test logging tools instance."""
config = Config()
authenticator = Mock(spec=GCPAuthenticator)
authenticator.get_project_id.return_value = "test-project"
authenticator.logging_client = Mock()
tools = EnterpriseLoggingTools(authenticator, config)
return tools
@pytest.mark.asyncio
async def test_get_tools(self, logging_tools):
"""Test that enterprise logging tools are returned."""
tools = await logging_tools.get_tools()
assert len(tools) == 6
tool_names = [tool.name for tool in tools]
expected_tools = [
"advanced_log_query",
"error_root_cause_analysis",
"security_log_analysis",
"performance_log_analysis",
"log_pattern_discovery",
"cross_service_trace_analysis"
]
for expected_tool in expected_tools:
assert expected_tool in tool_names
@pytest.mark.asyncio
async def test_advanced_log_query_validation(self, logging_tools):
"""Test advanced log query with validation."""
# Valid arguments
valid_args = {
"projects": ["test-project"],
"time_range": {"start": "1h"},
"filter_expression": "severity>=ERROR",
"advanced_options": {"max_results": 100}
}
with patch.object(logging_tools, '_execute_project_query', return_value=[]):
result = await logging_tools._advanced_log_query(valid_args)
assert len(result) == 1
assert "query_summary" in result[0].text
class TestEnterpriseMonitoringTools:
"""Test enterprise monitoring tools."""
@pytest.fixture
def monitoring_tools(self):
"""Create test monitoring tools instance."""
config = Config()
authenticator = Mock(spec=GCPAuthenticator)
authenticator.get_project_id.return_value = "test-project"
authenticator.monitoring_client = Mock()
tools = EnterpriseMonitoringTools(authenticator, config)
return tools
@pytest.mark.asyncio
async def test_get_tools(self, monitoring_tools):
"""Test that enterprise monitoring tools are returned."""
tools = await monitoring_tools.get_tools()
assert len(tools) == 6
tool_names = [tool.name for tool in tools]
expected_tools = [
"advanced_metrics_query",
"sla_slo_analysis",
"alert_policy_analysis",
"resource_optimization_analysis",
"custom_dashboard_metrics",
"infrastructure_health_check"
]
for expected_tool in expected_tools:
assert expected_tool in tool_names
class TestToolArgumentValidation:
"""Test comprehensive tool argument validation."""
def test_project_validation_in_args(self):
"""Test project validation within tool arguments."""
# Valid arguments
valid_args = {
"projects": ["valid-project-1", "valid-project-2"],
"time_range": {"start": "1h"}
}
# Mock authenticator for the validation function
with patch('gcp_mcp.validation.self') as mock_self:
mock_self.authenticator.get_project_id.return_value = "default-project"
result = validate_tool_arguments("test_tool", valid_args)
assert result["projects"] == valid_args["projects"]
def test_time_range_validation_in_args(self):
"""Test time range validation within tool arguments."""
args = {"time_range": {"start": "2h", "end": "1h"}}
with patch('gcp_mcp.validation.self') as mock_self:
result = validate_tool_arguments("test_tool", args)
assert result["time_range"]["start"] == "2h"
assert result["time_range"]["end"] == "1h"
class TestIntegration:
"""Integration tests for the complete system."""
@pytest.mark.asyncio
async def test_caching_decorator(self):
"""Test the caching decorator functionality."""
class TestService:
def __init__(self):
self.cache = EnterpriseCache(max_size_mb=1, default_ttl=60)
self.call_count = 0
@cached(ttl=60, cache_key_prefix="test")
async def expensive_operation(self, param1, param2=None):
self.call_count += 1
return f"result_{param1}_{param2}"
service = TestService()
# First call should execute the function
result1 = await service.expensive_operation("arg1", param2="arg2")
assert service.call_count == 1
assert result1 == "result_arg1_arg2"
# Second call with same args should use cache
result2 = await service.expensive_operation("arg1", param2="arg2")
assert service.call_count == 1 # Should not increment
assert result2 == result1
# Different args should execute function again
result3 = await service.expensive_operation("different", param2="args")
assert service.call_count == 2
assert result3 == "result_different_args"
@pytest.mark.asyncio
async def test_rate_limiting_decorator(self):
"""Test the rate limiting decorator functionality."""
class TestService:
def __init__(self):
self.rate_limiter = RateLimiter()
self.rate_limiter.global_limits["default"] = 2 # Very low for testing
@rate_limited(limit_type="default")
async def limited_operation(self, param):
return f"executed_{param}"
service = TestService()
# First two calls should succeed
result1 = await service.limited_operation("call1")
assert result1 == "executed_call1"
result2 = await service.limited_operation("call2")
assert result2 == "executed_call2"
# Third call should be rate limited
with pytest.raises(GCPServiceError) as exc_info:
await service.limited_operation("call3")
assert "Rate limit exceeded" in str(exc_info.value)
class TestErrorHandling:
"""Test comprehensive error handling."""
def test_validation_error_handling(self):
"""Test validation error handling."""
# Test project ID validation error
with pytest.raises(ValidationError) as exc_info:
ProjectValidator.validate_project_id("INVALID")
assert "Invalid project ID format" in str(exc_info.value)
# Test duration validation error
with pytest.raises(ValidationError) as exc_info:
TimeRangeValidator.validate_duration("invalid")
assert "Invalid duration format" in str(exc_info.value)
def test_gcp_service_error_handling(self):
"""Test GCP service error handling."""
from gcp_mcp.validation import handle_gcp_errors
from google.cloud.exceptions import Forbidden, NotFound
@handle_gcp_errors
async def test_function_403():
raise Forbidden("Access denied")
@handle_gcp_errors
async def test_function_404():
raise NotFound("Resource not found")
# Test 403 error handling
with pytest.raises(GCPServiceError) as exc_info:
await test_function_403()
assert "Access denied" in str(exc_info.value)
# Test 404 error handling
with pytest.raises(GCPServiceError) as exc_info:
await test_function_404()
assert "Resource not found" in str(exc_info.value)