MCP Terminal Server
by dillip285
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { MistralGoogleCloud } from '@mistralai/mistralai-gcp';
import {
AssistantMessage,
ChatCompletionChoiceFinishReason,
ChatCompletionRequest,
ChatCompletionResponse,
CompletionChunk,
FunctionCall,
Tool as MistralTool,
SystemMessage,
ToolCall,
ToolMessage,
ToolTypes,
UserMessage,
} from '@mistralai/mistralai-gcp/models/components';
import {
GENKIT_CLIENT_HEADER,
GenerateRequest,
GenerationCommonConfigSchema,
Genkit,
MessageData,
ModelReference,
ModelResponseData,
Part,
Role,
ToolRequestPart,
z,
} from 'genkit';
import { ModelAction, modelRef } from 'genkit/model';
export const MistralConfigSchema = GenerationCommonConfigSchema.extend({
location: z.string().optional(),
maxOutputTokens: z.number().optional(),
temperature: z.number().optional(),
// TODO: is this supported?
// topK: z.number().optional(),
topP: z.number().optional(),
stopSequences: z.array(z.string()).optional(),
});
export const mistralLarge = modelRef({
name: 'vertexai/mistral-large',
info: {
label: 'Vertex AI Model Garden - Mistral Large',
versions: ['mistral-large-2411', 'mistral-large-2407'],
supports: {
multiturn: true,
media: false,
tools: true,
systemRole: true,
output: ['text'],
},
},
configSchema: MistralConfigSchema,
});
export const mistralNemo = modelRef({
name: 'vertexai/mistral-nemo',
info: {
label: 'Vertex AI Model Garden - Mistral Nemo',
versions: ['mistral-nemo-2407'],
supports: {
multiturn: true,
media: false,
tools: false,
systemRole: true,
output: ['text'],
},
},
configSchema: MistralConfigSchema,
});
export const codestral = modelRef({
name: 'vertexai/codestral',
info: {
label: 'Vertex AI Model Garden - Codestral',
versions: ['codestral-2405'],
supports: {
multiturn: true,
media: false,
tools: false,
systemRole: true,
output: ['text'],
},
},
configSchema: MistralConfigSchema,
});
export const SUPPORTED_MISTRAL_MODELS: Record<
string,
ModelReference<typeof MistralConfigSchema>
> = {
'mistral-large': mistralLarge,
'mistral-nemo': mistralNemo,
codestral: codestral,
};
// TODO: Do they export a type for this?
type MistralRole = 'assistant' | 'user' | 'tool' | 'system';
function toMistralRole(role: Role): MistralRole {
switch (role) {
case 'model':
return 'assistant';
case 'user':
return 'user';
case 'tool':
return 'tool';
case 'system':
return 'system';
default:
throw new Error(`Unknwon role ${role}`);
}
}
function toMistralToolRequest(toolRequest: Record<string, any>): FunctionCall {
if (!toolRequest.name) {
throw new Error('Tool name is required');
}
return {
name: toolRequest.name,
// Mistral expects arguments as either a string or object
arguments:
typeof toolRequest.input === 'string'
? toolRequest.input
: JSON.stringify(toolRequest.input),
};
}
export function toMistralRequest(
model: string,
input: GenerateRequest<typeof MistralConfigSchema>
): ChatCompletionRequest {
const messages = input.messages.map((msg) => {
// Handle regular text messages
if (msg.content.every((part) => part.text)) {
const content = msg.content.map((part) => part.text || '').join('');
return {
role: toMistralRole(msg.role),
content,
};
}
// Handle assistant's tool/function calls
const toolRequest = msg.content.find((part) => part.toolRequest);
if (toolRequest?.toolRequest) {
const functionCall = toMistralToolRequest(toolRequest.toolRequest);
return {
role: 'assistant' as const,
content: null,
toolCalls: [
{
id: toolRequest.toolRequest.ref,
type: ToolTypes.Function,
function: {
name: functionCall.name,
arguments: functionCall.arguments,
},
},
],
};
}
// Handle tool responses
const toolResponse = msg.content.find((part) => part.toolResponse);
if (toolResponse?.toolResponse) {
return {
role: 'tool' as const,
name: toolResponse.toolResponse.name,
content: JSON.stringify(toolResponse.toolResponse.output),
toolCallId: toolResponse.toolResponse.ref, // This must match the id from tool_calls
};
}
return {
role: toMistralRole(msg.role),
content: msg.content.map((part) => part.text || '').join(''),
};
});
validateToolSequence(messages); // This line exists but might not be running?
const request: ChatCompletionRequest = {
model,
messages,
maxTokens: input.config?.maxOutputTokens ?? 1024,
temperature: input.config?.temperature ?? 0.7,
...(input.config?.topP && { topP: input.config.topP }),
...(input.config?.stopSequences && { stop: input.config.stopSequences }),
...(input.tools && {
tools: input.tools.map((tool) => ({
type: 'function',
function: {
name: tool.name,
description: tool.description,
parameters: tool.inputSchema || {},
},
})) as MistralTool[],
}),
};
return request;
}
// Helper to convert Mistral AssistantMessage content into Genkit parts
function fromMistralTextPart(content: string): Part {
return {
text: content,
};
}
// Helper to convert Mistral ToolCall into Genkit parts
function fromMistralToolCall(toolCall: ToolCall): ToolRequestPart {
if (!toolCall.function) {
throw new Error('Tool call must include a function definition');
}
return {
toolRequest: {
ref: toolCall.id,
name: toolCall.function.name,
input:
typeof toolCall.function.arguments === 'string'
? JSON.parse(toolCall.function.arguments)
: toolCall.function.arguments,
},
};
}
// Converts Mistral AssistantMessage content into Genkit parts
function fromMistralMessage(message: AssistantMessage): Part[] {
const parts: Part[] = [];
// Handle textual content
if (typeof message.content === 'string') {
parts.push(fromMistralTextPart(message.content));
} else if (Array.isArray(message.content)) {
// If content is an array of ContentChunk, handle each chunk
message.content.forEach((chunk) => {
if (chunk.type === 'text') {
parts.push(fromMistralTextPart(chunk.text));
}
// Add support for other ContentChunk types here if needed
});
}
// Handle tool calls if present
if (message.toolCalls) {
message.toolCalls.forEach((toolCall) => {
parts.push(fromMistralToolCall(toolCall));
});
}
return parts;
}
// Maps Mistral finish reasons to Genkit finish reasons
export function fromMistralFinishReason(
reason: ChatCompletionChoiceFinishReason | undefined
): 'length' | 'unknown' | 'stop' | 'blocked' | 'other' {
switch (reason) {
case ChatCompletionChoiceFinishReason.Stop:
return 'stop';
case ChatCompletionChoiceFinishReason.Length:
case ChatCompletionChoiceFinishReason.ModelLength:
return 'length';
case ChatCompletionChoiceFinishReason.Error:
return 'other'; // Map generic errors to "other"
case ChatCompletionChoiceFinishReason.ToolCalls:
return 'stop'; // Assuming tool calls signify a "stop" in processing
default:
return 'other'; // For undefined or unmapped reasons
}
}
// Converts a Mistral response to a Genkit response
export function fromMistralResponse(
_input: GenerateRequest<typeof MistralConfigSchema>,
response: ChatCompletionResponse
): ModelResponseData {
const firstChoice = response.choices?.[0];
// Convert content from Mistral response to Genkit parts
const contentParts: Part[] = firstChoice?.message
? fromMistralMessage(firstChoice.message)
: [];
const message: MessageData = {
role: 'model',
content: contentParts,
};
return {
message,
finishReason: fromMistralFinishReason(firstChoice?.finishReason),
usage: {
inputTokens: response.usage.promptTokens,
outputTokens: response.usage.completionTokens,
},
custom: {
id: response.id,
model: response.model,
created: response.created,
},
raw: response, // Include the raw response for debugging or additional context
};
}
export function mistralModel(
ai: Genkit,
modelName: string,
projectId: string,
region: string
): ModelAction {
const getClient = createClientFactory(projectId);
const model = SUPPORTED_MISTRAL_MODELS[modelName];
if (!model) {
throw new Error(`Unsupported Mistral model name ${modelName}`);
}
return ai.defineModel(
{
name: model.name,
label: model.info?.label,
configSchema: MistralConfigSchema,
supports: model.info?.supports,
versions: model.info?.versions,
},
async (input, sendChunk) => {
const client = getClient(input.config?.location || region);
const versionedModel =
input.config?.version ?? model.info?.versions?.[0] ?? model.name;
if (!sendChunk) {
const mistralRequest = toMistralRequest(versionedModel, input);
const response = await client.chat.complete(mistralRequest, {
fetchOptions: {
headers: {
'X-Goog-Api-Client': GENKIT_CLIENT_HEADER,
},
},
});
return fromMistralResponse(input, response);
} else {
const mistralRequest = toMistralRequest(versionedModel, input);
const stream = await client.chat.stream(mistralRequest, {
fetchOptions: {
headers: {
'X-Goog-Api-Client': GENKIT_CLIENT_HEADER,
},
},
});
for await (const event of stream) {
const parts = fromMistralCompletionChunk(event.data);
if (parts.length > 0) {
sendChunk({
content: parts,
});
}
}
// Get the complete response after streaming
const completeResponse = await client.chat.complete(mistralRequest, {
fetchOptions: {
headers: {
'X-Goog-Api-Client': GENKIT_CLIENT_HEADER,
},
},
});
return fromMistralResponse(input, completeResponse);
}
}
);
}
function createClientFactory(projectId: string) {
const clients: Record<string, MistralGoogleCloud> = {};
return (region: string): MistralGoogleCloud => {
if (!region) {
throw new Error('Region is required to create Mistral client');
}
try {
if (!clients[region]) {
clients[region] = new MistralGoogleCloud({
region,
projectId,
});
}
return clients[region];
} catch (error) {
throw new Error(
`Failed to create/retrieve Mistral client for region ${region}: ${error}`
);
}
};
}
type MistralMessage =
| AssistantMessage
| ToolMessage
| SystemMessage
| UserMessage;
// Helper function to validate tool calls and responses match
function validateToolSequence(messages: MistralMessage[]) {
const toolCalls = (
messages.filter((m) => {
return m.role === 'assistant' && m.toolCalls;
}) as AssistantMessage[]
).reduce((acc: ToolCall[], m) => {
if (m.toolCalls) {
return [...acc, ...m.toolCalls];
}
return acc;
}, []);
const toolResponses = messages.filter(
(m) => m.role === 'tool'
) as ToolMessage[];
if (toolCalls.length !== toolResponses.length) {
throw new Error(
`Mismatch between tool calls (${toolCalls.length}) and responses (${toolResponses.length})`
);
}
toolResponses.forEach((response) => {
const matchingCall = toolCalls.find(
(call) => call.id === response.toolCallId
);
if (!matchingCall) {
throw new Error(
`Tool response with ID ${response.toolCallId} has no matching call`
);
}
});
}
export function fromMistralCompletionChunk(chunk: CompletionChunk): Part[] {
if (!chunk.choices?.[0]?.delta) return [];
const delta = chunk.choices[0].delta;
const parts: Part[] = [];
if (typeof delta.content === 'string') {
parts.push({ text: delta.content });
}
if (delta.toolCalls) {
delta.toolCalls.forEach((toolCall) => {
if (!toolCall.function) return;
parts.push({
toolRequest: {
ref: toolCall.id,
name: toolCall.function.name,
input:
typeof toolCall.function.arguments === 'string'
? JSON.parse(toolCall.function.arguments)
: toolCall.function.arguments,
},
});
});
}
return parts;
}