databaseService.test.ts•4.13 kB
import { describe, it, expect, beforeEach, afterEach, vi } from "vitest";
import { refreshConfig } from "../src/config/index.js";
interface QueryCall {
  text?: string;
}
interface MockPgModule {
  __setQueryRows(rows: Array<Record<string, unknown>>): void;
  __setQueryError(error: Error | null): void;
  __getQueryCalls(): QueryCall[];
  __reset(): void;
}
vi.mock("pg", async () => {
  let rows: Array<Record<string, unknown>> = [];
  let error: Error | null = null;
  const queryCalls: QueryCall[] = [];
  class MockClient {
    public connection = { stream: { destroy: vi.fn() } };
    connect = vi.fn(async () => {});
    end = vi.fn(async () => {});
    query = vi.fn(async (config: QueryCall | string) => {
      const normalized: QueryCall = typeof config === "string" ? { text: config } : config;
      queryCalls.push(normalized);
      if (typeof normalized.text === "string" && normalized.text.startsWith("SET statement_timeout")) {
        return { rows: [], rowCount: 0 };
      }
      if (error) {
        throw error;
      }
      return {
        rows,
        rowCount: rows.length
      };
    });
  }
  return {
    Client: MockClient,
    __setQueryRows(newRows: Array<Record<string, unknown>>) {
      rows = newRows;
      error = null;
    },
    __setQueryError(newError: Error | null) {
      error = newError;
    },
    __getQueryCalls() {
      return queryCalls;
    },
    __reset() {
      rows = [];
      error = null;
      queryCalls.length = 0;
    }
  } as unknown as Record<string, unknown>;
});
describe("executeDatabaseQuery", () => {
  const baseConfig = {
    sshProfiles: {},
    databaseProfiles: {
      analytics: {
        connectionString: { value: "postgres://user:pass@localhost/db" },
        allowedStatements: ["^\\s*SELECT"],
        maxRows: 2,
        maxExecutionMs: 1000,
        maxConcurrent: 1
      }
    },
    training: {}
  };
  let pgModule: MockPgModule;
  beforeEach(async () => {
    process.env.INFER_MCP_CONFIG = JSON.stringify(baseConfig);
    refreshConfig();
    pgModule = (await import("pg")) as unknown as MockPgModule;
    pgModule.__reset();
  });
  afterEach(() => {
    delete process.env.INFER_MCP_CONFIG;
    refreshConfig();
    pgModule?.__reset();
  });
  it("rejects statements outside the allowlist", async () => {
    const { executeDatabaseQuery } = await import("../src/services/databaseService.js");
    await expect(executeDatabaseQuery("analytics", "DELETE FROM data", undefined, {}))
      .rejects.toThrow(/not permitted/i);
  });
  it("rejects queries with inline comments", async () => {
    const { executeDatabaseQuery } = await import("../src/services/databaseService.js");
    await expect(executeDatabaseQuery("analytics", "SELECT * FROM table -- comment", undefined, {}))
      .rejects.toThrow(/comments are not allowed/i);
  });
  it("rejects multi-statement queries", async () => {
    const { executeDatabaseQuery } = await import("../src/services/databaseService.js");
    await expect(executeDatabaseQuery("analytics", "SELECT 1; DROP TABLE users", undefined, {}))
      .rejects.toThrow(/multiple statements/i);
  });
  it("enforces row limits and reports truncation", async () => {
    pgModule.__setQueryRows([
      { id: 1 },
      { id: 2 },
      { id: 3 }
    ]);
    const { executeDatabaseQuery } = await import("../src/services/databaseService.js");
    const result = await executeDatabaseQuery("analytics", "SELECT * FROM data", undefined, {});
    expect(result.rows).toHaveLength(2);
    expect(result.rowCount).toBe(3);
    expect(result.truncated).toBe(true);
    const queries = pgModule.__getQueryCalls();
    expect(queries[0]?.text).toContain("SET statement_timeout");
    expect(queries[1]?.text).toContain("SELECT");
  });
  it("allows queries ending with a semicolon", async () => {
    pgModule.__setQueryRows([{ id: 1 }]);
    const { executeDatabaseQuery } = await import("../src/services/databaseService.js");
    const result = await executeDatabaseQuery("analytics", "SELECT * FROM data;", undefined, {});
    expect(result.rows).toHaveLength(1);
  });
});