sshService.test.ts•5.06 kB
import { describe, it, expect, beforeEach, afterEach, vi } from "vitest";
import type { EventEmitter } from "node:events";
import { refreshConfig } from "../src/config/index.js";
vi.mock("ssh2", async () => {
  const { EventEmitter } = await import("node:events");
  const instances: MockSshClient[] = [];
  class MockSshClient extends EventEmitter {
    public lastExec?: {
      command: string;
      options: Record<string, unknown>;
      stream: EventEmitter & { stderr: EventEmitter };
    };
    public ended = vi.fn();
    public destroyed = vi.fn();
    constructor() {
      super();
      instances.push(this);
    }
    connect(): void {
      queueMicrotask(() => {
        this.emit("ready");
      });
    }
    exec(command: string, options: Record<string, unknown>, callback: (err: Error | null, stream: EventEmitter & { stderr: EventEmitter }) => void): void {
      const stream = new EventEmitter() as EventEmitter & { stderr: EventEmitter };
      stream.stderr = new EventEmitter();
      this.lastExec = { command, options, stream };
      callback(null, stream);
    }
    end(): void {
      this.ended();
    }
    destroy(): void {
      this.destroyed();
    }
  }
  return {
    Client: MockSshClient,
    __instances: instances
  };
});
type MockClientModule = {
  __instances: Array<{
    lastExec?: {
      stream: EventEmitter & { stderr: EventEmitter };
    };
  }>;
};
describe("executeSshCommand", () => {
  const allowlistedConfig = {
    sshProfiles: {
      training: {
        host: "cluster.example.com",
        username: "trainer",
        password: { value: "secret" },
        policy: {
          allowedCommands: ["^safe"],
          maxExecutionMs: 100,
          maxOutputBytes: 4,
          maxConcurrent: 1
        }
      }
    },
    databaseProfiles: {},
    training: {}
  };
  beforeEach(() => {
    process.env.INFER_MCP_CONFIG = JSON.stringify(allowlistedConfig);
    refreshConfig();
  });
  afterEach(() => {
    delete process.env.INFER_MCP_CONFIG;
    delete process.env.INFER_MCP_MODE;
    refreshConfig();
    vi.clearAllTimers();
    vi.useRealTimers();
  });
  const flushAsync = async () => {
    await Promise.resolve();
    await Promise.resolve();
  };
  it("rejects commands outside the allowlist", async () => {
    const { executeSshCommand } = await import("../src/services/sshService.js");
    await expect(executeSshCommand("training", "rm -rf /"))
      .rejects.toThrowError(/not permitted/i);
  });
  it("truncates stdout and stderr when output exceeds limits", async () => {
    const { executeSshCommand } = await import("../src/services/sshService.js");
    const sshModule = (await import("ssh2")) as unknown as MockClientModule;
    const promise = executeSshCommand("training", "safe command");
    await flushAsync();
    await flushAsync();
    const stream = sshModule.__instances.at(-1)?.lastExec?.stream;
    expect(stream).toBeDefined();
    stream!.emit("data", Buffer.from("abcdef"));
    stream!.stderr.emit("data", Buffer.from("ghijk"));
    stream!.emit("close", 0);
    const result = await promise;
    expect(result.stdout.length).toBeLessThanOrEqual(4);
    expect(result.stderr.length).toBeLessThanOrEqual(4);
    expect(result.truncated.stdout).toBe(true);
    expect(result.truncated.stderr).toBe(true);
  });
  it("rejects when the command exceeds the timeout", async () => {
    vi.useFakeTimers();
    const { executeSshCommand } = await import("../src/services/sshService.js");
    const promise = executeSshCommand("training", "safe command");
    const caught = promise.then(
      (value) => value,
      (error: unknown) => error as Error
    );
    await flushAsync();
    await vi.advanceTimersByTimeAsync(101);
    const error = await caught;
    expect(error).toBeInstanceOf(Error);
    if (error instanceof Error) {
      expect(error.message).toMatch(/timed out/i);
    } else {
      throw new Error("Expected SSH timeout to reject");
    }
  });
  it("skips policy checks in local test mode", async () => {
    process.env.INFER_MCP_CONFIG = JSON.stringify({});
    refreshConfig();
    const { executeSshCommand } = await import("../src/services/sshService.js");
    const sshModule = (await import("ssh2")) as unknown as MockClientModule;
    const promise = executeSshCommand("local-test", "cat /etc/passwd");
    await flushAsync();
    await flushAsync();
    const stream = sshModule.__instances.at(-1)?.lastExec?.stream;
    expect(stream).toBeDefined();
    stream!.emit("data", Buffer.from("ok"));
    stream!.stderr.emit("data", Buffer.from(""));
    stream!.emit("close", 0);
    const result = await promise;
    expect(result.exitCode).toBe(0);
  });
  it("enforces policy when INFER_MCP_MODE=production", async () => {
    process.env.INFER_MCP_MODE = "production";
    refreshConfig();
    const { executeSshCommand } = await import("../src/services/sshService.js");
    await expect(executeSshCommand("local-test", "cat /etc/passwd"))
      .rejects.toThrow(/not permitted/i);
  });
});