Skip to main content
Glama
ImDPS
by ImDPS
main.py17.8 kB
""" Enhanced Gemini MCP Client with Real LLM-Based Tool Selection. Implements intelligent tool selection and parameter extraction using Gemini LLM. """ import asyncio import json import os import logging from datetime import datetime from typing import Dict, Any, List, Optional, Tuple import google.generativeai as genai from google.generativeai import types from dotenv import load_dotenv from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client import sys import os sys.path.append(os.path.join(os.path.dirname(__file__), '..')) from utils.gemini_client import get_rate_limited_client # Load environment variables load_dotenv("../.env") # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Validate API key GEMINI_API_KEY = os.getenv('GEMINI_API_KEY') if not GEMINI_API_KEY: raise ValueError("GEMINI_API_KEY environment variable not set. Please check your .env file.") class EnhancedGeminiMCPClient: """Enhanced Gemini MCP Client with real LLM-based tool selection and parameter extraction.""" def __init__(self, server_script_path: str = "src/server/main.py", server_env: Optional[Dict[str, str]] = None): """Initialize the client with server configuration. Args: server_script_path: Path to the MCP server script server_env: Environment variables to pass to the server """ self.server_params = StdioServerParameters( command="python", args=[server_script_path], env=server_env or {} ) self.available_tools = [] self.tool_descriptions = {} def format_tools_for_gemini(self, mcp_tools) -> List[types.Tool]: """Format MCP tools for Gemini API with enhanced schema handling. Args: mcp_tools: The result from session.list_tools() Returns: List of tools formatted for Gemini """ tools = [] self.available_tools = [] self.tool_descriptions = {} for tool in mcp_tools.tools: try: # Clean up the input schema by removing unsupported properties parameters = { k: v for k, v in tool.inputSchema.items() if k not in ["additionalProperties", "$schema"] } if hasattr(tool, 'inputSchema') and tool.inputSchema else {"type": "object"} # Create tool in Gemini format gemini_tool = types.Tool( function_declarations=[{ "name": tool.name, "description": tool.description or "", "parameters": parameters }] ) tools.append(gemini_tool) self.available_tools.append(tool.name) self.tool_descriptions[tool.name] = tool.description or "" logger.info(f"Formatted tool: {tool.name} - {tool.description}") except Exception as e: logger.warning(f"Failed to format tool {tool.name}: {e}") continue return tools async def process_query_with_llm_selection(self, session: ClientSession, query: str, model: str = "gemini-2.0-flash-lite") -> str: """Process a query using real LLM-based tool selection and parameter extraction. Args: session: Active MCP session query: The user query model: Gemini model to use Returns: The response from tool execution or LLM reasoning """ try: # Get available tools mcp_tools = await session.list_tools() tools = self.format_tools_for_gemini(mcp_tools) if not tools: return "Error: No tools available for processing queries." logger.info(f"Processing query with LLM-based tool selection: {query}") logger.info(f"Available tools: {self.available_tools}") # Create comprehensive prompt for LLM tool selection selection_prompt = self._create_tool_selection_prompt(query) # Get rate-limited client if not GEMINI_API_KEY: return "Error: GEMINI_API_KEY not set" rate_limited_client = get_rate_limited_client(GEMINI_API_KEY) # Generate response from Gemini with rate limiting response = await rate_limited_client.generate_content( contents=selection_prompt, generation_config=types.GenerationConfig( temperature=0.1, # Low temperature for consistent tool selection ), tools=tools, ) # Process the response if not response.candidates: return "Error: No response candidates from Gemini" candidate = response.candidates[0] # Check for function call (tool selection) logger.info(f"Checking for function call in response...") logger.info(f"Response parts: {len(candidate.content.parts) if candidate.content.parts else 0}") if candidate.content.parts: for i, part in enumerate(candidate.content.parts): logger.info(f"Part {i}: {type(part)}") if hasattr(part, 'function_call'): logger.info(f"Part {i} has function_call: {part.function_call}") if hasattr(part, 'text'): logger.info(f"Part {i} has text: {part.text}") # Look for a function call in any part function_call = None for part in candidate.content.parts or []: if hasattr(part, 'function_call') and part.function_call: function_call = part.function_call break if function_call: logger.info(f"Found function call: {function_call.name}") logger.info(f"Function call args: {function_call.args}") if function_call.name and function_call.name in self.available_tools: logger.info(f"LLM selected tool: {function_call.name}") logger.info(f"LLM extracted parameters: {dict(function_call.args)}") try: # Call the MCP tool with LLM-extracted parameters result = await session.call_tool( function_call.name, arguments=dict(function_call.args) ) # Format and return the result return self._format_tool_result(result) except Exception as e: logger.error(f"Error executing tool {function_call.name}: {e}") return f"Error executing tool {function_call.name}: {str(e)}" else: logger.warning(f"LLM selected unknown tool: {function_call.name}") return f"Error: Unknown tool '{function_call.name}' selected by LLM" # If no function call, return LLM's reasoning (join only non-None text parts) if candidate.content.parts: text_parts = [part.text for part in candidate.content.parts if hasattr(part, 'text') and part.text is not None] if text_parts: return '\n'.join(text_parts) # Fallback to response.text if available if hasattr(response, 'text') and response.text: return response.text return "No meaningful response generated" except Exception as e: logger.error(f"Error processing query with LLM selection: {e}") return f"Error processing query: {str(e)}" def _create_tool_selection_prompt(self, query: str) -> str: """Create a comprehensive prompt for LLM-based tool selection.""" tool_info = "\n".join([ f"- {name}: {desc}" for name, desc in self.tool_descriptions.items() ]) prompt = f"""You are an intelligent assistant that can use various tools to help users. Available tools: {tool_info} User Query: "{query}" Instructions: 1. Analyze the user's query to understand what they need 2. Select the most appropriate tool from the available options 3. Extract the relevant parameters from the user's query 4. Call the selected tool with the extracted parameters Guidelines for tool selection: - Use 'get_knowledge_base' for questions about company policies, benefits, procedures, HR information, employee guidelines, vacation, sick leave, dress code, working hours, etc. - Use 'calculate' for mathematical calculations, arithmetic, formulas, equations, or any numerical computations - Use 'get_weather' for weather-related queries, temperature, climate, or location-specific weather information Examples: - "What is the vacation policy?" → get_knowledge_base with query="vacation policy" - "What is 15 + 27?" → calculate with expression="15 + 27" - "What's the weather in New York?" → get_weather with location="New York" Please select the appropriate tool and extract the necessary parameters from the user's query.""" return prompt def _format_tool_result(self, result) -> str: """Format tool execution result for display. Args: result: The result from MCP tool execution Returns: Formatted result string """ try: if hasattr(result, 'content') and result.content: content_text = result.content[0].text if hasattr(result.content[0], 'text') else str(result.content[0]) # Try to parse and format as JSON try: parsed_json = json.loads(content_text) return json.dumps(parsed_json, indent=2) except json.JSONDecodeError: # Return as plain text if not valid JSON return content_text return str(result) except Exception as e: logger.error(f"Error formatting result: {e}") return f"Error formatting result: {str(e)}" async def run_interactive_session(self): """Run an interactive session with real LLM-based tool selection.""" print("=== Enhanced Gemini MCP Interactive Client ===") print("Type 'quit' or 'exit' to end the session") print("Type 'status' to check rate limiting status") print("Type 'tools' to see available tools") print("Connecting to MCP server...") try: async with stdio_client(self.server_params) as (read, write): async with ClientSession(read, write) as session: await session.initialize() # List available tools mcp_tools = await session.list_tools() tools = self.format_tools_for_gemini(mcp_tools) print(f"\nConnected! Available tools ({len(mcp_tools.tools)}):") for tool in mcp_tools.tools: print(f" - {tool.name}: {tool.description}") print("\nReady for queries! The LLM will intelligently select tools and extract parameters.") while True: try: query = input("\n> ").strip() if query.lower() in ['quit', 'exit', 'q']: break if query.lower() == 'status': rate_limited_client = get_rate_limited_client(GEMINI_API_KEY) status = rate_limited_client.get_rate_limit_status() print("\n" + "="*50) print("Rate Limiting Status:") print(f"Current RPM: {status['current_rpm']}/{status['safe_rpm']}") print(f"Current RPD: {status['current_rpd']}/{status['safe_rpd']}") print(f"Current TPM: {status['current_tpm']}/{status['safe_tpm']}") print(f"Available RPM: {status['rpm_available']}") print(f"Available RPD: {status['rpd_available']}") print(f"Available TPM: {status['tpm_available']}") print("="*50) continue if query.lower() == 'tools': print("\n" + "="*50) print("Available Tools:") for name, desc in self.tool_descriptions.items(): print(f" - {name}: {desc}") print("="*50) continue if not query: continue print("\n" + "="*50) print(f"Processing: {query}") print("-" * 50) response = await self.process_query_with_llm_selection(session, query) print("Response:") print(response) print("="*50) except KeyboardInterrupt: print("\nUse 'quit' to exit gracefully.") continue except Exception as e: print(f"Error: {e}") continue except Exception as e: print(f"Failed to connect to MCP server: {e}") print("Please ensure the server script exists and dependencies are installed.") async def run_batch_queries(self, queries: List[str]): """Run a batch of predefined queries with LLM-based tool selection. Args: queries: List of queries to process """ print("=== Enhanced Gemini MCP Batch Client ===") print("Connecting to MCP server...") try: async with stdio_client(self.server_params) as (read, write): async with ClientSession(read, write) as session: await session.initialize() # List available tools mcp_tools = await session.list_tools() tools = self.format_tools_for_gemini(mcp_tools) print(f"\nConnected! Available tools: {[tool.name for tool in mcp_tools.tools]}") print("Processing queries with LLM-based tool selection...") # Process each query for i, query in enumerate(queries, 1): print(f"\n{'='*60}") print(f"Query {i}/{len(queries)}: {query}") print("-" * 60) try: response = await self.process_query_with_llm_selection(session, query) print("Response:") print(response) # Small delay between queries if i < len(queries): await asyncio.sleep(1) except Exception as e: print(f"Error processing query {i}: {e}") continue except Exception as e: print(f"Failed to connect to MCP server: {e}") async def main(): """Main entry point with multiple operation modes.""" import sys # Initialize client client_instance = EnhancedGeminiMCPClient("src/server/main.py") # Test queries that demonstrate different tool selection scenarios test_queries = [ "What is the company's vacation policy?", "Calculate 15 + 27 * 3", "What's the weather in New York?", # "How many sick days do employees get?", # "What is the absolute value of -42?", # "What benefits does the company offer?", # "What's the temperature in Tokyo?", # "What is 100 divided by 5?" ] try: if len(sys.argv) > 1 and sys.argv[1] == "--interactive": # Interactive mode await client_instance.run_interactive_session() else: # Batch mode with test queries await client_instance.run_batch_queries(test_queries) except KeyboardInterrupt: print("\n\nShutting down gracefully...") except Exception as e: print(f"Unexpected error: {e}") if __name__ == "__main__": asyncio.run(main())

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/ImDPS/MCP'

If you have feedback or need assistance with the MCP directory API, please join our Discord server