import { CohereClient } from "cohere-ai";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { CohereEmbeddings } from "./cohere.js";
const mockClient = {
embed: vi.fn().mockResolvedValue({ embeddings: [[]] }),
};
vi.mock("cohere-ai", () => ({
CohereClient: vi.fn().mockImplementation(function () {
return mockClient;
}),
}));
// Mock Bottleneck to pass through directly — avoids internal promise chains
// that cause unhandled rejections when combined with vi.useFakeTimers
vi.mock("bottleneck", () => ({
default: class MockBottleneck {
constructor(_options?: any) {}
async schedule<T>(fn: () => Promise<T>): Promise<T> {
return fn();
}
on() {
return this;
}
},
}));
describe("CohereEmbeddings", () => {
let embeddings: CohereEmbeddings;
beforeEach(() => {
mockClient.embed.mockReset().mockResolvedValue({ embeddings: [[]] });
vi.mocked(CohereClient).mockClear();
embeddings = new CohereEmbeddings("test-api-key");
});
describe("constructor", () => {
it("should use default model and dimensions", () => {
expect(embeddings.getModel()).toBe("embed-english-v3.0");
expect(embeddings.getDimensions()).toBe(1024);
});
it("should use custom model", () => {
const customEmbeddings = new CohereEmbeddings("test-api-key", "embed-multilingual-v3.0");
expect(customEmbeddings.getModel()).toBe("embed-multilingual-v3.0");
expect(customEmbeddings.getDimensions()).toBe(1024);
});
it("should use custom dimensions", () => {
const customEmbeddings = new CohereEmbeddings("test-api-key", "embed-english-v3.0", 512);
expect(customEmbeddings.getDimensions()).toBe(512);
});
it("should use default dimensions for light models", () => {
const lightEmbeddings = new CohereEmbeddings("test-api-key", "embed-english-light-v3.0");
expect(lightEmbeddings.getDimensions()).toBe(384);
});
it("should default to 1024 for unknown models", () => {
const unknownEmbeddings = new CohereEmbeddings("test-api-key", "custom-model");
expect(unknownEmbeddings.getDimensions()).toBe(1024);
});
it("should accept custom input type", () => {
const searchQueryEmbeddings = new CohereEmbeddings(
"test-api-key",
"embed-english-v3.0",
undefined,
undefined,
"search_query",
);
expect(searchQueryEmbeddings).toBeInstanceOf(CohereEmbeddings);
});
});
describe("embed", () => {
it("should generate embedding for single text", async () => {
const mockEmbedding = Array(1024)
.fill(0)
.map((_, i) => i * 0.001);
mockClient.embed.mockResolvedValue({
embeddings: [mockEmbedding],
});
const result = await embeddings.embed("test text");
expect(result).toEqual({
embedding: mockEmbedding,
dimensions: 1024,
});
expect(mockClient.embed).toHaveBeenCalledWith({
texts: ["test text"],
model: "embed-english-v3.0",
inputType: "search_document",
embeddingTypes: ["float"],
});
});
it("should handle long text", async () => {
const longText = "word ".repeat(1000);
const mockEmbedding = Array(1024).fill(0.5);
mockClient.embed.mockResolvedValue({
embeddings: [mockEmbedding],
});
const result = await embeddings.embed(longText);
expect(result.embedding).toEqual(mockEmbedding);
expect(mockClient.embed).toHaveBeenCalledWith({
texts: [longText],
model: "embed-english-v3.0",
inputType: "search_document",
embeddingTypes: ["float"],
});
});
it("should use custom model configuration", async () => {
const customEmbeddings = new CohereEmbeddings("test-api-key", "embed-multilingual-v3.0", 1024);
const mockEmbedding = Array(1024).fill(0.1);
mockClient.embed.mockResolvedValue({
embeddings: [mockEmbedding],
});
await customEmbeddings.embed("test");
expect(mockClient.embed).toHaveBeenCalledWith({
texts: ["test"],
model: "embed-multilingual-v3.0",
inputType: "search_document",
embeddingTypes: ["float"],
});
});
it("should throw error if no embedding returned", async () => {
mockClient.embed.mockResolvedValue({
embeddings: [],
});
await expect(embeddings.embed("test")).rejects.toThrow("No embedding returned from Cohere API");
});
it("should propagate errors", async () => {
mockClient.embed.mockRejectedValue(new Error("API Error"));
await expect(embeddings.embed("test")).rejects.toThrow("API Error");
});
});
describe("embedBatch", () => {
it("should generate embeddings for multiple texts", async () => {
const mockEmbeddings = [Array(1024).fill(0.1), Array(1024).fill(0.2), Array(1024).fill(0.3)];
mockClient.embed.mockResolvedValue({
embeddings: mockEmbeddings,
});
const texts = ["text1", "text2", "text3"];
const results = await embeddings.embedBatch(texts);
expect(results).toEqual([
{ embedding: mockEmbeddings[0], dimensions: 1024 },
{ embedding: mockEmbeddings[1], dimensions: 1024 },
{ embedding: mockEmbeddings[2], dimensions: 1024 },
]);
expect(mockClient.embed).toHaveBeenCalledWith({
texts,
model: "embed-english-v3.0",
inputType: "search_document",
embeddingTypes: ["float"],
});
});
it("should handle empty batch", async () => {
mockClient.embed.mockResolvedValue({
embeddings: [],
});
const results = await embeddings.embedBatch([]);
expect(results).toEqual([]);
});
it("should handle single item in batch", async () => {
const mockEmbedding = Array(1024).fill(0.5);
mockClient.embed.mockResolvedValue({
embeddings: [mockEmbedding],
});
const results = await embeddings.embedBatch(["single text"]);
expect(results).toHaveLength(1);
expect(results[0].embedding).toEqual(mockEmbedding);
});
it("should handle large batches", async () => {
const batchSize = 100;
const mockEmbeddings = Array(batchSize)
.fill(null)
.map(() => Array(1024).fill(Math.random()));
mockClient.embed.mockResolvedValue({
embeddings: mockEmbeddings,
});
const texts = Array(batchSize)
.fill(null)
.map((_, i) => `text ${i}`);
const results = await embeddings.embedBatch(texts);
expect(results).toHaveLength(batchSize);
});
it("should throw error if no embeddings returned", async () => {
mockClient.embed.mockResolvedValue({});
await expect(embeddings.embedBatch(["text1"])).rejects.toThrow("No embeddings returned from Cohere API");
});
it("should propagate errors in batch", async () => {
mockClient.embed.mockRejectedValue(new Error("Batch API Error"));
await expect(embeddings.embedBatch(["text1", "text2"])).rejects.toThrow("Batch API Error");
});
});
describe("getDimensions", () => {
it("should return configured dimensions", () => {
expect(embeddings.getDimensions()).toBe(1024);
});
it("should return custom dimensions", () => {
const customEmbeddings = new CohereEmbeddings("test-api-key", "embed-english-v3.0", 512);
expect(customEmbeddings.getDimensions()).toBe(512);
});
});
describe("getModel", () => {
it("should return configured model", () => {
expect(embeddings.getModel()).toBe("embed-english-v3.0");
});
it("should return custom model", () => {
const customEmbeddings = new CohereEmbeddings("test-api-key", "embed-multilingual-v3.0");
expect(customEmbeddings.getModel()).toBe("embed-multilingual-v3.0");
});
});
describe("rate limiting", () => {
beforeEach(() => {
vi.useFakeTimers({ toFake: ["setTimeout", "clearTimeout", "Date"] });
});
afterEach(async () => {
await vi.advanceTimersByTimeAsync(30_000);
vi.useRealTimers();
});
it("should retry on rate limit error (status 429)", async () => {
const mockEmbedding = Array(1024).fill(0.5);
mockClient.embed
.mockRejectedValueOnce({ status: 429, message: "Rate limit exceeded" })
.mockRejectedValueOnce({ status: 429, message: "Rate limit exceeded" })
.mockResolvedValue({ embeddings: [mockEmbedding] });
const promise = embeddings.embed("test text");
await vi.advanceTimersByTimeAsync(10_000);
const result = await promise;
expect(result.embedding).toEqual(mockEmbedding);
expect(mockClient.embed).toHaveBeenCalledTimes(3);
});
it("should retry on rate limit error (statusCode 429)", async () => {
const mockEmbedding = Array(1024).fill(0.5);
mockClient.embed
.mockRejectedValueOnce({
statusCode: 429,
message: "Rate limit exceeded",
})
.mockResolvedValue({ embeddings: [mockEmbedding] });
const promise = embeddings.embed("test text");
await vi.advanceTimersByTimeAsync(10_000);
const result = await promise;
expect(result.embedding).toEqual(mockEmbedding);
expect(mockClient.embed).toHaveBeenCalledTimes(2);
});
it("should retry on rate limit message", async () => {
const mockEmbedding = Array(1024).fill(0.5);
mockClient.embed
.mockRejectedValueOnce({
message: "You have exceeded the rate limit",
})
.mockResolvedValue({ embeddings: [mockEmbedding] });
const promise = embeddings.embed("test text");
await vi.advanceTimersByTimeAsync(10_000);
const result = await promise;
expect(result.embedding).toEqual(mockEmbedding);
expect(mockClient.embed).toHaveBeenCalledTimes(2);
});
it("should use exponential backoff", async () => {
const rateLimitEmbeddings = new CohereEmbeddings("test-api-key", "embed-english-v3.0", undefined, {
retryAttempts: 3,
retryDelayMs: 100,
});
const mockEmbedding = Array(1024).fill(0.5);
const rateLimitError = {
status: 429,
message: "Rate limit exceeded",
};
mockClient.embed
.mockRejectedValueOnce(rateLimitError)
.mockRejectedValueOnce(rateLimitError)
.mockResolvedValue({ embeddings: [mockEmbedding] });
const startTime = Date.now();
const promise = rateLimitEmbeddings.embed("test text");
await vi.advanceTimersByTimeAsync(10_000);
await promise;
const duration = Date.now() - startTime;
// Should wait: 100ms (first retry) + 200ms (second retry) = 300ms
expect(duration).toBeGreaterThanOrEqual(250);
});
it("should throw error after max retries exceeded", async () => {
const rateLimitEmbeddings = new CohereEmbeddings("test-api-key", "embed-english-v3.0", undefined, {
retryAttempts: 2,
retryDelayMs: 100,
});
const rateLimitError = {
status: 429,
message: "Rate limit exceeded",
};
mockClient.embed.mockRejectedValue(rateLimitError);
const promise = rateLimitEmbeddings.embed("test text");
promise.catch(() => {}); // prevent unhandled rejection detection
await vi.advanceTimersByTimeAsync(10_000);
await expect(promise).rejects.toThrow("Cohere API rate limit exceeded after 2 retry attempts");
expect(mockClient.embed).toHaveBeenCalledTimes(3);
});
it("should handle rate limit errors in batch operations", async () => {
const mockEmbeddings = [Array(1024).fill(0.1), Array(1024).fill(0.2)];
mockClient.embed.mockRejectedValueOnce({ status: 429, message: "Rate limit exceeded" }).mockResolvedValue({
embeddings: mockEmbeddings,
});
const promise = embeddings.embedBatch(["text1", "text2"]);
await vi.advanceTimersByTimeAsync(10_000);
const results = await promise;
expect(results).toHaveLength(2);
expect(mockClient.embed).toHaveBeenCalledTimes(2);
});
it("should not retry on non-rate-limit errors", async () => {
const apiError = new Error("Invalid API key");
mockClient.embed.mockRejectedValue(apiError);
await expect(embeddings.embed("test text")).rejects.toThrow("Invalid API key");
expect(mockClient.embed).toHaveBeenCalledTimes(1);
});
it("should accept custom rate limit configuration", () => {
const customEmbeddings = new CohereEmbeddings("test-api-key", "embed-english-v3.0", undefined, {
maxRequestsPerMinute: 200,
retryAttempts: 5,
retryDelayMs: 2000,
});
expect(customEmbeddings).toBeDefined();
});
});
});