trainingService.ts•6.38 kB
import { getConfig } from "../config/index.js";
import {
  executeSshCommand,
  type SshCommandOptions,
  type SshCommandResult
} from "./sshService.js";
import { createAbortError } from "../utils/abort.js";
import type { ProgressUpdate } from "../utils/progress.js";
import { logger } from "../utils/logger.js";
export interface TrainingJobInput {
  profile: string;
  subclasses: string[];
  datasetPath: string;
  commandTemplate?: string;
  timeoutMs?: number;
  dryRun?: boolean;
  signal?: AbortSignal;
  onProgress?: (update: ProgressUpdate) => void;
}
export type TrainingTaskStatus = "pending" | "running" | "succeeded" | "failed" | "cancelled";
export interface TrainingTaskLogEntry {
  level: "info" | "warn" | "error";
  message: string;
  at: string;
  context?: Record<string, unknown>;
}
export interface TrainingTaskReport {
  subclass: string;
  command: string;
  dryRun: boolean;
  status: TrainingTaskStatus;
  startedAt?: string;
  completedAt?: string;
  result?: SshCommandResult;
  error?: string;
  logs: TrainingTaskLogEntry[];
}
export interface TrainingJobReport {
  profile: string;
  datasetPath: string;
  commandTemplate: string;
  status: TrainingTaskStatus;
  startedAt: string;
  completedAt?: string;
  tasks: TrainingTaskReport[];
}
function buildCommand(template: string, subclass: string, datasetPath: string): string {
  return template
    .replaceAll("{{subclass}}", subclass)
    .replaceAll("{{datasetPath}}", datasetPath)
    .replaceAll("{{subclass_slug}}", subclass.replace(/\s+/g, "-").toLowerCase());
}
export async function runTrainingJob(input: TrainingJobInput): Promise<TrainingJobReport> {
  const config = getConfig();
  const template = input.commandTemplate ?? config.training.defaultCommandTemplate;
  if (!template) {
    throw new Error("No command template supplied for training job");
  }
  const total = input.subclasses.length;
  const abortError = createAbortError("Training job cancelled");
  const reportProgress = (update: ProgressUpdate) => {
    if (input.onProgress) {
      input.onProgress(update);
    }
  };
  const startedAt = new Date();
  const tasks: TrainingTaskReport[] = input.subclasses.map((subclass) => ({
    subclass,
    command: buildCommand(template, subclass, input.datasetPath),
    dryRun: Boolean(input.dryRun),
    status: "pending",
    logs: []
  }));
  reportProgress({ progress: 0, total, message: "Starting training job" });
  const appendLog = (
    task: TrainingTaskReport,
    level: "info" | "warn" | "error",
    message: string,
    context?: Record<string, unknown>
  ): void => {
    const at = new Date().toISOString();
    task.logs.push({ level, message, at, context });
    logger[level](message, {
      profile: input.profile,
      subclass: task.subclass,
      datasetPath: input.datasetPath,
      ...context
    });
  };
  for (let index = 0; index < tasks.length; index += 1) {
    const task = tasks[index];
    if (input.signal?.aborted) {
      appendLog(task, "warn", "Training task cancelled before start");
      task.status = "cancelled";
      continue;
    }
    if (input.dryRun) {
      task.status = "succeeded";
      task.startedAt = new Date().toISOString();
      task.completedAt = new Date().toISOString();
      appendLog(task, "info", "Dry run: command rendered but not executed", { command: task.command });
      reportProgress({ progress: index + 1, total, message: `Dry run prepared for ${task.subclass}` });
      continue;
    }
    task.status = "running";
    task.startedAt = new Date().toISOString();
    appendLog(task, "info", "Starting training command", { command: task.command });
    reportProgress({ progress: index, total, message: `Launching training for ${task.subclass}` });
    const options: SshCommandOptions = {
      timeoutMs: input.timeoutMs ?? config.training.defaultTimeoutMs,
      signal: input.signal,
      onProgress: (update) => {
        const normalized = Math.min(1, Math.max(0, update.progress));
        reportProgress({
          progress: index + normalized,
          total,
          message: update.message ? `${task.subclass}: ${update.message}` : `Running training for ${task.subclass}`
        });
      }
    };
    try {
      const execution = await executeSshCommand(input.profile, task.command, options, {
        tool: "trainClassifier"
      });
      task.result = execution;
      task.status = "succeeded";
      task.completedAt = new Date().toISOString();
      appendLog(task, "info", "Training command completed", {
        exitCode: execution.exitCode,
        signal: execution.signal,
        durationMs: execution.durationMs,
        stdoutBytes: execution.stdout.length,
        stderrBytes: execution.stderr.length,
        stdoutTruncated: execution.truncated.stdout,
        stderrTruncated: execution.truncated.stderr
      });
      reportProgress({ progress: index + 1, total, message: `Completed training for ${task.subclass}` });
    } catch (error) {
      const now = new Date().toISOString();
      task.completedAt = now;
      if (error instanceof Error && error.name === "AbortError") {
        task.status = "cancelled";
        task.error = error.message;
        appendLog(task, "warn", "Training command cancelled", { reason: error.message });
        reportProgress({ progress: index + 1, total, message: `Cancelled training for ${task.subclass}` });
      } else {
        task.status = "failed";
        const message = error instanceof Error ? error.message : String(error);
        task.error = message;
        appendLog(task, "error", "Training command failed", { error: message });
        reportProgress({ progress: index + 1, total, message: `Failed training for ${task.subclass}` });
      }
    }
  }
  const completedAt = new Date().toISOString();
  const overallStatus: TrainingTaskStatus = tasks.every((task) => task.status === "succeeded")
    ? "succeeded"
    : tasks.some((task) => task.status === "failed")
      ? "failed"
      : tasks.some((task) => task.status === "cancelled")
        ? "cancelled"
        : "succeeded";
  reportProgress({ progress: total, total, message: "Training job completed" });
  return {
    profile: input.profile,
    datasetPath: input.datasetPath,
    commandTemplate: template,
    status: overallStatus,
    startedAt: startedAt.toISOString(),
    completedAt,
    tasks
  };
}