"""
单元测试:AI Router 模块测试
"""
import pytest
import asyncio
from unittest.mock import Mock, AsyncMock, patch, MagicMock
from typing import Dict, Any, List
from src.percepta_mcp.ai_router import (
AIResponse,
AIProvider,
OpenAIProvider,
AnthropicProvider,
GoogleProvider,
OllamaProvider,
AIRouter,
get_ai_router,
reset_ai_router,
AIProviderError
)
from src.percepta_mcp.config import AIProviderConfig, Settings
class TestAIResponse:
"""测试 AI 响应数据类"""
def test_ai_response_creation(self) -> None:
"""测试AI响应创建"""
response = AIResponse(
content="Test response",
provider="test-provider",
model="test-model",
tokens_used=100,
cost=0.05,
response_time=1.5
)
assert response.content == "Test response"
assert response.provider == "test-provider"
assert response.model == "test-model"
assert response.tokens_used == 100
assert response.cost == 0.05
assert response.response_time == 1.5
assert response.error is None
def test_ai_response_with_error(self) -> None:
"""测试包含错误的AI响应"""
response = AIResponse(
content="",
provider="test-provider",
model="test-model",
error="Test error"
)
assert response.content == ""
assert response.error == "Test error"
assert response.tokens_used == 0
assert response.cost == 0.0
class TestAIProvider:
"""测试 AI 提供商基类"""
def test_ai_provider_initialization(self) -> None:
"""测试AI提供商初始化"""
config = AIProviderConfig(
name="test-provider",
type="test",
model="test-model",
rate_limit=100
)
provider = AIProvider(config)
assert provider.config == config
assert provider.request_count == 0
assert provider.last_request_time == 0.0
def test_rate_limiting(self) -> None:
"""测试速率限制功能"""
config = AIProviderConfig(
name="test-provider",
type="test",
model="test-model",
rate_limit=2
)
provider = AIProvider(config)
# 初始状态应该允许请求
assert provider.can_handle_request() is True
# 记录请求
provider.record_request()
assert provider.request_count == 1
assert provider.can_handle_request() is True
provider.record_request()
assert provider.request_count == 2
assert provider.can_handle_request() is False # Should be False as we hit the limit
# 超过限制
provider.record_request()
assert provider.request_count == 3
assert provider.can_handle_request() is False
@pytest.mark.asyncio
async def test_generate_not_implemented(self) -> None:
"""测试基类的生成方法未实现"""
config = AIProviderConfig(
name="test-provider",
type="test",
model="test-model"
)
provider = AIProvider(config)
with pytest.raises(NotImplementedError):
await provider.generate("test prompt")
class TestOpenAIProvider:
"""测试 OpenAI 提供商"""
def test_openai_provider_initialization(self) -> None:
"""测试OpenAI提供商初始化"""
config = AIProviderConfig(
name="openai-test",
type="openai",
model="gpt-4",
api_key="test-key"
)
with patch('openai.AsyncOpenAI') as mock_client:
provider = OpenAIProvider(config)
assert provider.config == config
mock_client.assert_called_once_with(
api_key="test-key",
base_url=None
)
@pytest.mark.asyncio
async def test_openai_generate_success(self) -> None:
"""测试OpenAI成功生成响应"""
config = AIProviderConfig(
name="openai-test",
type="openai",
model="gpt-4",
api_key="test-key",
cost_per_token=0.001
)
# Mock OpenAI response
mock_response = Mock()
mock_response.choices = [Mock()]
mock_response.choices[0].message.content = "Test response"
mock_response.usage.total_tokens = 100
mock_client = AsyncMock()
mock_client.chat.completions.create.return_value = mock_response
with patch('openai.AsyncOpenAI') as mock_openai:
mock_openai.return_value = mock_client
provider = OpenAIProvider(config)
response = await provider.generate("test prompt")
assert response.content == "Test response"
assert response.provider == "openai-test"
assert response.model == "gpt-4"
assert response.tokens_used == 100
assert response.cost == 0.1 # 100 * 0.001
assert response.error is None
@pytest.mark.asyncio
async def test_openai_generate_error(self) -> None:
"""测试OpenAI生成响应时的错误处理"""
config = AIProviderConfig(
name="openai-test",
type="openai",
model="gpt-4",
api_key="test-key"
)
mock_client = AsyncMock()
mock_client.chat.completions.create.side_effect = Exception("API Error")
with patch('openai.AsyncOpenAI') as mock_openai:
mock_openai.return_value = mock_client
provider = OpenAIProvider(config)
response = await provider.generate("test prompt")
assert response.content == ""
assert response.provider == "openai-test"
assert response.model == "gpt-4"
assert response.error == "API Error"
class TestAnthropicProvider:
"""测试 Anthropic 提供商"""
@pytest.mark.asyncio
async def test_anthropic_generate_success(self) -> None:
"""测试Anthropic成功生成响应"""
config = AIProviderConfig(
name="anthropic-test",
type="anthropic",
model="claude-3-sonnet",
api_key="test-key",
cost_per_token=0.003
)
# Mock Anthropic response
mock_response = Mock()
mock_response.content = [Mock()]
mock_response.content[0].text = "Test response"
mock_response.usage.input_tokens = 50
mock_response.usage.output_tokens = 50
mock_client = AsyncMock()
mock_client.messages.create.return_value = mock_response
with patch('anthropic.AsyncAnthropic') as mock_anthropic:
mock_anthropic.return_value = mock_client
provider = AnthropicProvider(config)
response = await provider.generate("test prompt")
assert response.content == "Test response"
assert response.provider == "anthropic-test"
assert response.model == "claude-3-sonnet"
assert response.tokens_used == 100 # 50 + 50
assert response.cost == 0.3 # 100 * 0.003
assert response.error is None
class TestGoogleProvider:
"""测试 Google 提供商"""
@pytest.mark.asyncio
async def test_google_generate_success(self) -> None:
"""测试Google成功生成响应"""
config = AIProviderConfig(
name="google-test",
type="google",
model="gemini-pro",
api_key="test-key",
cost_per_token=0.0005
)
# Mock Google response
mock_response = Mock()
mock_response.text = "Test response"
mock_response.usage_metadata.total_token_count = 80
mock_model = AsyncMock()
mock_model.generate_content_async.return_value = mock_response
with patch('google.generativeai.configure') as mock_configure:
with patch('google.generativeai.GenerativeModel') as mock_genai:
mock_genai.return_value = mock_model
provider = GoogleProvider(config)
response = await provider.generate("test prompt")
assert response.content == "Test response"
assert response.provider == "google-test"
assert response.model == "gemini-pro"
assert response.tokens_used == 80
assert response.cost == 0.04 # 80 * 0.0005
assert response.error is None
class TestOllamaProvider:
"""测试 Ollama 提供商"""
@pytest.mark.asyncio
async def test_ollama_generate_success(self) -> None:
"""测试Ollama成功生成响应"""
config = AIProviderConfig(
name="ollama-test",
type="ollama",
model="llama2",
base_url="http://localhost:11434"
)
# Mock Ollama response
mock_response = {
'response': 'Test response'
}
mock_client = AsyncMock()
mock_client.generate.return_value = mock_response
with patch('ollama.AsyncClient') as mock_ollama:
mock_ollama.return_value = mock_client
provider = OllamaProvider(config)
response = await provider.generate("test prompt")
assert response.content == "Test response"
assert response.provider == "ollama-test"
assert response.model == "llama2"
assert response.tokens_used == 0 # Ollama doesn't provide token counts
assert response.cost == 0.0 # Local model, no cost
assert response.error is None
class TestAIRouter:
"""测试 AI 路由器"""
def test_ai_router_initialization(self) -> None:
"""测试AI路由器初始化"""
settings = Settings(
ai_providers=[
AIProviderConfig(
name="test-provider",
type="openai",
model="gpt-4",
enabled=True
)
]
)
with patch('src.percepta_mcp.ai_router.OpenAIProvider'):
router = AIRouter(settings)
assert router.settings == settings
assert isinstance(router.providers, dict)
def test_get_available_providers(self) -> None:
"""测试获取可用提供商"""
settings = Settings(
ai_providers=[
AIProviderConfig(
name="provider1",
type="openai",
model="gpt-4",
enabled=True
),
AIProviderConfig(
name="provider2",
type="anthropic",
model="claude-3",
enabled=True
)
]
)
mock_provider1 = Mock()
mock_provider1.can_handle_request.return_value = True
mock_provider2 = Mock()
mock_provider2.can_handle_request.return_value = False
with patch('src.percepta_mcp.ai_router.OpenAIProvider', return_value=mock_provider1):
with patch('src.percepta_mcp.ai_router.AnthropicProvider', return_value=mock_provider2):
router = AIRouter(settings)
available = router.get_available_providers()
assert "provider1" in available
assert "provider2" not in available
def test_get_provider_by_priority(self) -> None:
"""测试按优先级获取提供商"""
settings = Settings(
ai_providers=[
AIProviderConfig(
name="low-priority",
type="openai",
model="gpt-4",
priority=3,
enabled=True
),
AIProviderConfig(
name="high-priority",
type="anthropic",
model="claude-3",
priority=1,
enabled=True
)
]
)
mock_provider1 = Mock()
mock_provider1.can_handle_request.return_value = True
mock_provider2 = Mock()
mock_provider2.can_handle_request.return_value = True
with patch('src.percepta_mcp.ai_router.OpenAIProvider', return_value=mock_provider1):
with patch('src.percepta_mcp.ai_router.AnthropicProvider', return_value=mock_provider2):
router = AIRouter(settings)
provider = router.get_provider_by_priority()
# 应该返回高优先级的提供商
assert provider == mock_provider2
@pytest.mark.asyncio
async def test_generate_with_specific_provider(self) -> None:
"""测试使用指定提供商生成响应"""
settings = Settings(
ai_providers=[
AIProviderConfig(
name="test-provider",
type="openai",
model="gpt-4",
enabled=True
)
]
)
mock_response = AIResponse(
content="Test response",
provider="test-provider",
model="gpt-4"
)
mock_provider = AsyncMock()
mock_provider.generate.return_value = mock_response
mock_provider.config.retry_attempts = 3
with patch('src.percepta_mcp.ai_router.OpenAIProvider', return_value=mock_provider):
router = AIRouter(settings)
response = await router.generate("test prompt", provider_name="test-provider")
assert response.content == "Test response"
assert response.provider == "test-provider"
mock_provider.generate.assert_called_once_with("test prompt")
@pytest.mark.asyncio
async def test_generate_provider_not_available(self) -> None:
"""测试提供商不可用时的处理"""
settings = Settings(ai_providers=[])
router = AIRouter(settings)
response = await router.generate("test prompt", provider_name="nonexistent")
assert response.content == ""
assert response.provider == "nonexistent"
assert "not available" in response.error
@pytest.mark.asyncio
async def test_health_check(self) -> None:
"""测试健康检查"""
settings = Settings(
ai_providers=[
AIProviderConfig(
name="healthy-provider",
type="openai",
model="gpt-4",
enabled=True
),
AIProviderConfig(
name="unhealthy-provider",
type="anthropic",
model="claude-3",
enabled=True
)
]
)
# Mock healthy provider
healthy_response = AIResponse(
content="Hello",
provider="healthy-provider",
model="gpt-4",
response_time=0.5
)
mock_healthy = AsyncMock()
mock_healthy.generate.return_value = healthy_response
mock_healthy.can_handle_request = MagicMock(return_value=True)
# Mock unhealthy provider
unhealthy_response = AIResponse(
content="",
provider="unhealthy-provider",
model="claude-3",
error="API Error",
response_time=1.0
)
mock_unhealthy = AsyncMock()
mock_unhealthy.generate.return_value = unhealthy_response
mock_unhealthy.can_handle_request = MagicMock(return_value=False)
with patch('src.percepta_mcp.ai_router.OpenAIProvider', return_value=mock_healthy):
with patch('src.percepta_mcp.ai_router.AnthropicProvider', return_value=mock_unhealthy):
router = AIRouter(settings)
health_status = await router.health_check()
assert "healthy-provider" in health_status
assert "unhealthy-provider" in health_status
assert health_status["healthy-provider"]["status"] == "healthy"
assert health_status["healthy-provider"]["response_time"] == 0.5
assert health_status["healthy-provider"]["can_handle_request"] is True
assert health_status["unhealthy-provider"]["status"] == "error"
assert health_status["unhealthy-provider"]["error"] == "API Error"
assert health_status["unhealthy-provider"]["can_handle_request"] is False
class TestGlobalRouter:
"""测试全局路由器函数"""
def test_get_ai_router_singleton(self) -> None:
"""测试全局路由器单例模式"""
reset_ai_router() # 重置状态
router1 = get_ai_router()
router2 = get_ai_router()
assert router1 is router2 # 应该是同一个实例
def test_get_ai_router_with_settings(self) -> None:
"""测试使用自定义设置获取路由器"""
reset_ai_router() # 重置状态
custom_settings = Settings(
ai_providers=[
AIProviderConfig(
name="custom-provider",
type="openai",
model="gpt-4",
enabled=True
)
]
)
with patch('src.percepta_mcp.ai_router.OpenAIProvider'):
router = get_ai_router(custom_settings)
assert router.settings == custom_settings
def test_reset_ai_router(self) -> None:
"""测试重置全局路由器"""
# 获取一个路由器实例
router1 = get_ai_router()
# 重置
reset_ai_router()
# 获取新的路由器实例
router2 = get_ai_router()
# 应该是不同的实例
assert router1 is not router2
# Pytest fixtures
@pytest.fixture
def sample_ai_config() -> AIProviderConfig:
"""提供示例AI配置"""
return AIProviderConfig(
name="test-provider",
type="openai",
model="gpt-4",
api_key="test-key",
priority=1,
enabled=True
)
@pytest.fixture
def sample_settings(sample_ai_config: AIProviderConfig) -> Settings:
"""提供示例设置"""
return Settings(
ai_providers=[sample_ai_config],
default_provider="test-provider"
)
if __name__ == "__main__":
pytest.main([__file__, "-v"])