/**
* Model Router — auto-routing via Haiku classifier.
*
* When the user selects "auto" as their model, each user message is classified
* by Haiku to determine the optimal model (haiku/sonnet/opus) based on task
* complexity. Matches Gemini CLI's auto-routing feature.
*/
import { MODELS } from "../../shared/constants.js";
import { callServerProxy } from "../../shared/api-client.js";
import type { APIRequestConfig } from "../../shared/types.js";
// ============================================================================
// TYPES
// ============================================================================
export type RoutedModel = "haiku" | "sonnet" | "opus";
export interface RouteResult {
model: RoutedModel;
reasoning: string;
confidence: number; // 0-1
}
// ============================================================================
// CONSTANTS
// ============================================================================
const CLASSIFIER_MODEL = MODELS.HAIKU;
const CLASSIFIER_PROMPT = `Classify this user request by complexity. Respond with ONLY a JSON object.
Categories:
- "haiku": Simple lookups, single-step tasks, quick questions, status checks, simple edits (1-2 files)
- "sonnet": Multi-step tasks, code refactoring, moderate analysis, 3-5 file changes, debugging
- "opus": Complex architecture, large refactors, planning, creative problem solving, 5+ files, ambiguous requirements
User request: {message}
Respond: {"model": "haiku"|"sonnet"|"opus", "reasoning": "brief reason", "confidence": 0.0-1.0}`;
const MODEL_ID_MAP: Record<RoutedModel, string> = {
haiku: MODELS.HAIKU,
sonnet: MODELS.SONNET,
opus: MODELS.OPUS,
};
const DEFAULT_ROUTE: RouteResult = {
model: "sonnet",
reasoning: "default fallback",
confidence: 0.5,
};
// ============================================================================
// CLASSIFIER
// ============================================================================
/**
* Classify a user message to determine the best model.
* Uses Haiku via server proxy for fast, cheap classification.
*/
export async function classifyAndRoute(
userMessage: string,
serverUrl: string,
authToken: string,
): Promise<RouteResult> {
try {
const prompt = CLASSIFIER_PROMPT.replace("{message}", userMessage);
const apiConfig: APIRequestConfig = {
betas: [],
contextManagement: { edits: [] },
maxTokens: 256,
};
const stream = await callServerProxy({
proxyUrl: `${serverUrl}/proxy`,
token: authToken,
model: CLASSIFIER_MODEL,
system: [{ type: "text", text: "You are a task complexity classifier. Respond with only valid JSON." }],
messages: [{ role: "user", content: prompt }],
tools: [],
apiConfig,
timeoutMs: 5000,
});
// Collect the full text response from the SSE stream
const text = await collectText(stream);
return parseClassifierResponse(text);
} catch {
// On any error (network, timeout, parse), fall back to sonnet
return DEFAULT_ROUTE;
}
}
/**
* Check if auto-routing is enabled for this model string.
*/
export function isAutoModel(model: string): boolean {
return model === "auto";
}
/**
* Resolve "auto" to an actual model ID based on classification.
*/
export async function resolveAutoModel(
userMessage: string,
serverUrl: string,
authToken: string,
): Promise<string> {
const result = await classifyAndRoute(userMessage, serverUrl, authToken);
return MODEL_ID_MAP[result.model];
}
// ============================================================================
// INTERNAL HELPERS
// ============================================================================
/**
* Collect all text from an SSE stream body.
* Lightweight extraction — only reads text_delta events.
*/
async function collectText(body: ReadableStream<Uint8Array>): Promise<string> {
const reader = body.getReader();
const decoder = new TextDecoder();
let buffer = "";
let text = "";
try {
while (true) {
const { done, value } = await reader.read();
if (done) break;
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split("\n");
buffer = lines.pop() || "";
for (const line of lines) {
const trimmed = line.trim();
if (!trimmed || !trimmed.startsWith("data: ")) continue;
const payload = trimmed.slice(6);
if (payload === "[DONE]") return text;
try {
const event = JSON.parse(payload);
if (event.type === "content_block_delta" && event.delta?.type === "text_delta") {
text += event.delta.text;
}
} catch { /* skip malformed JSON */ }
}
}
} finally {
reader.releaseLock();
}
return text;
}
/**
* Parse the classifier's JSON response into a RouteResult.
* Falls back to sonnet on any parse failure or low confidence.
*/
function parseClassifierResponse(text: string): RouteResult {
try {
// Extract JSON from response (handle markdown fences or extra text)
const jsonMatch = text.match(/\{[\s\S]*\}/);
if (!jsonMatch) return DEFAULT_ROUTE;
const parsed = JSON.parse(jsonMatch[0]);
const model = parsed.model as string;
if (model !== "haiku" && model !== "sonnet" && model !== "opus") {
return DEFAULT_ROUTE;
}
const confidence = typeof parsed.confidence === "number" ? parsed.confidence : 0;
const reasoning = typeof parsed.reasoning === "string" ? parsed.reasoning : "";
// Low confidence defaults to sonnet (safe middle ground)
if (confidence < 0.3) {
return { model: "sonnet", reasoning: reasoning || "low confidence fallback", confidence };
}
return { model: model as RoutedModel, reasoning, confidence };
} catch {
return DEFAULT_ROUTE;
}
}