MCP Server for Vertex AI Search
by ubie-oss
- src
- research_agent
import argparse
import asyncio
import json
import textwrap
from typing import List
from google import genai
from google.genai import chats, types
from loguru import logger
from pydantic import BaseModel, Field
from research_agent.mcp_client import MCPClient
from research_agent.utils import to_gemini_tool
class Reference(BaseModel):
"""A reference to a document."""
title: str = Field(..., description="The title of the document.")
raw_text: str = Field(..., description="The raw text of the document.")
class SearchResponse(BaseModel):
"""The response from the search tool."""
answer: str = Field(..., description="The answer to the user's question.")
references: List[Reference] = Field(
...,
description="The references to the documents that are used to answer the user's question.",
)
@classmethod
def from_json_string(cls, json_string: str) -> "SearchResponse":
"""Deserialize the search response from a JSON string."""
return cls(**json.loads(json_string))
def __str__(self) -> str:
return textwrap.dedent(f"""
Answer: {self.answer}
References:
{"\n".join([f" - {ref.title}: {ref.raw_text}" for ref in self.references])}
""")
async def process_query(
chat_client: chats.Chat,
mcp_client: MCPClient,
query: str,
) -> str:
"""Process the user query using Gemini and MCP tools."""
response = chat_client.send_message(message=[query])
if not response.candidates:
raise RuntimeError("No response from Gemini")
response_text = []
for candidate in response.candidates:
if not candidate.content:
logger.debug(f"No content in candidate {candidate}")
continue
for part in candidate.content.parts:
if part.text:
response_text.append(part.text)
elif part.function_call:
tool_name = part.function_call.name
tool_args = part.function_call.args
logger.debug(f"Tool name: {tool_name}, tool args: {tool_args}")
tool_call = await mcp_client.call_tool(tool_name, tool_args)
if tool_call and tool_call.content:
for content in tool_call.content:
text = content.text
if not text:
logger.info(f"No text in tool call content {content}")
continue
try:
parsed_content = SearchResponse.from_json_string(text)
response_text.append(str(parsed_content))
except Exception as e: # pylint: disable=broad-except
logger.error(
f"Failed to deserialize tool call content {content}: {e}"
)
response_text.append(text)
else:
raise RuntimeError(f"No tool call content {tool_call}")
else:
raise RuntimeError(f"Unknown part type {part}")
return "\n".join(response_text)
async def chat(server_url: str):
"""
Run the chat server.
"""
# Why do we use google-genai, not vertexai?
# Because it is easier to convert MCP tools to GenAI tools in google-genai.
genai_client = genai.Client(vertexai=True, location="us-central1")
mcp_client = MCPClient(name="document-search")
await mcp_client.connect_to_server(server_url=server_url)
# Collect tools from MCP server
mcp_tools = await mcp_client.list_tools()
# Convert MCP tools to GenAI tools
genai_tools = [to_gemini_tool(tool) for tool in mcp_tools.tools]
# Create chat client
chat_client = genai_client.chats.create(
model="gemini-2.0-flash",
config=types.GenerateContentConfig(
tools=genai_tools,
system_instruction="""
You are a helpful assistant to search documents.
You have to pass the query to the tool to search the documents as much natural as possible.
""",
),
)
print("If you want to quit, please enter 'bye'")
try:
while True:
# Get user query
query = input("Enter your query: ")
if query == "bye":
break
# Get response from GenAI
response = await process_query(chat_client, mcp_client, query)
print(response)
# pylint: disable=broad-except
except Exception as e:
await mcp_client.cleanup()
raise RuntimeError from e
if __name__ == "__main__":
# Parse command line arguments
parser = argparse.ArgumentParser()
# trunk-ignore(bandit/B104)
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=8080)
args = parser.parse_args()
# Run the chat server
server_url = f"http://{args.host}:{args.port}/sse"
asyncio.run(chat(server_url))