test_per_tool_model_defaults.py•24.6 kB
"""
Test per-tool model default selection functionality
"""
import json
import os
import shutil
import tempfile
from unittest.mock import MagicMock, patch
import pytest
from providers.registry import ModelProviderRegistry, ProviderType
from tools.analyze import AnalyzeTool
from tools.chat import ChatTool
from tools.codereview import CodeReviewTool
from tools.debug import DebugIssueTool
from tools.models import ToolModelCategory
from tools.precommit import PrecommitTool
from tools.shared.base_tool import BaseTool
from tools.thinkdeep import ThinkDeepTool
class TestToolModelCategories:
"""Test that each tool returns the correct model category."""
def test_thinkdeep_category(self):
tool = ThinkDeepTool()
assert tool.get_model_category() == ToolModelCategory.EXTENDED_REASONING
def test_debug_category(self):
tool = DebugIssueTool()
assert tool.get_model_category() == ToolModelCategory.EXTENDED_REASONING
def test_analyze_category(self):
tool = AnalyzeTool()
assert tool.get_model_category() == ToolModelCategory.EXTENDED_REASONING
def test_precommit_category(self):
tool = PrecommitTool()
assert tool.get_model_category() == ToolModelCategory.EXTENDED_REASONING
def test_chat_category(self):
tool = ChatTool()
assert tool.get_model_category() == ToolModelCategory.FAST_RESPONSE
def test_codereview_category(self):
tool = CodeReviewTool()
assert tool.get_model_category() == ToolModelCategory.EXTENDED_REASONING
def test_base_tool_default_category(self):
# Test that BaseTool defaults to BALANCED
class TestTool(BaseTool):
def get_name(self):
return "test"
def get_description(self):
return "test"
def get_input_schema(self):
return {}
def get_system_prompt(self):
return "test"
def get_request_model(self):
return MagicMock
async def prepare_prompt(self, request):
return "test"
tool = TestTool()
assert tool.get_model_category() == ToolModelCategory.BALANCED
class TestModelSelection:
"""Test model selection based on tool categories."""
def teardown_method(self):
"""Clean up after each test to prevent state pollution."""
ModelProviderRegistry.clear_cache()
# Unregister all providers
for provider_type in list(ProviderType):
ModelProviderRegistry.unregister_provider(provider_type)
def test_extended_reasoning_with_openai(self):
"""Test EXTENDED_REASONING with OpenAI provider."""
# Setup with only OpenAI provider
ModelProviderRegistry.clear_cache()
# First unregister all providers to ensure isolation
for provider_type in list(ProviderType):
ModelProviderRegistry.unregister_provider(provider_type)
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}, clear=False):
from providers.openai import OpenAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
# OpenAI prefers GPT-5-Codex for extended reasoning (coding tasks)
assert model == "gpt-5-codex"
def test_extended_reasoning_with_gemini_only(self):
"""Test EXTENDED_REASONING prefers pro when only Gemini is available."""
# Clear cache and unregister all providers first
ModelProviderRegistry.clear_cache()
for provider_type in list(ProviderType):
ModelProviderRegistry.unregister_provider(provider_type)
# Register only Gemini provider
with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-key"}, clear=False):
from providers.gemini import GeminiModelProvider
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
# Gemini should return one of its models for extended reasoning
# The default behavior may return flash when pro is not explicitly preferred
assert model in ["gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.0-flash"]
def test_fast_response_with_openai(self):
"""Test FAST_RESPONSE with OpenAI provider."""
# Setup with only OpenAI provider
ModelProviderRegistry.clear_cache()
# First unregister all providers to ensure isolation
for provider_type in list(ProviderType):
ModelProviderRegistry.unregister_provider(provider_type)
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}, clear=False):
from providers.openai import OpenAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
# OpenAI now prefers gpt-5 for fast response (based on our new preference order)
assert model == "gpt-5"
def test_fast_response_with_gemini_only(self):
"""Test FAST_RESPONSE prefers flash when only Gemini is available."""
# Clear cache and unregister all providers first
ModelProviderRegistry.clear_cache()
for provider_type in list(ProviderType):
ModelProviderRegistry.unregister_provider(provider_type)
# Register only Gemini provider
with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-key"}, clear=False):
from providers.gemini import GeminiModelProvider
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
# Gemini should return one of its models for fast response
assert model in ["gemini-2.5-flash", "gemini-2.0-flash", "gemini-2.5-pro"]
def test_balanced_category_fallback(self):
"""Test BALANCED category uses existing logic."""
# Setup with only OpenAI provider
ModelProviderRegistry.clear_cache()
# First unregister all providers to ensure isolation
for provider_type in list(ProviderType):
ModelProviderRegistry.unregister_provider(provider_type)
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}, clear=False):
from providers.openai import OpenAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.BALANCED)
# OpenAI prefers gpt-5 for balanced (based on our new preference order)
assert model == "gpt-5"
def test_no_category_uses_balanced_logic(self):
"""Test that no category specified uses balanced logic."""
# Setup with only Gemini provider
with patch.dict(os.environ, {"GEMINI_API_KEY": "test-key"}, clear=False):
from providers.gemini import GeminiModelProvider
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
model = ModelProviderRegistry.get_preferred_fallback_model()
# Should pick flash for balanced use
assert model == "gemini-2.5-flash"
class TestFlexibleModelSelection:
"""Test that model selection handles various naming scenarios."""
def test_fallback_handles_mixed_model_names(self):
"""Test that fallback selection works with different providers."""
# Test with different provider configurations
test_cases = [
# Case 1: OpenAI provider for extended reasoning
{
"env": {"OPENAI_API_KEY": "test-key"},
"provider_type": ProviderType.OPENAI,
"category": ToolModelCategory.EXTENDED_REASONING,
"expected": "gpt-5-codex", # GPT-5-Codex prioritized for coding tasks
},
# Case 2: Gemini provider for fast response
{
"env": {"GEMINI_API_KEY": "test-key"},
"provider_type": ProviderType.GOOGLE,
"category": ToolModelCategory.FAST_RESPONSE,
"expected": "gemini-2.5-flash",
},
# Case 3: OpenAI provider for fast response
{
"env": {"OPENAI_API_KEY": "test-key"},
"provider_type": ProviderType.OPENAI,
"category": ToolModelCategory.FAST_RESPONSE,
"expected": "gpt-5", # Based on new preference order
},
]
for case in test_cases:
# Clear registry for clean test
ModelProviderRegistry.clear_cache()
# First unregister all providers to ensure isolation
for provider_type in list(ProviderType):
ModelProviderRegistry.unregister_provider(provider_type)
with patch.dict(os.environ, case["env"], clear=False):
# Register the appropriate provider
if case["provider_type"] == ProviderType.OPENAI:
from providers.openai import OpenAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
elif case["provider_type"] == ProviderType.GOOGLE:
from providers.gemini import GeminiModelProvider
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
model = ModelProviderRegistry.get_preferred_fallback_model(case["category"])
assert model == case["expected"], f"Failed for case: {case}, got {model}"
class TestCustomProviderFallback:
"""Test fallback to custom/openrouter providers."""
def test_extended_reasoning_custom_fallback(self):
"""Test EXTENDED_REASONING with custom provider."""
# Setup with custom provider
ModelProviderRegistry.clear_cache()
with patch.dict(os.environ, {"CUSTOM_API_URL": "http://localhost:11434", "CUSTOM_API_KEY": ""}, clear=False):
from providers.custom import CustomProvider
ModelProviderRegistry.register_provider(ProviderType.CUSTOM, CustomProvider)
provider = ModelProviderRegistry.get_provider(ProviderType.CUSTOM)
if provider:
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
# Should get a model from custom provider
assert model is not None
def test_extended_reasoning_final_fallback(self):
"""Test EXTENDED_REASONING falls back to default when no providers."""
# Clear all providers
ModelProviderRegistry.clear_cache()
for provider_type in list(
ModelProviderRegistry._instance._providers.keys() if ModelProviderRegistry._instance else []
):
ModelProviderRegistry.unregister_provider(provider_type)
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
# Should fall back to hardcoded default
assert model == "gemini-2.5-flash"
class TestAutoModeErrorMessages:
"""Test that auto mode error messages include suggested models."""
def teardown_method(self):
"""Clean up after each test to prevent state pollution."""
# Clear provider registry singleton
ModelProviderRegistry._instance = None
@pytest.mark.asyncio
async def test_chat_auto_error_message(self):
"""Test Chat tool suggests appropriate model in auto mode."""
with patch("config.IS_AUTO_MODE", True):
with patch("config.DEFAULT_MODEL", "auto"):
with patch.object(ModelProviderRegistry, "get_available_models") as mock_get_available:
# Mock OpenAI models available
mock_get_available.return_value = {
"o3": ProviderType.OPENAI,
"o3-mini": ProviderType.OPENAI,
"o4-mini": ProviderType.OPENAI,
}
# Mock the provider lookup to return None for auto model
with patch.object(ModelProviderRegistry, "get_provider_for_model") as mock_get_provider_for:
mock_get_provider_for.return_value = None
tool = ChatTool()
temp_dir = tempfile.mkdtemp()
try:
result = await tool.execute(
{"prompt": "test", "model": "auto", "working_directory": temp_dir}
)
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
assert len(result) == 1
# The SimpleTool will wrap the error message
error_output = json.loads(result[0].text)
assert error_output["status"] == "error"
assert "Model 'auto' is not available" in error_output["content"]
# Removed TestFileContentPreparation class
# The original test was using MagicMock which caused TypeErrors when comparing with integers
# The test has been removed to avoid mocking issues and encourage real integration testing
class TestProviderHelperMethods:
"""Test the helper methods for finding models from custom/openrouter."""
def test_extended_reasoning_with_custom_provider(self):
"""Test extended reasoning model selection with custom provider."""
# Setup with custom provider
with patch.dict(os.environ, {"CUSTOM_API_URL": "http://localhost:11434", "CUSTOM_API_KEY": ""}, clear=False):
from providers.custom import CustomProvider
ModelProviderRegistry.register_provider(ProviderType.CUSTOM, CustomProvider)
provider = ModelProviderRegistry.get_provider(ProviderType.CUSTOM)
if provider:
# Custom provider should return a model for extended reasoning
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
assert model is not None
def test_extended_reasoning_with_openrouter(self):
"""Test extended reasoning model selection with OpenRouter."""
# Setup with OpenRouter provider
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}, clear=False):
from providers.openrouter import OpenRouterProvider
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
# OpenRouter should provide a model for extended reasoning
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
# Should return first available OpenRouter model
assert model is not None
def test_fallback_when_no_providers_available(self):
"""Test fallback when no providers are available."""
# Clear all providers
ModelProviderRegistry.clear_cache()
for provider_type in list(
ModelProviderRegistry._instance._providers.keys() if ModelProviderRegistry._instance else []
):
ModelProviderRegistry.unregister_provider(provider_type)
# Should return hardcoded fallback
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.EXTENDED_REASONING)
assert model == "gemini-2.5-flash"
class TestEffectiveAutoMode:
"""Test the is_effective_auto_mode method."""
def test_explicit_auto_mode(self):
"""Test when DEFAULT_MODEL is explicitly 'auto'."""
with patch("config.DEFAULT_MODEL", "auto"):
with patch("config.IS_AUTO_MODE", True):
tool = ChatTool()
assert tool.is_effective_auto_mode() is True
def test_unavailable_model_triggers_auto_mode(self):
"""Test when DEFAULT_MODEL is set but not available."""
with patch("config.DEFAULT_MODEL", "o3"):
with patch("config.IS_AUTO_MODE", False):
with patch.object(ModelProviderRegistry, "get_provider_for_model") as mock_get_provider:
mock_get_provider.return_value = None # Model not available
tool = ChatTool()
assert tool.is_effective_auto_mode() is True
def test_available_model_no_auto_mode(self):
"""Test when DEFAULT_MODEL is set and available."""
with patch("config.DEFAULT_MODEL", "pro"):
with patch("config.IS_AUTO_MODE", False):
with patch.object(ModelProviderRegistry, "get_provider_for_model") as mock_get_provider:
mock_get_provider.return_value = MagicMock() # Model is available
tool = ChatTool()
assert tool.is_effective_auto_mode() is False
class TestRuntimeModelSelection:
"""Test runtime model selection behavior."""
def teardown_method(self):
"""Clean up after each test to prevent state pollution."""
# Clear provider registry singleton
ModelProviderRegistry._instance = None
@pytest.mark.asyncio
async def test_explicit_auto_in_request(self):
"""Test when Claude explicitly passes model='auto'."""
with patch("config.DEFAULT_MODEL", "pro"): # DEFAULT_MODEL is a real model
with patch("config.IS_AUTO_MODE", False): # Not in auto mode
tool = ThinkDeepTool()
result = await tool.execute(
{
"step": "test",
"step_number": 1,
"total_steps": 1,
"next_step_required": False,
"findings": "test",
"model": "auto",
}
)
# Should require model selection even though DEFAULT_MODEL is valid
assert len(result) == 1
assert "Model 'auto' is not available" in result[0].text
@pytest.mark.asyncio
async def test_unavailable_model_in_request(self):
"""Test when Claude passes an unavailable model."""
with patch("config.DEFAULT_MODEL", "pro"):
with patch("config.IS_AUTO_MODE", False):
with patch.object(ModelProviderRegistry, "get_provider_for_model") as mock_get_provider:
# Model is not available
mock_get_provider.return_value = None
tool = ChatTool()
temp_dir = tempfile.mkdtemp()
try:
result = await tool.execute(
{"prompt": "test", "model": "gpt-5-turbo", "working_directory": temp_dir}
)
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
# Should require model selection
assert len(result) == 1
# When a specific model is requested but not available, error message is different
error_output = json.loads(result[0].text)
assert error_output["status"] == "error"
assert "gpt-5-turbo" in error_output["content"]
assert "is not available" in error_output["content"]
class TestSchemaGeneration:
"""Test schema generation with different configurations."""
def test_schema_with_explicit_auto_mode(self):
"""Test schema when DEFAULT_MODEL='auto'."""
with patch("config.DEFAULT_MODEL", "auto"):
with patch("config.IS_AUTO_MODE", True):
tool = ChatTool()
schema = tool.get_input_schema()
# Model should be required
assert "model" in schema["required"]
def test_schema_with_unavailable_default_model(self):
"""Test schema when DEFAULT_MODEL is set but unavailable."""
with patch("config.DEFAULT_MODEL", "o3"):
with patch("config.IS_AUTO_MODE", False):
with patch.object(ModelProviderRegistry, "get_provider_for_model") as mock_get_provider:
mock_get_provider.return_value = None # Model not available
tool = AnalyzeTool()
schema = tool.get_input_schema()
# Model should be required due to unavailable DEFAULT_MODEL
assert "model" in schema["required"]
def test_schema_with_available_default_model(self):
"""Test schema when DEFAULT_MODEL is available."""
with patch("config.DEFAULT_MODEL", "pro"):
with patch("config.IS_AUTO_MODE", False):
with patch.object(ModelProviderRegistry, "get_provider_for_model") as mock_get_provider:
mock_get_provider.return_value = MagicMock() # Model is available
tool = ThinkDeepTool()
schema = tool.get_input_schema()
# Model should remain optional when DEFAULT_MODEL is available
assert "model" not in schema["required"]
class TestUnavailableModelFallback:
"""Test fallback behavior when DEFAULT_MODEL is not available."""
@pytest.mark.asyncio
async def test_unavailable_default_model_fallback(self):
"""Test that unavailable DEFAULT_MODEL triggers auto mode behavior."""
with patch("config.DEFAULT_MODEL", "o3"): # Set DEFAULT_MODEL to a specific model
with patch("config.IS_AUTO_MODE", False): # Not in auto mode
with patch.object(ModelProviderRegistry, "get_provider_for_model") as mock_get_provider:
# Model is not available (no provider)
mock_get_provider.return_value = None
tool = ThinkDeepTool()
result = await tool.execute(
{
"step": "test",
"step_number": 1,
"total_steps": 1,
"next_step_required": False,
"findings": "test",
}
) # No model specified
# Should get model error since fallback model is also unavailable
assert len(result) == 1
# Workflow tools try fallbacks and report when the fallback model is not available
assert "is not available" in result[0].text
# Should list available models in the error
assert "Available models:" in result[0].text
@pytest.mark.asyncio
async def test_available_default_model_no_fallback(self):
"""Test that available DEFAULT_MODEL works normally."""
with patch("config.DEFAULT_MODEL", "pro"):
with patch("config.IS_AUTO_MODE", False):
with patch.object(ModelProviderRegistry, "get_provider_for_model") as mock_get_provider:
# Model is available
mock_provider = MagicMock()
mock_provider.generate_content.return_value = MagicMock(content="Test response", metadata={})
mock_get_provider.return_value = mock_provider
# Mock the provider lookup in BaseTool.get_model_provider
with patch.object(BaseTool, "get_model_provider") as mock_get_model_provider:
mock_get_model_provider.return_value = mock_provider
tool = ChatTool()
temp_dir = tempfile.mkdtemp()
try:
result = await tool.execute({"prompt": "test", "working_directory": temp_dir})
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
# Should work normally, not require model parameter
assert len(result) == 1
output = json.loads(result[0].text)
assert output["status"] in ["success", "continuation_available"]
assert "Test response" in output["content"]