"""Pydantic models for OpenPipe ART training pipeline.
Provides typed models for training runs, trajectories, and task inputs
with validation and serialization support for RL agent training.
"""
from __future__ import annotations
from datetime import datetime
from enum import Enum
from typing import Any
from uuid import UUID, uuid4
from pydantic import BaseModel, Field
class TrainingStatus(str, Enum):
"""Status of a training run."""
PENDING = "pending"
GENERATING = "generating" # Generating training inputs
COLLECTING = "collecting" # Collecting trajectories
SCORING = "scoring" # RULER scoring trajectories
TRAINING = "training" # Model training in progress
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
class TaskInput(BaseModel):
"""Input for a training rollout.
Maps to the user's code pattern:
test_task_input = TaskInput(step=999, input_text=test_input)
"""
step: int
input_text: str
context: dict[str, Any] | None = None # Optional task context/metadata
model_config = {"from_attributes": True}
class TrajectoryMessage(BaseModel):
"""A single message in a trajectory conversation."""
role: str # "system", "user", or "assistant"
content: str
model_config = {"from_attributes": True}
class TrajectoryRecord(BaseModel):
"""Record of a single trajectory execution.
Stores the messages, reward, and metadata from a rollout
for persistence and analysis.
"""
id: int | None = None
uuid: UUID = Field(default_factory=uuid4)
run_id: int # Foreign key to TrainingRun
task_input: TaskInput
messages: list[TrajectoryMessage] = Field(default_factory=list)
reward: float | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
created_at: datetime = Field(default_factory=datetime.now)
model_config = {"from_attributes": True}
@classmethod
def from_row(cls, row: Any) -> TrajectoryRecord:
"""Create TrajectoryRecord from database row.
Args:
row: SQLite Row object or dictionary.
Returns:
TrajectoryRecord instance.
"""
data = dict(row) if hasattr(row, "keys") else row
if data.get("uuid") and isinstance(data["uuid"], str):
data["uuid"] = UUID(data["uuid"])
if data.get("created_at") and isinstance(data["created_at"], str):
data["created_at"] = datetime.fromisoformat(data["created_at"])
# Parse JSON fields
if data.get("task_input"):
if isinstance(data["task_input"], str):
data["task_input"] = TaskInput.model_validate_json(data["task_input"])
elif isinstance(data["task_input"], dict):
data["task_input"] = TaskInput.model_validate(data["task_input"])
if data.get("messages"):
if isinstance(data["messages"], str):
import json
messages_list = json.loads(data["messages"])
data["messages"] = [TrajectoryMessage.model_validate(m) for m in messages_list]
elif isinstance(data["messages"], list):
data["messages"] = [
TrajectoryMessage.model_validate(m) if isinstance(m, dict) else m for m in data["messages"]
]
if data.get("metadata") and isinstance(data["metadata"], str):
import json
data["metadata"] = json.loads(data["metadata"])
return cls(**data)
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for database storage.
Returns:
Dictionary with string-serialized values suitable for SQLite.
"""
import json
data = self.model_dump(exclude={"task_input", "messages", "metadata"})
data["uuid"] = str(data["uuid"])
if data.get("created_at"):
data["created_at"] = data["created_at"].isoformat()
data["task_input"] = self.task_input.model_dump_json()
data["messages"] = json.dumps([m.model_dump() for m in self.messages])
data["metadata"] = json.dumps(self.metadata) if self.metadata else None
return data
class TrainingMetrics(BaseModel):
"""Metrics collected during training."""
total_trajectories: int = 0
avg_reward: float | None = None
min_reward: float | None = None
max_reward: float | None = None
training_loss: float | None = None
epoch: int = 0
step: int = 0
model_config = {"from_attributes": True}
class TrainingRun(BaseModel):
"""A training run configuration and status.
Tracks the lifecycle of a training session from configuration
through completion.
"""
id: int | None = None
uuid: UUID = Field(default_factory=uuid4)
status: TrainingStatus = TrainingStatus.PENDING
model_name: str # Base model (e.g., "Qwen/Qwen2.5-1.5B-Instruct")
project_name: str # ART project name
trained_model_name: str | None = None # Name for the trained model
task_description: str # Description of the task for training
num_examples: int = 100
rollouts_per_group: int = 4 # Number of responses per input for RULER
learning_rate: float = 1e-5
num_epochs: int = 1
max_steps: int | None = None
system_prompt: str | None = None # Generated system prompt for the task
created_at: datetime = Field(default_factory=datetime.now)
started_at: datetime | None = None
completed_at: datetime | None = None
metrics: TrainingMetrics | None = None
error: str | None = None
config: dict[str, Any] = Field(default_factory=dict) # Additional config
model_config = {"from_attributes": True}
@classmethod
def from_row(cls, row: Any) -> TrainingRun:
"""Create TrainingRun from database row.
Args:
row: SQLite Row object or dictionary.
Returns:
TrainingRun instance.
"""
data = dict(row) if hasattr(row, "keys") else row
if data.get("uuid") and isinstance(data["uuid"], str):
data["uuid"] = UUID(data["uuid"])
datetime_fields = ["created_at", "started_at", "completed_at"]
for field in datetime_fields:
if data.get(field) and isinstance(data[field], str):
data[field] = datetime.fromisoformat(data[field])
if data.get("status") and isinstance(data["status"], str):
data["status"] = TrainingStatus(data["status"])
if data.get("metrics"):
if isinstance(data["metrics"], str):
data["metrics"] = TrainingMetrics.model_validate_json(data["metrics"])
elif isinstance(data["metrics"], dict):
data["metrics"] = TrainingMetrics.model_validate(data["metrics"])
if data.get("config") and isinstance(data["config"], str):
import json
data["config"] = json.loads(data["config"])
return cls(**data)
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for database storage.
Returns:
Dictionary with string-serialized values suitable for SQLite.
"""
import json
data = self.model_dump(exclude={"metrics", "config"})
data["uuid"] = str(data["uuid"])
datetime_fields = ["created_at", "started_at", "completed_at"]
for field in datetime_fields:
if data.get(field):
data[field] = data[field].isoformat()
data["status"] = data["status"].value if data["status"] else None
if self.metrics:
data["metrics"] = self.metrics.model_dump_json()
else:
data["metrics"] = None
data["config"] = json.dumps(self.config) if self.config else None
return data
class TrainingInputBatch(BaseModel):
"""A batch of training inputs with metadata.
Used during the training loop iteration.
"""
step: int
epoch: int
epoch_step: int
items: list[TaskInput]
model_config = {"from_attributes": True}