MCP Terminal Server

/** * Copyright 2024 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ import { MistralGoogleCloud } from '@mistralai/mistralai-gcp'; import { AssistantMessage, ChatCompletionChoiceFinishReason, ChatCompletionRequest, ChatCompletionResponse, CompletionChunk, FunctionCall, Tool as MistralTool, SystemMessage, ToolCall, ToolMessage, ToolTypes, UserMessage, } from '@mistralai/mistralai-gcp/models/components'; import { GENKIT_CLIENT_HEADER, GenerateRequest, GenerationCommonConfigSchema, Genkit, MessageData, ModelReference, ModelResponseData, Part, Role, ToolRequestPart, z, } from 'genkit'; import { ModelAction, modelRef } from 'genkit/model'; export const MistralConfigSchema = GenerationCommonConfigSchema.extend({ location: z.string().optional(), maxOutputTokens: z.number().optional(), temperature: z.number().optional(), // TODO: is this supported? // topK: z.number().optional(), topP: z.number().optional(), stopSequences: z.array(z.string()).optional(), }); export const mistralLarge = modelRef({ name: 'vertexai/mistral-large', info: { label: 'Vertex AI Model Garden - Mistral Large', versions: ['mistral-large-2411', 'mistral-large-2407'], supports: { multiturn: true, media: false, tools: true, systemRole: true, output: ['text'], }, }, configSchema: MistralConfigSchema, }); export const mistralNemo = modelRef({ name: 'vertexai/mistral-nemo', info: { label: 'Vertex AI Model Garden - Mistral Nemo', versions: ['mistral-nemo-2407'], supports: { multiturn: true, media: false, tools: false, systemRole: true, output: ['text'], }, }, configSchema: MistralConfigSchema, }); export const codestral = modelRef({ name: 'vertexai/codestral', info: { label: 'Vertex AI Model Garden - Codestral', versions: ['codestral-2405'], supports: { multiturn: true, media: false, tools: false, systemRole: true, output: ['text'], }, }, configSchema: MistralConfigSchema, }); export const SUPPORTED_MISTRAL_MODELS: Record< string, ModelReference<typeof MistralConfigSchema> > = { 'mistral-large': mistralLarge, 'mistral-nemo': mistralNemo, codestral: codestral, }; // TODO: Do they export a type for this? type MistralRole = 'assistant' | 'user' | 'tool' | 'system'; function toMistralRole(role: Role): MistralRole { switch (role) { case 'model': return 'assistant'; case 'user': return 'user'; case 'tool': return 'tool'; case 'system': return 'system'; default: throw new Error(`Unknwon role ${role}`); } } function toMistralToolRequest(toolRequest: Record<string, any>): FunctionCall { if (!toolRequest.name) { throw new Error('Tool name is required'); } return { name: toolRequest.name, // Mistral expects arguments as either a string or object arguments: typeof toolRequest.input === 'string' ? toolRequest.input : JSON.stringify(toolRequest.input), }; } export function toMistralRequest( model: string, input: GenerateRequest<typeof MistralConfigSchema> ): ChatCompletionRequest { const messages = input.messages.map((msg) => { // Handle regular text messages if (msg.content.every((part) => part.text)) { const content = msg.content.map((part) => part.text || '').join(''); return { role: toMistralRole(msg.role), content, }; } // Handle assistant's tool/function calls const toolRequest = msg.content.find((part) => part.toolRequest); if (toolRequest?.toolRequest) { const functionCall = toMistralToolRequest(toolRequest.toolRequest); return { role: 'assistant' as const, content: null, toolCalls: [ { id: toolRequest.toolRequest.ref, type: ToolTypes.Function, function: { name: functionCall.name, arguments: functionCall.arguments, }, }, ], }; } // Handle tool responses const toolResponse = msg.content.find((part) => part.toolResponse); if (toolResponse?.toolResponse) { return { role: 'tool' as const, name: toolResponse.toolResponse.name, content: JSON.stringify(toolResponse.toolResponse.output), toolCallId: toolResponse.toolResponse.ref, // This must match the id from tool_calls }; } return { role: toMistralRole(msg.role), content: msg.content.map((part) => part.text || '').join(''), }; }); validateToolSequence(messages); // This line exists but might not be running? const request: ChatCompletionRequest = { model, messages, maxTokens: input.config?.maxOutputTokens ?? 1024, temperature: input.config?.temperature ?? 0.7, ...(input.config?.topP && { topP: input.config.topP }), ...(input.config?.stopSequences && { stop: input.config.stopSequences }), ...(input.tools && { tools: input.tools.map((tool) => ({ type: 'function', function: { name: tool.name, description: tool.description, parameters: tool.inputSchema || {}, }, })) as MistralTool[], }), }; return request; } // Helper to convert Mistral AssistantMessage content into Genkit parts function fromMistralTextPart(content: string): Part { return { text: content, }; } // Helper to convert Mistral ToolCall into Genkit parts function fromMistralToolCall(toolCall: ToolCall): ToolRequestPart { if (!toolCall.function) { throw new Error('Tool call must include a function definition'); } return { toolRequest: { ref: toolCall.id, name: toolCall.function.name, input: typeof toolCall.function.arguments === 'string' ? JSON.parse(toolCall.function.arguments) : toolCall.function.arguments, }, }; } // Converts Mistral AssistantMessage content into Genkit parts function fromMistralMessage(message: AssistantMessage): Part[] { const parts: Part[] = []; // Handle textual content if (typeof message.content === 'string') { parts.push(fromMistralTextPart(message.content)); } else if (Array.isArray(message.content)) { // If content is an array of ContentChunk, handle each chunk message.content.forEach((chunk) => { if (chunk.type === 'text') { parts.push(fromMistralTextPart(chunk.text)); } // Add support for other ContentChunk types here if needed }); } // Handle tool calls if present if (message.toolCalls) { message.toolCalls.forEach((toolCall) => { parts.push(fromMistralToolCall(toolCall)); }); } return parts; } // Maps Mistral finish reasons to Genkit finish reasons export function fromMistralFinishReason( reason: ChatCompletionChoiceFinishReason | undefined ): 'length' | 'unknown' | 'stop' | 'blocked' | 'other' { switch (reason) { case ChatCompletionChoiceFinishReason.Stop: return 'stop'; case ChatCompletionChoiceFinishReason.Length: case ChatCompletionChoiceFinishReason.ModelLength: return 'length'; case ChatCompletionChoiceFinishReason.Error: return 'other'; // Map generic errors to "other" case ChatCompletionChoiceFinishReason.ToolCalls: return 'stop'; // Assuming tool calls signify a "stop" in processing default: return 'other'; // For undefined or unmapped reasons } } // Converts a Mistral response to a Genkit response export function fromMistralResponse( _input: GenerateRequest<typeof MistralConfigSchema>, response: ChatCompletionResponse ): ModelResponseData { const firstChoice = response.choices?.[0]; // Convert content from Mistral response to Genkit parts const contentParts: Part[] = firstChoice?.message ? fromMistralMessage(firstChoice.message) : []; const message: MessageData = { role: 'model', content: contentParts, }; return { message, finishReason: fromMistralFinishReason(firstChoice?.finishReason), usage: { inputTokens: response.usage.promptTokens, outputTokens: response.usage.completionTokens, }, custom: { id: response.id, model: response.model, created: response.created, }, raw: response, // Include the raw response for debugging or additional context }; } export function mistralModel( ai: Genkit, modelName: string, projectId: string, region: string ): ModelAction { const getClient = createClientFactory(projectId); const model = SUPPORTED_MISTRAL_MODELS[modelName]; if (!model) { throw new Error(`Unsupported Mistral model name ${modelName}`); } return ai.defineModel( { name: model.name, label: model.info?.label, configSchema: MistralConfigSchema, supports: model.info?.supports, versions: model.info?.versions, }, async (input, sendChunk) => { const client = getClient(input.config?.location || region); const versionedModel = input.config?.version ?? model.info?.versions?.[0] ?? model.name; if (!sendChunk) { const mistralRequest = toMistralRequest(versionedModel, input); const response = await client.chat.complete(mistralRequest, { fetchOptions: { headers: { 'X-Goog-Api-Client': GENKIT_CLIENT_HEADER, }, }, }); return fromMistralResponse(input, response); } else { const mistralRequest = toMistralRequest(versionedModel, input); const stream = await client.chat.stream(mistralRequest, { fetchOptions: { headers: { 'X-Goog-Api-Client': GENKIT_CLIENT_HEADER, }, }, }); for await (const event of stream) { const parts = fromMistralCompletionChunk(event.data); if (parts.length > 0) { sendChunk({ content: parts, }); } } // Get the complete response after streaming const completeResponse = await client.chat.complete(mistralRequest, { fetchOptions: { headers: { 'X-Goog-Api-Client': GENKIT_CLIENT_HEADER, }, }, }); return fromMistralResponse(input, completeResponse); } } ); } function createClientFactory(projectId: string) { const clients: Record<string, MistralGoogleCloud> = {}; return (region: string): MistralGoogleCloud => { if (!region) { throw new Error('Region is required to create Mistral client'); } try { if (!clients[region]) { clients[region] = new MistralGoogleCloud({ region, projectId, }); } return clients[region]; } catch (error) { throw new Error( `Failed to create/retrieve Mistral client for region ${region}: ${error}` ); } }; } type MistralMessage = | AssistantMessage | ToolMessage | SystemMessage | UserMessage; // Helper function to validate tool calls and responses match function validateToolSequence(messages: MistralMessage[]) { const toolCalls = ( messages.filter((m) => { return m.role === 'assistant' && m.toolCalls; }) as AssistantMessage[] ).reduce((acc: ToolCall[], m) => { if (m.toolCalls) { return [...acc, ...m.toolCalls]; } return acc; }, []); const toolResponses = messages.filter( (m) => m.role === 'tool' ) as ToolMessage[]; if (toolCalls.length !== toolResponses.length) { throw new Error( `Mismatch between tool calls (${toolCalls.length}) and responses (${toolResponses.length})` ); } toolResponses.forEach((response) => { const matchingCall = toolCalls.find( (call) => call.id === response.toolCallId ); if (!matchingCall) { throw new Error( `Tool response with ID ${response.toolCallId} has no matching call` ); } }); } export function fromMistralCompletionChunk(chunk: CompletionChunk): Part[] { if (!chunk.choices?.[0]?.delta) return []; const delta = chunk.choices[0].delta; const parts: Part[] = []; if (typeof delta.content === 'string') { parts.push({ text: delta.content }); } if (delta.toolCalls) { delta.toolCalls.forEach((toolCall) => { if (!toolCall.function) return; parts.push({ toolRequest: { ref: toolCall.id, name: toolCall.function.name, input: typeof toolCall.function.arguments === 'string' ? JSON.parse(toolCall.function.arguments) : toolCall.function.arguments, }, }); }); } return parts; }