"""MCP Tools for OpenPipe ART training operations.
Provides tools for starting, monitoring, and evaluating training runs.
"""
from __future__ import annotations
from typing import Any
from loguru import logger
from pydantic import BaseModel, Field
from mcp_task_aggregator.models.training import TrainingRun, TrainingStatus
from mcp_task_aggregator.tools.server import mcp
class StartTrainingParams(BaseModel):
"""Parameters for start_training_run tool."""
model_name: str = Field(description="Name for the trained model")
project_name: str = Field(description="ART project name for tracking")
task_description: str = Field(description="Description of the task to train for")
base_model: str = Field(
default="Qwen/Qwen2.5-1.5B-Instruct",
description="HuggingFace model ID for the base model",
)
num_examples: int = Field(
default=100,
description="Number of training examples to generate",
ge=10,
le=1000,
)
rollouts_per_group: int = Field(
default=4,
description="Number of responses per input for RULER comparison",
ge=2,
le=10,
)
learning_rate: float = Field(
default=1e-5,
description="Learning rate for training",
gt=0,
le=1e-3,
)
num_epochs: int = Field(
default=1,
description="Number of training epochs",
ge=1,
le=10,
)
max_steps: int | None = Field(
default=None,
description="Maximum training steps (optional limit)",
)
class TrainingStatusResponse(BaseModel):
"""Response model for training status."""
run_id: str
status: str
model_name: str
project_name: str
task_description: str
num_examples: int
created_at: str
started_at: str | None
completed_at: str | None
metrics: dict[str, Any] | None
error: str | None
class EvaluateModelParams(BaseModel):
"""Parameters for evaluate_model tool."""
model_name: str = Field(description="Name of the trained model to evaluate")
test_inputs: list[str] = Field(description="List of test inputs to evaluate")
baseline_model: str | None = Field(
default=None,
description="Optional baseline model for comparison",
)
# Core implementation functions (testable without MCP decorator)
async def start_training_run_impl(
model_name: str,
project_name: str,
task_description: str,
base_model: str = "Qwen/Qwen2.5-1.5B-Instruct",
num_examples: int = 100,
rollouts_per_group: int = 4,
learning_rate: float = 1e-5,
num_epochs: int = 1,
max_steps: int | None = None,
) -> dict[str, Any]:
"""Start a new ART training run.
Initializes a training pipeline and begins the training process.
Args:
model_name: Name for the trained model.
project_name: ART project name.
task_description: Description of the task to train for.
base_model: HuggingFace model ID for base model.
num_examples: Number of training examples.
rollouts_per_group: Responses per input for RULER.
learning_rate: Training learning rate.
num_epochs: Number of training epochs.
max_steps: Optional maximum training steps.
Returns:
Dictionary with training run info.
"""
from mcp_task_aggregator.training.pipeline import TrainingPipeline
try:
# Create training run record
training_run = TrainingRun(
model_name=base_model,
project_name=project_name,
trained_model_name=model_name,
task_description=task_description,
num_examples=num_examples,
rollouts_per_group=rollouts_per_group,
learning_rate=learning_rate,
num_epochs=num_epochs,
max_steps=max_steps,
)
# Initialize pipeline
pipeline = TrainingPipeline(
model_name=model_name,
project_name=project_name,
)
await pipeline.initialize(base_model=base_model)
# Run training (this is a long-running operation)
completed_run = await pipeline.run_training(training_run)
return {
"run_id": str(completed_run.uuid),
"status": completed_run.status.value,
"model_name": completed_run.model_name,
"project_name": completed_run.project_name,
"task_description": completed_run.task_description,
"num_examples": completed_run.num_examples,
"created_at": completed_run.created_at.isoformat(),
"started_at": completed_run.started_at.isoformat() if completed_run.started_at else None,
"completed_at": completed_run.completed_at.isoformat() if completed_run.completed_at else None,
"metrics": completed_run.metrics.model_dump() if completed_run.metrics else None,
"error": completed_run.error,
}
except Exception as e:
logger.exception(f"Error starting training run: {e}")
return {
"error": str(e),
"status": TrainingStatus.FAILED.value,
}
async def evaluate_model_impl(
model_name: str,
test_inputs: list[str],
baseline_model: str | None = None,
) -> dict[str, Any]:
"""Evaluate a trained model against test inputs.
Runs the model on test inputs and optionally compares to a baseline.
Args:
model_name: Name of the trained model.
test_inputs: List of test inputs.
baseline_model: Optional baseline for comparison.
Returns:
Dictionary with evaluation results.
"""
from mcp_task_aggregator.models.training import TaskInput
from mcp_task_aggregator.training.pipeline import TrainingPipeline
try:
# Initialize pipeline for evaluation
pipeline = TrainingPipeline(
model_name=model_name,
project_name="evaluation",
)
# Generate a generic system prompt for evaluation
await pipeline.generate_system_prompt("Respond helpfully and accurately to user queries.")
results = []
for i, test_input in enumerate(test_inputs):
task_input = TaskInput(step=i, input_text=test_input)
trajectory = await pipeline.rollout(task_input)
messages = trajectory.messages()
response = messages[-1]["content"] if messages else "No response"
results.append({
"input": test_input,
"response": response,
"reward": trajectory.reward,
})
# Calculate aggregate metrics
avg_reward = sum(r["reward"] for r in results) / len(results) if results else 0.0
return {
"model_name": model_name,
"num_inputs": len(test_inputs),
"avg_reward": avg_reward,
"results": results,
"baseline_model": baseline_model,
"baseline_comparison": None, # TODO: Implement baseline comparison
}
except Exception as e:
logger.exception(f"Error evaluating model: {e}")
return {
"error": str(e),
"model_name": model_name,
"num_inputs": len(test_inputs),
}
def list_training_runs_impl(
status: str | None = None,
limit: int = 10,
) -> dict[str, Any]:
"""List training runs with optional filtering.
Args:
status: Filter by training status.
limit: Maximum number of runs to return.
Returns:
Dictionary with list of training runs.
"""
# TODO: Implement database query for training runs
# For now, return empty list (would need TrainingRunRepository)
logger.info(f"Listing training runs (status={status}, limit={limit})")
return {
"runs": [],
"total": 0,
"limit": limit,
"status_filter": status,
}
# MCP Tool registrations
@mcp.tool()
async def start_training_run(
model_name: str,
project_name: str,
task_description: str,
base_model: str = "Qwen/Qwen2.5-1.5B-Instruct",
num_examples: int = 100,
rollouts_per_group: int = 4,
learning_rate: float = 1e-5,
num_epochs: int = 1,
max_steps: int | None = None,
) -> dict[str, Any]:
"""Start a new ART training run.
Initializes a training pipeline and begins the training process
using OpenPipe ART with GRPO (Group Relative Policy Optimization).
Args:
model_name: Name for the trained model.
project_name: ART project name for tracking.
task_description: Description of the task to train for.
base_model: HuggingFace model ID for base model.
num_examples: Number of training examples to generate.
rollouts_per_group: Responses per input for RULER comparison.
learning_rate: Training learning rate.
num_epochs: Number of training epochs.
max_steps: Optional maximum training steps.
Returns:
Dictionary with training run info and status.
"""
return await start_training_run_impl(
model_name=model_name,
project_name=project_name,
task_description=task_description,
base_model=base_model,
num_examples=num_examples,
rollouts_per_group=rollouts_per_group,
learning_rate=learning_rate,
num_epochs=num_epochs,
max_steps=max_steps,
)
@mcp.tool()
async def evaluate_model(
model_name: str,
test_inputs: list[str],
baseline_model: str | None = None,
) -> dict[str, Any]:
"""Evaluate a trained model against test inputs.
Runs the model on test inputs and optionally compares
performance to a baseline model.
Args:
model_name: Name of the trained model to evaluate.
test_inputs: List of test inputs to evaluate.
baseline_model: Optional baseline model for comparison.
Returns:
Dictionary with evaluation results and metrics.
"""
return await evaluate_model_impl(
model_name=model_name,
test_inputs=test_inputs,
baseline_model=baseline_model,
)
@mcp.tool()
def list_training_runs(
status: str | None = None,
limit: int = 10,
) -> dict[str, Any]:
"""List training runs with optional filtering.
Retrieves training run history from the database.
Args:
status: Filter by status (pending, generating, training, completed, failed).
limit: Maximum number of runs to return.
Returns:
Dictionary with list of training runs.
"""
return list_training_runs_impl(status=status, limit=limit)