"""Training data adapter for extracting training signals from task management data.
Converts todo interactions and status transitions into training data for OpenPipe ART.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime
from typing import TYPE_CHECKING, Any
from loguru import logger
from mcp_task_aggregator.models.training import TaskInput, TrainingInputBatch
if TYPE_CHECKING:
from mcp_task_aggregator.storage.database import Database
@dataclass
class TrainingDataConfig:
"""Configuration for training data extraction."""
min_content_length: int = 10
max_content_length: int = 500
include_completed_only: bool = False
exclude_sources: list[str] = field(default_factory=list)
class TrainingDataAdapter:
"""Extract training signals from task management data.
Converts existing todos and their status transitions into training
inputs suitable for OpenPipe ART training.
"""
def __init__(
self,
db: Database,
config: TrainingDataConfig | None = None,
) -> None:
"""Initialize the training data adapter.
Args:
db: Database connection for accessing todos.
config: Optional configuration for data extraction.
"""
self.db = db
self.config = config or TrainingDataConfig()
def extract_task_inputs(
self,
limit: int = 100,
offset: int = 0,
) -> list[TaskInput]:
"""Extract task inputs from existing todos.
Converts todo content into training inputs that represent
realistic task management scenarios.
Args:
limit: Maximum number of inputs to extract.
offset: Offset for pagination.
Returns:
List of TaskInput objects for training.
"""
query = """
SELECT id, content, status, priority, due_date, source_system
FROM todos
WHERE LENGTH(content) >= ?
AND LENGTH(content) <= ?
"""
params: list[Any] = [
self.config.min_content_length,
self.config.max_content_length,
]
if self.config.include_completed_only:
query += " AND status = 'done'"
if self.config.exclude_sources:
placeholders = ",".join("?" * len(self.config.exclude_sources))
query += f" AND source_system NOT IN ({placeholders})"
params.extend(self.config.exclude_sources)
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
params.extend([limit, offset])
rows = self.db.fetchall(query, tuple(params))
logger.info(f"Extracted {len(rows)} task inputs from database")
inputs = []
for i, row in enumerate(rows):
context = {
"status": row["status"],
"priority": row["priority"],
"source": row["source_system"],
}
if row["due_date"]:
context["due_date"] = row["due_date"]
inputs.append(
TaskInput(
step=i,
input_text=row["content"],
context=context,
)
)
return inputs
def extract_status_transitions(
self,
todo_id: int,
) -> list[dict[str, Any]]:
"""Extract status transition history for a todo.
Useful for computing task completion rewards based on
how status progressed.
Args:
todo_id: ID of the todo to analyze.
Returns:
List of status transition records.
"""
# Check agent_metadata table for status change suggestions
query = """
SELECT agent_type, suggestions, created_at
FROM agent_metadata
WHERE todo_id = ?
ORDER BY created_at ASC
"""
rows = self.db.fetchall(query, (todo_id,))
transitions = []
for row in rows:
transitions.append({
"agent_type": row["agent_type"],
"suggestions": row["suggestions"],
"timestamp": row["created_at"],
})
return transitions
def extract_sync_history(
self,
source_system: str | None = None,
limit: int = 50,
) -> list[dict[str, Any]]:
"""Extract sync history as training context.
Sync logs can provide context about task management patterns
across different external systems.
Args:
source_system: Optional filter by source system.
limit: Maximum records to return.
Returns:
List of sync history records.
"""
query = """
SELECT source_system, sync_type, started_at, completed_at,
tasks_synced, tasks_created, tasks_updated, status
FROM sync_log
"""
params: list[Any] = []
if source_system:
query += " WHERE source_system = ?"
params.append(source_system)
query += " ORDER BY started_at DESC LIMIT ?"
params.append(limit)
rows = self.db.fetchall(query, tuple(params))
logger.debug(f"Retrieved {len(rows)} sync history records")
return [dict(row) for row in rows]
def create_training_batch(
self,
num_examples: int = 50,
step: int = 0,
epoch: int = 0,
epoch_step: int = 0,
) -> TrainingInputBatch:
"""Create a training batch from existing todo data.
Combines task inputs with contextual information to create
a complete training batch.
Args:
num_examples: Number of examples to include.
step: Current training step.
epoch: Current epoch.
epoch_step: Step within current epoch.
Returns:
TrainingInputBatch ready for pipeline consumption.
"""
inputs = self.extract_task_inputs(limit=num_examples)
return TrainingInputBatch(
items=inputs,
step=step,
epoch=epoch,
epoch_step=epoch_step,
)
def compute_completion_signals(
self,
todo_before: dict[str, Any],
todo_after: dict[str, Any],
) -> dict[str, Any]:
"""Compute completion signals from todo state changes.
Analyzes before/after state to generate reward signals
for task completion scoring.
Args:
todo_before: Todo state before agent action.
todo_after: Todo state after agent action.
Returns:
Dictionary of completion signals for reward calculation.
"""
signals: dict[str, Any] = {}
# Task completed check
signals["task_completed"] = todo_after.get("status") == "done"
# Status progression
status_before = todo_before.get("status", "todo")
status_after = todo_after.get("status", "todo")
signals["status_progression"] = [status_before, status_after]
# Proper flow detection
proper_flow = ["todo", "in_progress", "done"]
if status_before in proper_flow and status_after in proper_flow:
before_idx = proper_flow.index(status_before)
after_idx = proper_flow.index(status_after)
signals["proper_progression"] = after_idx >= before_idx
# Priority change detection
priority_before = todo_before.get("priority", 0)
priority_after = todo_after.get("priority", 0)
signals["priority_changed"] = priority_before != priority_after
# Due date handling
due_before = todo_before.get("due_date")
due_after = todo_after.get("due_date")
signals["due_date_set"] = due_after is not None and due_before is None
signals["due_date_met"] = (
signals["task_completed"]
and due_after
and self._is_on_time(due_after)
)
# Error detection
signals["error_count"] = 1 if todo_after.get("error") else 0
return signals
def _is_on_time(self, due_date_str: str) -> bool:
"""Check if task was completed on time.
Args:
due_date_str: Due date as ISO format string.
Returns:
True if current time is before or equal to due date.
"""
try:
due_date = datetime.fromisoformat(due_date_str.replace("Z", "+00:00"))
return datetime.now(due_date.tzinfo) <= due_date
except (ValueError, AttributeError):
return True # Assume on time if date is invalid
def get_training_statistics(self) -> dict[str, Any]:
"""Get statistics about available training data.
Returns:
Dictionary with training data statistics.
"""
stats: dict[str, Any] = {}
# Total todos
result = self.db.fetchone("SELECT COUNT(*) as count FROM todos")
stats["total_todos"] = result["count"] if result else 0
# By status
rows = self.db.fetchall(
"SELECT status, COUNT(*) as count FROM todos GROUP BY status"
)
stats["by_status"] = {row["status"]: row["count"] for row in rows}
# By source
rows = self.db.fetchall(
"SELECT source_system, COUNT(*) as count FROM todos GROUP BY source_system"
)
stats["by_source"] = {row["source_system"]: row["count"] for row in rows}
# Valid for training (within length constraints)
result = self.db.fetchone(
"""
SELECT COUNT(*) as count FROM todos
WHERE LENGTH(content) >= ? AND LENGTH(content) <= ?
""",
(self.config.min_content_length, self.config.max_content_length),
)
stats["valid_for_training"] = result["count"] if result else 0
# Agent metadata count
result = self.db.fetchone("SELECT COUNT(*) as count FROM agent_metadata")
stats["agent_interactions"] = result["count"] if result else 0
return stats