"""Tests for MCP training tools.
Tests for OpenPipe ART training tool implementations.
"""
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
import pytest
from mcp_task_aggregator.models.training import TrainingMetrics, TrainingRun, TrainingStatus
from mcp_task_aggregator.tools.training_tools import (
evaluate_model_impl,
list_training_runs_impl,
start_training_run_impl,
)
class TestListTrainingRunsImpl:
"""Tests for list_training_runs_impl function."""
def test_returns_empty_list(self):
"""Test that function returns empty list (stub implementation)."""
result = list_training_runs_impl()
assert result["runs"] == []
assert result["total"] == 0
assert result["limit"] == 10
assert result["status_filter"] is None
def test_respects_limit(self):
"""Test that limit parameter is passed through."""
result = list_training_runs_impl(limit=50)
assert result["limit"] == 50
def test_respects_status_filter(self):
"""Test that status filter is passed through."""
result = list_training_runs_impl(status="completed")
assert result["status_filter"] == "completed"
def test_combined_parameters(self):
"""Test multiple parameters together."""
result = list_training_runs_impl(status="failed", limit=25)
assert result["limit"] == 25
assert result["status_filter"] == "failed"
class TestStartTrainingRunImpl:
"""Tests for start_training_run_impl async function."""
@pytest.mark.asyncio
async def test_start_training_run_success(self):
"""Test successful training run start."""
run_uuid = uuid4()
completed_run = TrainingRun(
uuid=run_uuid,
model_name="Qwen/Qwen2.5-1.5B-Instruct",
project_name="test-project",
trained_model_name="test-model",
task_description="Test task",
num_examples=10,
status=TrainingStatus.COMPLETED,
created_at=datetime.now(),
started_at=datetime.now(),
completed_at=datetime.now(),
metrics=TrainingMetrics(
total_examples=10,
total_steps=5,
avg_reward=0.85,
),
)
with patch("mcp_task_aggregator.training.pipeline.TrainingPipeline") as mock_pipeline_cls:
mock_pipeline = AsyncMock()
mock_pipeline_cls.return_value = mock_pipeline
mock_pipeline.initialize = AsyncMock()
mock_pipeline.run_training = AsyncMock(return_value=completed_run)
result = await start_training_run_impl(
model_name="test-model",
project_name="test-project",
task_description="Test task",
num_examples=10,
)
assert result["status"] == "completed"
assert result["model_name"] == "Qwen/Qwen2.5-1.5B-Instruct"
assert result["project_name"] == "test-project"
assert result["run_id"] == str(run_uuid)
assert result.get("error") is None
@pytest.mark.asyncio
async def test_start_training_run_failure(self):
"""Test training run handles errors gracefully."""
with patch("mcp_task_aggregator.training.pipeline.TrainingPipeline") as mock_pipeline_cls:
mock_pipeline_cls.side_effect = RuntimeError("Pipeline initialization failed")
result = await start_training_run_impl(
model_name="test-model",
project_name="test-project",
task_description="Test task",
)
assert result["status"] == "failed"
assert "error" in result
assert "Pipeline initialization failed" in result["error"]
@pytest.mark.asyncio
async def test_start_training_run_with_custom_params(self):
"""Test training run with custom parameters."""
run_uuid = uuid4()
completed_run = TrainingRun(
uuid=run_uuid,
model_name="custom-base-model",
project_name="custom-project",
trained_model_name="custom-trained",
task_description="Custom task",
num_examples=50,
rollouts_per_group=6,
learning_rate=0.0001,
num_epochs=3,
max_steps=100,
status=TrainingStatus.COMPLETED,
created_at=datetime.now(),
)
with patch("mcp_task_aggregator.training.pipeline.TrainingPipeline") as mock_pipeline_cls:
mock_pipeline = AsyncMock()
mock_pipeline_cls.return_value = mock_pipeline
mock_pipeline.initialize = AsyncMock()
mock_pipeline.run_training = AsyncMock(return_value=completed_run)
result = await start_training_run_impl(
model_name="custom-trained",
project_name="custom-project",
task_description="Custom task",
base_model="custom-base-model",
num_examples=50,
rollouts_per_group=6,
learning_rate=0.0001,
num_epochs=3,
max_steps=100,
)
assert result["status"] == "completed"
mock_pipeline.initialize.assert_called_once_with(base_model="custom-base-model")
class TestEvaluateModelImpl:
"""Tests for evaluate_model_impl async function."""
@pytest.mark.asyncio
async def test_evaluate_model_success(self):
"""Test successful model evaluation."""
mock_trajectory = MagicMock()
mock_trajectory.messages.return_value = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
mock_trajectory.reward = 0.9
with patch("mcp_task_aggregator.training.pipeline.TrainingPipeline") as mock_pipeline_cls:
mock_pipeline = AsyncMock()
mock_pipeline_cls.return_value = mock_pipeline
mock_pipeline.generate_system_prompt = AsyncMock()
mock_pipeline.rollout = AsyncMock(return_value=mock_trajectory)
result = await evaluate_model_impl(
model_name="test-model",
test_inputs=["Hello", "How are you?"],
)
assert result["model_name"] == "test-model"
assert result["num_inputs"] == 2
assert result["avg_reward"] == 0.9
assert len(result["results"]) == 2
assert result["results"][0]["response"] == "Hi there!"
@pytest.mark.asyncio
async def test_evaluate_model_with_baseline(self):
"""Test evaluation with baseline model specified."""
mock_trajectory = MagicMock()
mock_trajectory.messages.return_value = [{"role": "assistant", "content": "Response"}]
mock_trajectory.reward = 0.8
with patch("mcp_task_aggregator.training.pipeline.TrainingPipeline") as mock_pipeline_cls:
mock_pipeline = AsyncMock()
mock_pipeline_cls.return_value = mock_pipeline
mock_pipeline.generate_system_prompt = AsyncMock()
mock_pipeline.rollout = AsyncMock(return_value=mock_trajectory)
result = await evaluate_model_impl(
model_name="test-model",
test_inputs=["Test input"],
baseline_model="baseline-model",
)
assert result["baseline_model"] == "baseline-model"
assert result["baseline_comparison"] is None # Not implemented yet
@pytest.mark.asyncio
async def test_evaluate_model_failure(self):
"""Test evaluation handles errors gracefully."""
with patch("mcp_task_aggregator.training.pipeline.TrainingPipeline") as mock_pipeline_cls:
mock_pipeline_cls.side_effect = RuntimeError("Model not found")
result = await evaluate_model_impl(
model_name="missing-model",
test_inputs=["Test"],
)
assert "error" in result
assert "Model not found" in result["error"]
assert result["model_name"] == "missing-model"
assert result["num_inputs"] == 1
@pytest.mark.asyncio
async def test_evaluate_model_empty_response(self):
"""Test evaluation handles empty responses."""
mock_trajectory = MagicMock()
mock_trajectory.messages.return_value = []
mock_trajectory.reward = 0.0
with patch("mcp_task_aggregator.training.pipeline.TrainingPipeline") as mock_pipeline_cls:
mock_pipeline = AsyncMock()
mock_pipeline_cls.return_value = mock_pipeline
mock_pipeline.generate_system_prompt = AsyncMock()
mock_pipeline.rollout = AsyncMock(return_value=mock_trajectory)
result = await evaluate_model_impl(
model_name="test-model",
test_inputs=["Test"],
)
assert result["results"][0]["response"] == "No response"