"""
Enhanced Gemini MCP Client with Real LLM-Based Tool Selection.
Implements intelligent tool selection and parameter extraction using Gemini LLM.
"""
import asyncio
import json
import os
import logging
from datetime import datetime
from typing import Dict, Any, List, Optional, Tuple
import google.generativeai as genai
from google.generativeai import types
from dotenv import load_dotenv
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from utils.gemini_client import get_rate_limited_client
# Load environment variables
load_dotenv("../.env")
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Validate API key
GEMINI_API_KEY = os.getenv('GEMINI_API_KEY')
if not GEMINI_API_KEY:
raise ValueError("GEMINI_API_KEY environment variable not set. Please check your .env file.")
class EnhancedGeminiMCPClient:
"""Enhanced Gemini MCP Client with real LLM-based tool selection and parameter extraction."""
def __init__(self, server_script_path: str = "src/server/main.py", server_env: Optional[Dict[str, str]] = None):
"""Initialize the client with server configuration.
Args:
server_script_path: Path to the MCP server script
server_env: Environment variables to pass to the server
"""
self.server_params = StdioServerParameters(
command="python",
args=[server_script_path],
env=server_env or {}
)
self.available_tools = []
self.tool_descriptions = {}
def format_tools_for_gemini(self, mcp_tools) -> List[types.Tool]:
"""Format MCP tools for Gemini API with enhanced schema handling.
Args:
mcp_tools: The result from session.list_tools()
Returns:
List of tools formatted for Gemini
"""
tools = []
self.available_tools = []
self.tool_descriptions = {}
for tool in mcp_tools.tools:
try:
# Clean up the input schema by removing unsupported properties
parameters = {
k: v for k, v in tool.inputSchema.items()
if k not in ["additionalProperties", "$schema"]
} if hasattr(tool, 'inputSchema') and tool.inputSchema else {"type": "object"}
# Create tool in Gemini format
gemini_tool = types.Tool(
function_declarations=[{
"name": tool.name,
"description": tool.description or "",
"parameters": parameters
}]
)
tools.append(gemini_tool)
self.available_tools.append(tool.name)
self.tool_descriptions[tool.name] = tool.description or ""
logger.info(f"Formatted tool: {tool.name} - {tool.description}")
except Exception as e:
logger.warning(f"Failed to format tool {tool.name}: {e}")
continue
return tools
async def process_query_with_llm_selection(self, session: ClientSession, query: str, model: str = "gemini-2.0-flash-lite") -> str:
"""Process a query using real LLM-based tool selection and parameter extraction.
Args:
session: Active MCP session
query: The user query
model: Gemini model to use
Returns:
The response from tool execution or LLM reasoning
"""
try:
# Get available tools
mcp_tools = await session.list_tools()
tools = self.format_tools_for_gemini(mcp_tools)
if not tools:
return "Error: No tools available for processing queries."
logger.info(f"Processing query with LLM-based tool selection: {query}")
logger.info(f"Available tools: {self.available_tools}")
# Create comprehensive prompt for LLM tool selection
selection_prompt = self._create_tool_selection_prompt(query)
# Get rate-limited client
if not GEMINI_API_KEY:
return "Error: GEMINI_API_KEY not set"
rate_limited_client = get_rate_limited_client(GEMINI_API_KEY)
# Generate response from Gemini with rate limiting
response = await rate_limited_client.generate_content(
contents=selection_prompt,
generation_config=types.GenerationConfig(
temperature=0.1, # Low temperature for consistent tool selection
),
tools=tools,
)
# Process the response
if not response.candidates:
return "Error: No response candidates from Gemini"
candidate = response.candidates[0]
# Check for function call (tool selection)
logger.info(f"Checking for function call in response...")
logger.info(f"Response parts: {len(candidate.content.parts) if candidate.content.parts else 0}")
if candidate.content.parts:
for i, part in enumerate(candidate.content.parts):
logger.info(f"Part {i}: {type(part)}")
if hasattr(part, 'function_call'):
logger.info(f"Part {i} has function_call: {part.function_call}")
if hasattr(part, 'text'):
logger.info(f"Part {i} has text: {part.text}")
# Look for a function call in any part
function_call = None
for part in candidate.content.parts or []:
if hasattr(part, 'function_call') and part.function_call:
function_call = part.function_call
break
if function_call:
logger.info(f"Found function call: {function_call.name}")
logger.info(f"Function call args: {function_call.args}")
if function_call.name and function_call.name in self.available_tools:
logger.info(f"LLM selected tool: {function_call.name}")
logger.info(f"LLM extracted parameters: {dict(function_call.args)}")
try:
# Call the MCP tool with LLM-extracted parameters
result = await session.call_tool(
function_call.name,
arguments=dict(function_call.args)
)
# Format and return the result
return self._format_tool_result(result)
except Exception as e:
logger.error(f"Error executing tool {function_call.name}: {e}")
return f"Error executing tool {function_call.name}: {str(e)}"
else:
logger.warning(f"LLM selected unknown tool: {function_call.name}")
return f"Error: Unknown tool '{function_call.name}' selected by LLM"
# If no function call, return LLM's reasoning (join only non-None text parts)
if candidate.content.parts:
text_parts = [part.text for part in candidate.content.parts if hasattr(part, 'text') and part.text is not None]
if text_parts:
return '\n'.join(text_parts)
# Fallback to response.text if available
if hasattr(response, 'text') and response.text:
return response.text
return "No meaningful response generated"
except Exception as e:
logger.error(f"Error processing query with LLM selection: {e}")
return f"Error processing query: {str(e)}"
def _create_tool_selection_prompt(self, query: str) -> str:
"""Create a comprehensive prompt for LLM-based tool selection."""
tool_info = "\n".join([
f"- {name}: {desc}"
for name, desc in self.tool_descriptions.items()
])
prompt = f"""You are an intelligent assistant that can use various tools to help users.
Available tools:
{tool_info}
User Query: "{query}"
Instructions:
1. Analyze the user's query to understand what they need
2. Select the most appropriate tool from the available options
3. Extract the relevant parameters from the user's query
4. Call the selected tool with the extracted parameters
Guidelines for tool selection:
- Use 'get_knowledge_base' for questions about company policies, benefits, procedures, HR information, employee guidelines, vacation, sick leave, dress code, working hours, etc.
- Use 'calculate' for mathematical calculations, arithmetic, formulas, equations, or any numerical computations
- Use 'get_weather' for weather-related queries, temperature, climate, or location-specific weather information
Examples:
- "What is the vacation policy?" → get_knowledge_base with query="vacation policy"
- "What is 15 + 27?" → calculate with expression="15 + 27"
- "What's the weather in New York?" → get_weather with location="New York"
Please select the appropriate tool and extract the necessary parameters from the user's query."""
return prompt
def _format_tool_result(self, result) -> str:
"""Format tool execution result for display.
Args:
result: The result from MCP tool execution
Returns:
Formatted result string
"""
try:
if hasattr(result, 'content') and result.content:
content_text = result.content[0].text if hasattr(result.content[0], 'text') else str(result.content[0])
# Try to parse and format as JSON
try:
parsed_json = json.loads(content_text)
return json.dumps(parsed_json, indent=2)
except json.JSONDecodeError:
# Return as plain text if not valid JSON
return content_text
return str(result)
except Exception as e:
logger.error(f"Error formatting result: {e}")
return f"Error formatting result: {str(e)}"
async def run_interactive_session(self):
"""Run an interactive session with real LLM-based tool selection."""
print("=== Enhanced Gemini MCP Interactive Client ===")
print("Type 'quit' or 'exit' to end the session")
print("Type 'status' to check rate limiting status")
print("Type 'tools' to see available tools")
print("Connecting to MCP server...")
try:
async with stdio_client(self.server_params) as (read, write):
async with ClientSession(read, write) as session:
await session.initialize()
# List available tools
mcp_tools = await session.list_tools()
tools = self.format_tools_for_gemini(mcp_tools)
print(f"\nConnected! Available tools ({len(mcp_tools.tools)}):")
for tool in mcp_tools.tools:
print(f" - {tool.name}: {tool.description}")
print("\nReady for queries! The LLM will intelligently select tools and extract parameters.")
while True:
try:
query = input("\n> ").strip()
if query.lower() in ['quit', 'exit', 'q']:
break
if query.lower() == 'status':
rate_limited_client = get_rate_limited_client(GEMINI_API_KEY)
status = rate_limited_client.get_rate_limit_status()
print("\n" + "="*50)
print("Rate Limiting Status:")
print(f"Current RPM: {status['current_rpm']}/{status['safe_rpm']}")
print(f"Current RPD: {status['current_rpd']}/{status['safe_rpd']}")
print(f"Current TPM: {status['current_tpm']}/{status['safe_tpm']}")
print(f"Available RPM: {status['rpm_available']}")
print(f"Available RPD: {status['rpd_available']}")
print(f"Available TPM: {status['tpm_available']}")
print("="*50)
continue
if query.lower() == 'tools':
print("\n" + "="*50)
print("Available Tools:")
for name, desc in self.tool_descriptions.items():
print(f" - {name}: {desc}")
print("="*50)
continue
if not query:
continue
print("\n" + "="*50)
print(f"Processing: {query}")
print("-" * 50)
response = await self.process_query_with_llm_selection(session, query)
print("Response:")
print(response)
print("="*50)
except KeyboardInterrupt:
print("\nUse 'quit' to exit gracefully.")
continue
except Exception as e:
print(f"Error: {e}")
continue
except Exception as e:
print(f"Failed to connect to MCP server: {e}")
print("Please ensure the server script exists and dependencies are installed.")
async def run_batch_queries(self, queries: List[str]):
"""Run a batch of predefined queries with LLM-based tool selection.
Args:
queries: List of queries to process
"""
print("=== Enhanced Gemini MCP Batch Client ===")
print("Connecting to MCP server...")
try:
async with stdio_client(self.server_params) as (read, write):
async with ClientSession(read, write) as session:
await session.initialize()
# List available tools
mcp_tools = await session.list_tools()
tools = self.format_tools_for_gemini(mcp_tools)
print(f"\nConnected! Available tools: {[tool.name for tool in mcp_tools.tools]}")
print("Processing queries with LLM-based tool selection...")
# Process each query
for i, query in enumerate(queries, 1):
print(f"\n{'='*60}")
print(f"Query {i}/{len(queries)}: {query}")
print("-" * 60)
try:
response = await self.process_query_with_llm_selection(session, query)
print("Response:")
print(response)
# Small delay between queries
if i < len(queries):
await asyncio.sleep(1)
except Exception as e:
print(f"Error processing query {i}: {e}")
continue
except Exception as e:
print(f"Failed to connect to MCP server: {e}")
async def main():
"""Main entry point with multiple operation modes."""
import sys
# Initialize client
client_instance = EnhancedGeminiMCPClient("src/server/main.py")
# Test queries that demonstrate different tool selection scenarios
test_queries = [
"What is the company's vacation policy?",
"Calculate 15 + 27 * 3",
"What's the weather in New York?",
# "How many sick days do employees get?",
# "What is the absolute value of -42?",
# "What benefits does the company offer?",
# "What's the temperature in Tokyo?",
# "What is 100 divided by 5?"
]
try:
if len(sys.argv) > 1 and sys.argv[1] == "--interactive":
# Interactive mode
await client_instance.run_interactive_session()
else:
# Batch mode with test queries
await client_instance.run_batch_queries(test_queries)
except KeyboardInterrupt:
print("\n\nShutting down gracefully...")
except Exception as e:
print(f"Unexpected error: {e}")
if __name__ == "__main__":
asyncio.run(main())