models.ts•5.41 kB
import { Tool, CallToolResult } from '@modelcontextprotocol/sdk/types.js';
import { HuggingFaceClient } from '../client.js';
import { ModelSearchArgs, ModelInfoArgs } from '../types.js';
/**
* Tool definition for listing models
*/
export const listModelsToolDefinition: Tool = {
name: "hf_list_models",
description:
"Get information from all models in the Hub. Supports filtering by search terms, authors, tags, and more. " +
"Returns paginated results with model metadata including downloads, likes, and tags.",
inputSchema: {
type: "object",
properties: {
search: {
type: "string",
description: "Filter based on substrings for repos and their usernames (e.g., 'resnet', 'microsoft')"
},
author: {
type: "string",
description: "Filter models by author or organization (e.g., 'huggingface', 'microsoft')"
},
filter: {
type: "string",
description: "Filter based on tags (e.g., 'text-classification', 'spacy')"
},
sort: {
type: "string",
description: "Property to use when sorting (e.g., 'downloads', 'author')"
},
direction: {
type: "string",
description: "Sort direction: '-1' for descending, anything else for ascending"
},
limit: {
type: "number",
description: "Limit the number of models fetched"
},
full: {
type: "boolean",
description: "Whether to fetch most model data including all tags and files"
},
config: {
type: "boolean",
description: "Whether to also fetch the repo config"
}
},
required: []
}
};
/**
* Tool definition for getting model info
*/
export const getModelInfoToolDefinition: Tool = {
name: "hf_get_model_info",
description:
"Get detailed information for a specific model including metadata, files, configuration, and more.",
inputSchema: {
type: "object",
properties: {
repo_id: {
type: "string",
description: "Model repository ID (e.g., 'microsoft/DialoGPT-medium')"
},
revision: {
type: "string",
description: "Optional git revision (branch, tag, or commit hash)"
}
},
required: ["repo_id"]
}
};
/**
* Tool definition for getting model tags
*/
export const getModelTagsToolDefinition: Tool = {
name: "hf_get_model_tags",
description: "Gets all available model tags hosted in the Hub, organized by type (e.g., task types, libraries, languages).",
inputSchema: {
type: "object",
properties: {},
required: []
}
};
function isModelSearchArgs(args: unknown): args is ModelSearchArgs {
return typeof args === "object" && args !== null;
}
function isModelInfoArgs(args: unknown): args is ModelInfoArgs {
return (
typeof args === "object" &&
args !== null &&
"repo_id" in args &&
typeof (args as { repo_id: string }).repo_id === "string"
);
}
export async function handleListModels(client: HuggingFaceClient, args: unknown): Promise<CallToolResult> {
try {
if (!isModelSearchArgs(args)) {
throw new Error("Invalid arguments for hf_list_models");
}
const results = await client.getModels(args as Record<string, any>);
return {
content: [{ type: "text", text: results }],
isError: false,
};
} catch (error) {
return {
content: [
{
type: "text",
text: `Error: ${error instanceof Error ? error.message : String(error)}`,
},
],
isError: true,
};
}
}
export async function handleGetModelInfo(client: HuggingFaceClient, args: unknown): Promise<CallToolResult> {
try {
if (!isModelInfoArgs(args)) {
throw new Error("Invalid arguments for hf_get_model_info");
}
const { repo_id, revision } = args;
const results = await client.getModelInfo(repo_id, revision);
return {
content: [{ type: "text", text: results }],
isError: false,
};
} catch (error) {
return {
content: [
{
type: "text",
text: `Error: ${error instanceof Error ? error.message : String(error)}`,
},
],
isError: true,
};
}
}
export async function handleGetModelTags(client: HuggingFaceClient, args: unknown): Promise<CallToolResult> {
try {
const results = await client.getModelTags();
return {
content: [{ type: "text", text: results }],
isError: false,
};
} catch (error) {
return {
content: [
{
type: "text",
text: `Error: ${error instanceof Error ? error.message : String(error)}`,
},
],
isError: true,
};
}
}