mcp_client.py•16.3 kB
"""
Enhanced MCP Stock Client - AI-Powered Stock Query Interface
This module implements an enhanced MCP client that provides an intelligent interface
for stock market queries. It uses Google's Gemini AI to interpret natural language
queries and automatically selects appropriate tools from the MCP stock server.
Key Features:
- AI-powered query understanding via Google Gemini
- Robust error handling and retry mechanisms
- Configuration management via environment variables and config files
- Comprehensive logging for debugging and monitoring
- Input validation and sanitization
- Graceful fallback handling
Dependencies:
- mcp: Model Context Protocol client library
- google.genai: Google Gemini AI client
- python-dotenv: Environment variable management
"""
import asyncio
import os
import json
import logging
from typing import Dict, Any, Optional, List
from dataclasses import dataclass
from pathlib import Path
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from google import genai
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
@dataclass
class ClientConfig:
"""Configuration class for MCP client settings."""
gemini_api_key: str
server_command: str = "python"
server_args: List[str] = None
server_cwd: str = None
max_retries: int = 3
timeout_seconds: int = 30
def __post_init__(self):
if self.server_args is None:
self.server_args = ["mcp_server.py"]
if self.server_cwd is None:
self.server_cwd = os.getcwd()
class ConfigurationError(Exception):
"""Custom exception for configuration-related errors."""
pass
class AIQueryError(Exception):
"""Custom exception for AI query processing errors."""
pass
class MCPConnectionError(Exception):
"""Custom exception for MCP connection errors."""
pass
def load_configuration() -> ClientConfig:
"""
Load configuration from environment variables and config files.
Returns:
ClientConfig object with loaded settings
Raises:
ConfigurationError: If required configuration is missing or invalid
"""
try:
# Check for required API key
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
raise ConfigurationError(
"GEMINI_API_KEY environment variable is required. "
"Please set it in your .env file or environment."
)
# Load optional configuration
server_cwd = os.getenv("MCP_SERVER_CWD", os.getcwd())
max_retries = int(os.getenv("MCP_MAX_RETRIES", "3"))
timeout_seconds = int(os.getenv("MCP_TIMEOUT_SECONDS", "30"))
# Validate server directory exists
if not os.path.exists(server_cwd):
raise ConfigurationError(f"Server directory does not exist: {server_cwd}")
server_script = os.path.join(server_cwd, "mcp_server.py")
if not os.path.exists(server_script):
raise ConfigurationError(f"Server script not found: {server_script}")
config = ClientConfig(
gemini_api_key=api_key,
server_cwd=server_cwd,
max_retries=max_retries,
timeout_seconds=timeout_seconds
)
logger.info(f"Configuration loaded successfully. Server CWD: {server_cwd}")
return config
except ValueError as e:
raise ConfigurationError(f"Invalid configuration value: {e}")
except Exception as e:
raise ConfigurationError(f"Failed to load configuration: {e}")
def get_tool_identifier_prompt() -> str:
"""
Get the prompt template for tool identification.
Returns:
Formatted prompt template for AI tool identification
"""
return """
You have been given access to the below MCP Server Tools:
{tools_description}
You must identify the appropriate tool from the above tools required to resolve the user query along with the arguments.
User Query: {user_query}
Your output should be in JSON format like below:
{{
"user_query": "User Query",
"tool_identified": "Tool Name",
"arguments": {{"arg1": "value1", "arg2": "value2"}}
}}
Important Guidelines:
1. Only use tools from the provided list
2. Ensure argument names match exactly what the tool expects
3. For stock symbols, use uppercase format (e.g., "AAPL", not "apple")
4. If the query is unclear, choose the most appropriate tool based on context
5. For comparison queries, use the compare_stocks tool with symbol1 and symbol2 parameters
Examples:
User Query: "What is the price of Apple stock?"
Response:
{{
"user_query": "What is the price of Apple stock?",
"tool_identified": "get_stock_price",
"arguments": {{"symbol": "AAPL"}}
}}
User Query: "Compare Microsoft and Google stocks"
Response:
{{
"user_query": "Compare Microsoft and Google stocks",
"tool_identified": "compare_stocks",
"arguments": {{"symbol1": "MSFT", "symbol2": "GOOGL"}}
}}
"""
async def generate_ai_response(user_query: str, tools_description: str, config: ClientConfig) -> Dict[str, Any]:
"""
Generate AI response to identify appropriate tool for user query.
Args:
user_query: The user's input query that needs to be resolved
tools_description: Description of available MCP server tools
config: Client configuration object
Returns:
Dictionary containing:
- user_query: The original user query
- tool_identified: Name of the identified tool
- arguments: Dictionary of arguments for the tool
Raises:
AIQueryError: If AI processing fails or returns invalid response
"""
try:
logger.info(f"Processing query with AI: {user_query}")
client = genai.Client(api_key=config.gemini_api_key)
prompt_template = get_tool_identifier_prompt()
formatted_prompt = prompt_template.format(
user_query=user_query,
tools_description=tools_description
)
response = client.models.generate_content(
model='gemini-2.0-flash-001',
contents=formatted_prompt
)
if not response or not response.text:
raise AIQueryError("AI model returned empty response")
# Clean up response text
raw_text = response.text.strip()
raw_text = raw_text.replace("```json", "").replace("```", "")
# Parse JSON response
try:
data = json.loads(raw_text)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse AI response as JSON: {raw_text}")
raise AIQueryError(f"AI response is not valid JSON: {e}")
# Validate response structure
required_fields = ["user_query", "tool_identified", "arguments"]
for field in required_fields:
if field not in data:
raise AIQueryError(f"AI response missing required field: {field}")
# Handle legacy string arguments format
if isinstance(data["arguments"], str):
logger.warning("AI returned string arguments, attempting to parse")
args_list = [arg.strip() for arg in data["arguments"].split(",")]
if len(args_list) == 2:
data["arguments"] = {"symbol1": args_list[0], "symbol2": args_list[1]}
elif len(args_list) == 1:
data["arguments"] = {"symbol": args_list[0]}
else:
raise AIQueryError("Could not parse string arguments format")
logger.info(f"AI identified tool: {data['tool_identified']} with args: {data['arguments']}")
return data
except genai.errors.GoogleGenerativeAIError as e:
raise AIQueryError(f"Google AI API error: {e}")
except Exception as e:
if isinstance(e, AIQueryError):
raise
raise AIQueryError(f"Unexpected error in AI processing: {e}")
async def execute_mcp_query(user_input: str, config: ClientConfig) -> str:
"""
Execute a single MCP query with comprehensive error handling.
Args:
user_input: The user's query to be processed
config: Client configuration object
Returns:
Response string from the MCP server
Raises:
MCPConnectionError: If MCP server connection fails
AIQueryError: If AI query processing fails
"""
logger.info(f"Executing MCP query: {user_input}")
server_params = StdioServerParameters(
command=config.server_command,
args=config.server_args,
cwd=config.server_cwd
)
try:
async with stdio_client(server_params) as (read, write):
logger.info("MCP connection established, creating session...")
async with ClientSession(read, write) as session:
logger.info("MCP session created, initializing...")
await session.initialize()
logger.info("MCP session initialized successfully")
# Get available tools
tools = await session.list_tools()
if not tools.tools:
raise MCPConnectionError("No tools available from MCP server")
# Build tools description
tools_description = ""
for tool in tools.tools:
tool_desc = f"Tool - {tool.name}:\n{tool.description}\n"
if hasattr(tool, 'inputSchema') and tool.inputSchema:
tool_desc += f"Parameters: {tool.inputSchema}\n"
tools_description += tool_desc + "\n"
# Get AI recommendation
request_json = await generate_ai_response(user_input, tools_description, config)
tool_name = request_json["tool_identified"]
arguments = request_json["arguments"]
logger.info(f"Executing tool '{tool_name}' with arguments: {arguments}")
# Execute the tool
response = await session.call_tool(tool_name, arguments=arguments)
if not response.content:
return "Tool executed successfully but returned no content."
result = response.content[0].text
logger.info("Tool execution completed successfully")
return result
except asyncio.TimeoutError:
raise MCPConnectionError(f"MCP operation timed out after {config.timeout_seconds} seconds")
except Exception as e:
if isinstance(e, (MCPConnectionError, AIQueryError)):
raise
logger.error(f"Unexpected error in MCP execution: {e}")
raise MCPConnectionError(f"MCP execution failed: {e}")
async def process_user_query(user_input: str, config: ClientConfig) -> None:
"""
Process a user query with retry logic and comprehensive error handling.
Args:
user_input: The user's query to be processed
config: Client configuration object
"""
print("-" * 50)
print(f"Processing query: {user_input}")
for attempt in range(config.max_retries):
try:
result = await execute_mcp_query(user_input, config)
print(f"Result: {result}")
print("-" * 50)
print("\n")
return
except AIQueryError as e:
print(f"AI processing error: {e}")
if attempt < config.max_retries - 1:
print(f"Retrying... (attempt {attempt + 2}/{config.max_retries})")
await asyncio.sleep(1)
else:
print("Max retries reached for AI processing.")
except MCPConnectionError as e:
print(f"MCP connection error: {e}")
if attempt < config.max_retries - 1:
print(f"Retrying... (attempt {attempt + 2}/{config.max_retries})")
await asyncio.sleep(2)
else:
print("Max retries reached for MCP connection.")
except Exception as e:
print(f"Unexpected error: {e}")
logger.error(f"Unexpected error processing query: {e}")
break
print("-" * 50)
print("\n")
def validate_user_input(user_input: str) -> str:
"""
Validate and sanitize user input.
Args:
user_input: Raw user input string
Returns:
Cleaned and validated input string
Raises:
ValueError: If input is invalid
"""
if not user_input or not user_input.strip():
raise ValueError("Query cannot be empty")
cleaned_input = user_input.strip()
# Basic length validation
if len(cleaned_input) > 500:
raise ValueError("Query too long (max 500 characters)")
return cleaned_input
async def main():
"""
Main function to run the interactive MCP client.
Provides an interactive loop that continuously prompts users for queries
and processes them using the enhanced MCP client system with AI integration.
"""
try:
# Load configuration
config = load_configuration()
logger.info("Enhanced MCP Stock Client started successfully")
print("=" * 60)
print("Enhanced MCP Stock Query System")
print("=" * 60)
print("Ask questions about stock prices, comparisons, or market data.")
print("Examples:")
print(" - What's the price of Apple stock?")
print(" - Compare Microsoft and Google stocks")
print(" - Get market summary")
print("Type 'quit' or 'exit' to stop.")
print("=" * 60)
print()
while True:
try:
user_input = input("What is your query? → ").strip()
# Check for exit commands
if user_input.lower() in ['quit', 'exit', 'q']:
print("Goodbye!")
break
# Validate input
try:
validated_input = validate_user_input(user_input)
except ValueError as e:
print(f"Invalid input: {e}")
continue
# Process the query
await process_user_query(validated_input, config)
except KeyboardInterrupt:
print("\nGoodbye!")
break
except EOFError:
print("\nGoodbye!")
break
except ConfigurationError as e:
print(f"Configuration error: {e}")
print("Please check your .env file and configuration settings.")
return 1
except Exception as e:
logger.error(f"Fatal error in main: {e}")
print(f"Fatal error: {e}")
return 1
return 0
if __name__ == "__main__":
"""
Entry point for the Enhanced MCP Stock Client.
Runs the interactive client that allows users to query stock information
using natural language. The client uses AI to understand queries and
automatically selects appropriate tools from the MCP server.
Features:
- Interactive command-line interface
- AI-powered query understanding
- Robust error handling and retry logic
- Configuration management
- Input validation and sanitization
- Comprehensive logging
Usage:
python mcp_client.py
Environment Variables:
GEMINI_API_KEY: Required Google Gemini API key
MCP_SERVER_CWD: Optional server working directory (default: current directory)
MCP_MAX_RETRIES: Optional max retry attempts (default: 3)
MCP_TIMEOUT_SECONDS: Optional timeout in seconds (default: 30)
"""
exit_code = asyncio.run(main())
exit(exit_code)