cascading-search.ts•3.15 kB
import {McpServer} from '@modelcontextprotocol/sdk/server/mcp.js';
import {Pinecone} from '@pinecone-database/pinecone';
import {z} from 'zod';
import {RERANK_MODEL_SCHEMA} from './common/rerank-model.js';
import {SEARCH_QUERY_SCHEMA} from './common/search-query.js';
const INSTRUCTIONS = `Search across multiple indexes for records that are
similar to the query text, deduplicate and rerank the results.`;
const INDEX_SCHEMA = z.object({
  name: z.string().describe('An index to search.'),
  namespace: z.string().describe('A namespace to search.'),
});
const RERANK_SCHEMA = z
  .object({
    model: RERANK_MODEL_SCHEMA,
    topN: z
      .number()
      .optional()
      .describe(
        `The number of results to return after reranking. Must be less than or
        equal to the value of "query.topK".`,
      ),
    rankFields: z.array(z.string()).describe(
      `The fields to rerank on. This should include the field name specified
      in the index's "fieldMap". The "bge-reranker-v2-m3" and
      "pinecone-rerank-v0" models support only a single rerank field.
      "cohere-rerank-3.5" supports multiple rerank fields.`,
    ),
    query: z
      .string()
      .optional()
      .describe(
        `An optional query to rerank documents against. If not specified, the
        same query will be used for both the initial search and the reranking.`,
      ),
  })
  .describe(
    `Specifies how the results should be reranked. Use a "query" with a "topK"
    that returns more results than you need; then use "rerank" to select the
    most relevant "topN" results.`,
  );
export const SCHEMA = {
  indexes: z
    .array(INDEX_SCHEMA)
    .describe('The indexes to search across. Records in each index should share a common schema.'),
  query: SEARCH_QUERY_SCHEMA,
  rerank: RERANK_SCHEMA,
};
export function addCascadingSearchTool(server: McpServer, pc: Pinecone) {
  server.tool('cascading-search', INSTRUCTIONS, SCHEMA, async ({indexes, query, rerank}) => {
    const initialResults = await Promise.all(
      indexes.map(async (index) => {
        const ns = pc.index(index.name).namespace(index.namespace || '');
        const results = await ns.searchRecords({query});
        return results;
      }),
    );
    const deduplicatedResults: Record<string, Record<string, string>> = {};
    for (const results of initialResults) {
      for (const hit of results.result.hits) {
        if (!deduplicatedResults[hit._id]) {
          deduplicatedResults[hit._id] = hit.fields as Record<string, string>;
        }
      }
    }
    const deduplicatedResultsArray = Object.values(deduplicatedResults);
    const rerankedResults =
      deduplicatedResultsArray.length > 0
        ? await pc.inference.rerank(
            rerank.model,
            rerank.query || query.inputs.text,
            deduplicatedResultsArray,
            {
              topN: rerank.topN || query.topK,
              rankFields: rerank.rankFields,
            },
          )
        : [];
    return {
      content: [
        {
          type: 'text',
          text: JSON.stringify(rerankedResults, null, 2),
        },
      ],
    };
  });
}