trainClassifier
Execute classifier training commands on remote hosts via SSH to train specified subclasses using dataset paths and configurable execution parameters.
Instructions
Run classifier training commands on a remote host via SSH
Input Schema
TableJSON Schema
| Name | Required | Description | Default |
|---|---|---|---|
| commandTemplate | No | Command template to run on the remote host; overrides configuration if provided | |
| datasetPath | Yes | Remote dataset path | |
| dryRun | No | If true, render commands without executing them | |
| profile | Yes | Credential profile for SSH access | |
| subclasses | Yes | List of subclasses to train | |
| timeoutMs | No | Optional timeout per subclass execution in milliseconds |
Implementation Reference
- src/tools/trainClassifier.ts:70-116 (handler)The handler function for the 'trainClassifier' tool. It validates input, runs the training job using runTrainingJob, reports progress, generates a summary, and returns both text and structured output.async (args, extra) => { const input = TrainClassifierInputSchema.parse(args); const total = input.subclasses.length; const progress = createProgressReporter(extra, "trainClassifier"); progress?.({ progress: 0, total, message: "Preparing training job" }); const results = await runTrainingJob({ profile: input.profile, subclasses: input.subclasses, datasetPath: input.datasetPath, commandTemplate: input.commandTemplate, timeoutMs: input.timeoutMs, dryRun: input.dryRun ?? false, signal: extra.signal, onProgress: (update) => { progress?.({ progress: update.progress, total, message: update.message }); } }); progress?.({ progress: total, total, message: "Training job finished" }); const summaryLines = results.tasks.map((task) => { const status = task.status.toUpperCase(); const duration = task.result?.durationMs !== undefined ? `${task.result.durationMs}ms` : "n/a"; return `- ${task.subclass}: ${status} (duration: ${duration})`; }); return { content: [ { type: "text", text: [`Training job status: ${results.status.toUpperCase()}`, ...summaryLines].join("\n") } ], structuredContent: { job: results } }; } ); }
- src/tools/trainClassifier.ts:7-22 (schema)Zod schema defining the input parameters for the trainClassifier tool, including profile, subclasses, dataset path, etc.const TrainClassifierInputSchema = z.object({ profile: z.string().describe("Credential profile for SSH access"), subclasses: z.array(z.string()).nonempty().describe("List of subclasses to train"), datasetPath: z.string().describe("Remote dataset path"), commandTemplate: z .string() .optional() .describe("Command template to run on the remote host; overrides configuration if provided"), timeoutMs: z .number() .int() .positive() .optional() .describe("Optional timeout per subclass execution in milliseconds"), dryRun: z.boolean().optional().describe("If true, render commands without executing them") });
- src/tools/trainClassifier.ts:62-118 (registration)Registers the 'trainClassifier' tool with the MCP server, specifying the name, description, input/output schemas, and the handler function.export function registerTrainingTool(server: McpServer): void { server.registerTool( "trainClassifier", { description: "Run classifier training commands on a remote host via SSH", inputSchema: TrainClassifierInputSchema.shape, outputSchema: TrainingJobOutputShape }, async (args, extra) => { const input = TrainClassifierInputSchema.parse(args); const total = input.subclasses.length; const progress = createProgressReporter(extra, "trainClassifier"); progress?.({ progress: 0, total, message: "Preparing training job" }); const results = await runTrainingJob({ profile: input.profile, subclasses: input.subclasses, datasetPath: input.datasetPath, commandTemplate: input.commandTemplate, timeoutMs: input.timeoutMs, dryRun: input.dryRun ?? false, signal: extra.signal, onProgress: (update) => { progress?.({ progress: update.progress, total, message: update.message }); } }); progress?.({ progress: total, total, message: "Training job finished" }); const summaryLines = results.tasks.map((task) => { const status = task.status.toUpperCase(); const duration = task.result?.durationMs !== undefined ? `${task.result.durationMs}ms` : "n/a"; return `- ${task.subclass}: ${status} (duration: ${duration})`; }); return { content: [ { type: "text", text: [`Training job status: ${results.status.toUpperCase()}`, ...summaryLines].join("\n") } ], structuredContent: { job: results } }; } ); }
- src/tools/index.ts:6-10 (registration)Top-level tool registration function that calls registerTrainingTool(server) to register the trainClassifier tool among others.export function registerTools(server: McpServer): void { registerSshTool(server); registerDbTool(server); registerTrainingTool(server); }
- Core helper function called by the tool handler to execute the actual training jobs on remote hosts via SSH, handling multiple subclasses, dry-run, timeouts, progress, and returning a detailed report.export async function runTrainingJob(input: TrainingJobInput): Promise<TrainingJobReport> { const config = getConfig(); const template = input.commandTemplate ?? config.training.defaultCommandTemplate; if (!template) { throw new Error("No command template supplied for training job"); } const total = input.subclasses.length; const abortError = createAbortError("Training job cancelled"); const reportProgress = (update: ProgressUpdate) => { if (input.onProgress) { input.onProgress(update); } }; const startedAt = new Date(); const tasks: TrainingTaskReport[] = input.subclasses.map((subclass) => ({ subclass, command: buildCommand(template, subclass, input.datasetPath), dryRun: Boolean(input.dryRun), status: "pending", logs: [] })); reportProgress({ progress: 0, total, message: "Starting training job" }); const appendLog = ( task: TrainingTaskReport, level: "info" | "warn" | "error", message: string, context?: Record<string, unknown> ): void => { const at = new Date().toISOString(); task.logs.push({ level, message, at, context }); logger[level](message, { profile: input.profile, subclass: task.subclass, datasetPath: input.datasetPath, ...context }); }; for (let index = 0; index < tasks.length; index += 1) { const task = tasks[index]; if (input.signal?.aborted) { appendLog(task, "warn", "Training task cancelled before start"); task.status = "cancelled"; continue; } if (input.dryRun) { task.status = "succeeded"; task.startedAt = new Date().toISOString(); task.completedAt = new Date().toISOString(); appendLog(task, "info", "Dry run: command rendered but not executed", { command: task.command }); reportProgress({ progress: index + 1, total, message: `Dry run prepared for ${task.subclass}` }); continue; } task.status = "running"; task.startedAt = new Date().toISOString(); appendLog(task, "info", "Starting training command", { command: task.command }); reportProgress({ progress: index, total, message: `Launching training for ${task.subclass}` }); const options: SshCommandOptions = { timeoutMs: input.timeoutMs ?? config.training.defaultTimeoutMs, signal: input.signal, onProgress: (update) => { const normalized = Math.min(1, Math.max(0, update.progress)); reportProgress({ progress: index + normalized, total, message: update.message ? `${task.subclass}: ${update.message}` : `Running training for ${task.subclass}` }); } }; try { const execution = await executeSshCommand(input.profile, task.command, options, { tool: "trainClassifier" }); task.result = execution; task.status = "succeeded"; task.completedAt = new Date().toISOString(); appendLog(task, "info", "Training command completed", { exitCode: execution.exitCode, signal: execution.signal, durationMs: execution.durationMs, stdoutBytes: execution.stdout.length, stderrBytes: execution.stderr.length, stdoutTruncated: execution.truncated.stdout, stderrTruncated: execution.truncated.stderr }); reportProgress({ progress: index + 1, total, message: `Completed training for ${task.subclass}` }); } catch (error) { const now = new Date().toISOString(); task.completedAt = now; if (error instanceof Error && error.name === "AbortError") { task.status = "cancelled"; task.error = error.message; appendLog(task, "warn", "Training command cancelled", { reason: error.message }); reportProgress({ progress: index + 1, total, message: `Cancelled training for ${task.subclass}` }); } else { task.status = "failed"; const message = error instanceof Error ? error.message : String(error); task.error = message; appendLog(task, "error", "Training command failed", { error: message }); reportProgress({ progress: index + 1, total, message: `Failed training for ${task.subclass}` }); } } } const completedAt = new Date().toISOString(); const overallStatus: TrainingTaskStatus = tasks.every((task) => task.status === "succeeded") ? "succeeded" : tasks.some((task) => task.status === "failed") ? "failed" : tasks.some((task) => task.status === "cancelled") ? "cancelled" : "succeeded"; reportProgress({ progress: total, total, message: "Training job completed" }); return { profile: input.profile, datasetPath: input.datasetPath, commandTemplate: template, status: overallStatus, startedAt: startedAt.toISOString(), completedAt, tasks }; }