"""Tests for training Pydantic models.
Tests for OpenPipe ART training data models.
"""
from datetime import datetime
from uuid import UUID
from mcp_task_aggregator.models.training import (
TaskInput,
TrainingInputBatch,
TrainingMetrics,
TrainingRun,
TrainingStatus,
TrajectoryMessage,
TrajectoryRecord,
)
class TestTrainingStatus:
"""Tests for TrainingStatus enum."""
def test_status_values(self):
"""Test that all expected status values exist."""
assert TrainingStatus.PENDING == "pending"
assert TrainingStatus.GENERATING == "generating"
assert TrainingStatus.COLLECTING == "collecting"
assert TrainingStatus.SCORING == "scoring"
assert TrainingStatus.TRAINING == "training"
assert TrainingStatus.COMPLETED == "completed"
assert TrainingStatus.FAILED == "failed"
assert TrainingStatus.CANCELLED == "cancelled"
def test_all_statuses_present(self):
"""Test that expected number of statuses exist."""
assert len(TrainingStatus) == 8
class TestTaskInput:
"""Tests for TaskInput model."""
def test_minimal_creation(self):
"""Test TaskInput creation with minimal fields."""
task_input = TaskInput(step=0, input_text="Test input")
assert task_input.step == 0
assert task_input.input_text == "Test input"
assert task_input.context is None
def test_with_context(self):
"""Test TaskInput creation with context."""
context = {"status": "todo", "priority": 3}
task_input = TaskInput(step=5, input_text="Complex task", context=context)
assert task_input.step == 5
assert task_input.context is not None
assert task_input.context["status"] == "todo"
assert task_input.context["priority"] == 3
class TestTrajectoryMessage:
"""Tests for TrajectoryMessage model."""
def test_user_message(self):
"""Test user message creation."""
msg = TrajectoryMessage(role="user", content="Hello")
assert msg.role == "user"
assert msg.content == "Hello"
def test_assistant_message(self):
"""Test assistant message creation."""
msg = TrajectoryMessage(role="assistant", content="Hi there!")
assert msg.role == "assistant"
assert msg.content == "Hi there!"
def test_system_message(self):
"""Test system message creation."""
msg = TrajectoryMessage(role="system", content="You are helpful.")
assert msg.role == "system"
assert msg.content == "You are helpful."
class TestTrajectoryRecord:
"""Tests for TrajectoryRecord model."""
def test_minimal_creation(self):
"""Test TrajectoryRecord creation with minimal fields."""
task_input = TaskInput(step=0, input_text="Test")
record = TrajectoryRecord(run_id=1, task_input=task_input)
assert record.run_id == 1
assert record.task_input.input_text == "Test"
assert isinstance(record.uuid, UUID)
assert record.reward is None
assert record.messages == []
def test_full_creation(self):
"""Test TrajectoryRecord creation with all fields."""
task_input = TaskInput(step=1, input_text="Input")
messages = [
TrajectoryMessage(role="system", content="Be helpful"),
TrajectoryMessage(role="user", content="Input"),
TrajectoryMessage(role="assistant", content="Output"),
]
record = TrajectoryRecord(
run_id=1,
task_input=task_input,
messages=messages,
reward=0.85,
metadata={"scored_by": "ruler"},
)
assert record.reward == 0.85
assert len(record.messages) == 3
assert record.metadata["scored_by"] == "ruler"
class TestTrainingMetrics:
"""Tests for TrainingMetrics model."""
def test_minimal_metrics(self):
"""Test TrainingMetrics with minimal fields."""
metrics = TrainingMetrics()
assert metrics.total_trajectories == 0
assert metrics.avg_reward is None
assert metrics.min_reward is None
assert metrics.max_reward is None
assert metrics.epoch == 0
def test_full_metrics(self):
"""Test TrainingMetrics with all fields."""
metrics = TrainingMetrics(
total_trajectories=100,
avg_reward=0.75,
max_reward=0.95,
min_reward=0.35,
epoch=2,
step=50,
training_loss=0.123,
)
assert metrics.total_trajectories == 100
assert metrics.avg_reward == 0.75
assert metrics.max_reward == 0.95
assert metrics.epoch == 2
assert metrics.step == 50
assert metrics.training_loss == 0.123
class TestTrainingRun:
"""Tests for TrainingRun model."""
def test_minimal_creation(self):
"""Test TrainingRun creation with minimal fields."""
run = TrainingRun(
model_name="base-model",
project_name="test-project",
task_description="Test training task",
)
assert run.model_name == "base-model"
assert run.project_name == "test-project"
assert run.task_description == "Test training task"
assert run.status == TrainingStatus.PENDING
assert isinstance(run.uuid, UUID)
assert isinstance(run.created_at, datetime)
def test_full_creation(self):
"""Test TrainingRun creation with all fields."""
run = TrainingRun(
model_name="Qwen/Qwen2.5-1.5B-Instruct",
project_name="task-agent",
trained_model_name="task-agent-v1",
task_description="Train task management agent",
num_examples=50,
rollouts_per_group=8,
learning_rate=5e-6,
num_epochs=3,
max_steps=100,
system_prompt="You are a task management assistant.",
)
assert run.model_name == "Qwen/Qwen2.5-1.5B-Instruct"
assert run.trained_model_name == "task-agent-v1"
assert run.num_examples == 50
assert run.rollouts_per_group == 8
assert run.learning_rate == 5e-6
assert run.num_epochs == 3
assert run.max_steps == 100
def test_defaults(self):
"""Test TrainingRun default values."""
run = TrainingRun(
model_name="model",
project_name="project",
task_description="desc",
)
assert run.num_examples == 100
assert run.rollouts_per_group == 4
assert run.learning_rate == 1e-5
assert run.num_epochs == 1
assert run.max_steps is None
assert run.metrics is None
assert run.error is None
def test_status_transitions(self):
"""Test that status can be updated."""
run = TrainingRun(
model_name="model",
project_name="project",
task_description="desc",
)
assert run.status == TrainingStatus.PENDING
run.status = TrainingStatus.GENERATING
assert run.status == TrainingStatus.GENERATING
run.status = TrainingStatus.TRAINING
assert run.status == TrainingStatus.TRAINING
run.status = TrainingStatus.COMPLETED
assert run.status == TrainingStatus.COMPLETED
class TestTrainingInputBatch:
"""Tests for TrainingInputBatch model."""
def test_creation(self):
"""Test TrainingInputBatch creation."""
items = [
TaskInput(step=0, input_text="Input 1"),
TaskInput(step=0, input_text="Input 2"),
]
batch = TrainingInputBatch(
items=items,
step=0,
epoch=1,
epoch_step=0,
)
assert len(batch.items) == 2
assert batch.step == 0
assert batch.epoch == 1
assert batch.epoch_step == 0
def test_empty_batch(self):
"""Test TrainingInputBatch with no items."""
batch = TrainingInputBatch(items=[], step=0, epoch=0, epoch_step=0)
assert len(batch.items) == 0
assert batch.epoch_step == 0
class TestTrajectoryRecordSerialization:
"""Tests for TrajectoryRecord serialization and deserialization.
Why: Records must be converted to/from database storage format.
Can fail if: JSON serialization fails or data is lost during conversion.
"""
def test_to_dict_basic(self):
"""
Verify TrajectoryRecord.to_dict converts record to dictionary.
Why: Database storage requires dict format with JSON-serialized fields.
Can fail if: Field serialization is incorrect or fields are missing.
"""
task_input = TaskInput(step=1, input_text="Test input")
messages = [
TrajectoryMessage(role="system", content="System"),
TrajectoryMessage(role="user", content="User"),
]
record = TrajectoryRecord(
run_id=1,
task_input=task_input,
messages=messages,
reward=0.75,
metadata={"key": "value"}
)
result = record.to_dict()
assert result["run_id"] == 1
assert isinstance(result["uuid"], str)
assert result["reward"] == 0.75
assert isinstance(result["task_input"], str) # JSON string
assert isinstance(result["messages"], str) # JSON string
assert isinstance(result["metadata"], str) # JSON string
def test_from_row_with_string_fields(self):
"""
Verify TrajectoryRecord.from_row deserializes from database row with JSON strings.
Why: Database stores JSON fields as strings that must be parsed.
Can fail if: JSON parsing fails or types are not correctly restored.
"""
import json
from uuid import uuid4
test_uuid = uuid4()
test_time = datetime.now()
row_data = {
"id": 1,
"uuid": str(test_uuid),
"run_id": 5,
"task_input": TaskInput(step=2, input_text="Test").model_dump_json(),
"messages": json.dumps([
{"role": "system", "content": "Sys"},
{"role": "user", "content": "User"}
]),
"reward": 0.88,
"metadata": json.dumps({"step": 2}),
"created_at": test_time.isoformat()
}
record = TrajectoryRecord.from_row(row_data)
assert record.id == 1
assert record.uuid == test_uuid
assert record.run_id == 5
assert isinstance(record.task_input, TaskInput)
assert record.task_input.step == 2
assert len(record.messages) == 2
assert isinstance(record.messages[0], TrajectoryMessage)
assert record.reward == 0.88
assert record.metadata["step"] == 2
assert isinstance(record.created_at, datetime)
def test_from_row_with_dict_task_input(self):
"""
Verify TrajectoryRecord.from_row handles task_input as dict.
Why: Some database rows may have dict representation instead of JSON string.
Can fail if: Dict conversion logic is missing.
"""
row_data = {
"id": 1,
"uuid": "550e8400-e29b-41d4-a716-446655440000",
"run_id": 1,
"task_input": {"step": 3, "input_text": "Dict input", "context": None},
"messages": "[]",
"reward": None,
"metadata": "{}",
"created_at": datetime.now().isoformat()
}
record = TrajectoryRecord.from_row(row_data)
assert isinstance(record.task_input, TaskInput)
assert record.task_input.step == 3
assert record.task_input.input_text == "Dict input"
def test_from_row_with_list_messages(self):
"""
Verify TrajectoryRecord.from_row handles messages as list of dicts.
Why: Messages may come as parsed list instead of JSON string.
Can fail if: List conversion logic is missing.
"""
row_data = {
"id": 1,
"uuid": "550e8400-e29b-41d4-a716-446655440000",
"run_id": 1,
"task_input": TaskInput(step=1, input_text="Test").model_dump_json(),
"messages": [
{"role": "system", "content": "System"},
{"role": "user", "content": "User"}
],
"reward": None,
"metadata": "{}",
"created_at": datetime.now().isoformat()
}
record = TrajectoryRecord.from_row(row_data)
assert len(record.messages) == 2
assert isinstance(record.messages[0], TrajectoryMessage)
assert record.messages[0].role == "system"
class TestTrainingRunSerialization:
"""Tests for TrainingRun serialization and deserialization.
Why: Training runs must be persisted and restored from database.
Can fail if: Serialization breaks or data types are lost.
"""
def test_to_dict_complete(self):
"""
Verify TrainingRun.to_dict converts all fields correctly.
Why: Database storage requires proper serialization of all fields.
Can fail if: Complex fields like metrics or enums are not serialized.
"""
metrics = TrainingMetrics(
total_trajectories=100,
avg_reward=0.75,
epoch=2
)
run = TrainingRun(
model_name="test-model",
project_name="test-project",
task_description="Test task",
status=TrainingStatus.COMPLETED,
metrics=metrics,
config={"key": "value"}
)
result = run.to_dict()
assert result["model_name"] == "test-model"
assert result["status"] == "completed" # Enum value as string
assert isinstance(result["uuid"], str)
assert isinstance(result["created_at"], str) # ISO format
assert isinstance(result["metrics"], str) # JSON string
assert isinstance(result["config"], str) # JSON string
def test_from_row_with_string_status(self):
"""
Verify TrainingRun.from_row deserializes status enum from string.
Why: Database stores enum values as strings.
Can fail if: Enum reconstruction fails.
"""
row_data = {
"id": 1,
"uuid": "550e8400-e29b-41d4-a716-446655440000",
"model_name": "test",
"project_name": "test",
"task_description": "desc",
"status": "training", # String value
"num_examples": 100,
"rollouts_per_group": 4,
"learning_rate": 1e-5,
"num_epochs": 1,
"max_steps": None,
"system_prompt": None,
"trained_model_name": None,
"created_at": datetime.now().isoformat(),
"started_at": None,
"completed_at": None,
"metrics": None,
"error": None,
"config": "{}"
}
run = TrainingRun.from_row(row_data)
assert run.status == TrainingStatus.TRAINING
assert isinstance(run.status, TrainingStatus)
def test_from_row_with_datetime_strings(self):
"""
Verify TrainingRun.from_row converts ISO datetime strings to datetime objects.
Why: Database stores datetimes as ISO strings.
Can fail if: Datetime parsing fails or returns wrong type.
"""
created = datetime(2024, 1, 1, 12, 0, 0)
started = datetime(2024, 1, 1, 12, 5, 0)
completed = datetime(2024, 1, 1, 13, 0, 0)
row_data = {
"id": 1,
"uuid": "550e8400-e29b-41d4-a716-446655440000",
"model_name": "test",
"project_name": "test",
"task_description": "desc",
"status": "completed",
"num_examples": 100,
"rollouts_per_group": 4,
"learning_rate": 1e-5,
"num_epochs": 1,
"max_steps": None,
"system_prompt": None,
"trained_model_name": None,
"created_at": created.isoformat(),
"started_at": started.isoformat(),
"completed_at": completed.isoformat(),
"metrics": None,
"error": None,
"config": "{}"
}
run = TrainingRun.from_row(row_data)
assert isinstance(run.created_at, datetime)
assert isinstance(run.started_at, datetime)
assert isinstance(run.completed_at, datetime)
assert run.created_at == created
assert run.started_at == started
assert run.completed_at == completed
def test_from_row_with_metrics_json(self):
"""
Verify TrainingRun.from_row deserializes metrics from JSON string.
Why: Metrics are stored as JSON in database.
Can fail if: JSON parsing or model validation fails.
"""
metrics_json = TrainingMetrics(
total_trajectories=50,
avg_reward=0.65,
epoch=1
).model_dump_json()
row_data = {
"id": 1,
"uuid": "550e8400-e29b-41d4-a716-446655440000",
"model_name": "test",
"project_name": "test",
"task_description": "desc",
"status": "completed",
"num_examples": 100,
"rollouts_per_group": 4,
"learning_rate": 1e-5,
"num_epochs": 1,
"max_steps": None,
"system_prompt": None,
"trained_model_name": None,
"created_at": datetime.now().isoformat(),
"started_at": None,
"completed_at": None,
"metrics": metrics_json,
"error": None,
"config": "{}"
}
run = TrainingRun.from_row(row_data)
assert isinstance(run.metrics, TrainingMetrics)
assert run.metrics.total_trajectories == 50
assert run.metrics.avg_reward == 0.65
def test_from_row_with_metrics_dict(self):
"""
Verify TrainingRun.from_row handles metrics as dict.
Why: Some database drivers may return parsed dict instead of JSON string.
Can fail if: Dict conversion logic is missing.
"""
row_data = {
"id": 1,
"uuid": "550e8400-e29b-41d4-a716-446655440000",
"model_name": "test",
"project_name": "test",
"task_description": "desc",
"status": "completed",
"num_examples": 100,
"rollouts_per_group": 4,
"learning_rate": 1e-5,
"num_epochs": 1,
"max_steps": None,
"system_prompt": None,
"trained_model_name": None,
"created_at": datetime.now().isoformat(),
"started_at": None,
"completed_at": None,
"metrics": {
"total_trajectories": 75,
"avg_reward": 0.80,
"min_reward": None,
"max_reward": None,
"training_loss": None,
"epoch": 2,
"step": 0
},
"error": None,
"config": "{}"
}
run = TrainingRun.from_row(row_data)
assert isinstance(run.metrics, TrainingMetrics)
assert run.metrics.total_trajectories == 75
assert run.metrics.avg_reward == 0.80
assert run.metrics.epoch == 2