mcp_client.py•10.2 kB
import json
import os
import sys
from typing import Optional
from contextlib import AsyncExitStack
import httpx
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from dotenv import load_dotenv
from openai import OpenAI
load_dotenv()  # load environment variables from .env
def generate_system_prompt(tools):
    """
    Generate a concise system prompt for the assistant.
    This prompt is internal and not displayed to the user.
    """
    prompt_generator = SystemPromptGenerator()
    tools_json = {"tools": tools}
    system_prompt = prompt_generator.generate_prompt(tools_json)
    system_prompt += """
**GENERAL GUIDELINES:**
1. Step-by-step reasoning:
   - Analyze tasks systematically.
   - Break down complex problems into smaller, manageable parts.
   - Verify assumptions at each step to avoid errors.
   - Reflect on results to improve subsequent actions.
2. Effective tool usage:
   - Explore:
     - Identify available information and verify its structure.
     - Check assumptions and understand data relationships.
   - Iterate:
     - Start with simple queries or actions.
     - Build upon successes, adjusting based on observations.
   - Handle errors:
     - Carefully analyze error messages.
     - Use errors as a guide to refine your approach.
     - Document what went wrong and suggest fixes.
3. Clear communication:
   - Explain your reasoning and decisions at each step.
   - Share discoveries transparently with the user.
   - Outline next steps or ask clarifying questions as needed.
EXAMPLES OF BEST PRACTICES:
- Working with databases:
  - Check schema before writing queries.
  - Verify the existence of columns or tables.
  - Start with basic queries and refine based on results.
- Processing data:
  - Validate data formats and handle edge cases.
  - Ensure integrity and correctness of results.
- Accessing resources:
  - Confirm resource availability and permissions.
  - Handle missing or incomplete data gracefully.
REMEMBER:
- Be thorough and systematic.
- Each tool call should have a clear and well-explained purpose.
- Make reasonable assumptions if ambiguous.
- Minimize unnecessary user interactions by providing actionable insights.
EXAMPLES OF ASSUMPTIONS:
- Default sorting (e.g., descending order) if not specified.
- Assume basic user intentions, such as fetching top results by a common metric.
"""
    return system_prompt
class MCPClient:
    def __init__(self, vllm_url, vllm_api_key):
        self.stdio = None
        self.write = None
        self._session_context = None
        self._streams_context = None
        self.name = ""
        self.server_params = None
        # Initialize session and client objects
        self.session: Optional[ClientSession] = None
        self.exit_stack = AsyncExitStack()
        self.client = OpenAI(
            base_url=vllm_url,
            api_key=vllm_api_key,
            http_client=httpx.Client(verify=False)
            )
    async def connect_to_sse_server(self, server_url: str):
        """Connect to an MCP server running with SSE transport"""
        # Store the context managers so they stay alive
        self._streams_context = sse_client(url=server_url)
        streams = await self._streams_context.__aenter__()
        self._session_context = ClientSession(*streams)
        self.session: ClientSession = await self._session_context.__aenter__()
        # Initialize
        await self.session.initialize()
        # List available tools to verify connection
        print("Initialized SSE client...")
        print("Listing tools...")
        response = await self.session.list_tools()
        tools = response.tools
        print("\nConnected to server with tools:", [tool.name for tool in tools])
    async def connect_to_stdio_server(self, server_params, name):
            """Connect to an MCP server
            Args:
                server_script_path: Path to the server script (.py or .js)
            """
            self.name = name
            self.server_params = server_params
            stdio_transport = await self.exit_stack.enter_async_context(stdio_client(self.server_params))
            self.stdio, self.write = stdio_transport
            self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
            await self.session.initialize()
            # List available tools
            response = await self.session.list_tools()
            tools = response.tools
            print("\nConnected to server " + name + " with tools:", [tool.name for tool in tools])
            return self.stdio, self.write
    async def cleanup(self):
        """Properly clean up the session and streams"""
        if self._session_context:
            await self._session_context.__aexit__(None, None, None)
        if self._streams_context:
            await self._streams_context.__aexit__(None, None, None)
async def load_config(config_path: str, server_name: str) -> StdioServerParameters:
        """Load the server configuration from a JSON file."""
        try:
            # debug
            print(f"Loading config from {config_path}")
            # Read the configuration file
            with open(config_path, "r") as config_file:
                config = json.load(config_file)
            # Retrieve the server configuration
            server_config = config.get("mcpServers", {}).get(server_name)
            if not server_config:
                error_msg = f"Server '{server_name}' not found in configuration file."
                print(error_msg)
                raise ValueError(error_msg)
            # Construct the server parameters
            result = StdioServerParameters(
                command=server_config["command"],
                args=server_config.get("args", []),
                env=server_config.get("env"),
            )
            # debug
            print(
                f"Loaded config: command='{result.command}', args={result.args}, env={result.env}"
            )
            # return result
            return result
        except FileNotFoundError:
            # error
            error_msg = f"Configuration file not found: {config_path}"
            print(error_msg)
            raise FileNotFoundError(error_msg)
        except json.JSONDecodeError as e:
            # json error
            error_msg = f"Invalid JSON in configuration file: {e.msg}"
            print(error_msg)
            raise json.JSONDecodeError(error_msg, e.doc, e.pos)
        except ValueError as e:
            # error
            print(str(e))
            raise
class SystemPromptGenerator:
    """
    A class for generating system prompts dynamically based on tools JSON and user inputs.
    """
    def __init__(self):
        """
        Initialize the SystemPromptGenerator with a default system prompt template.
        """
        self.template = """
        In this environment you have access to a set of tools you can use to answer the user's question.
        {{ FORMATTING INSTRUCTIONS }}
        String and scalar parameters should be specified as is, while lists and objects should use JSON format. Note that spaces for string values are not stripped. The output is not expected to be valid XML and is parsed with regular expressions.
        Here are the functions available in JSONSchema format:
        {{ TOOL DEFINITIONS IN JSON SCHEMA }}
        {{ USER SYSTEM PROMPT }}
        {{ TOOL CONFIGURATION }}
        """
        self.default_user_system_prompt = "You are an intelligent assistant capable of using tools to solve user queries effectively."
        self.default_tool_config = "No additional configuration is required."
    def generate_prompt(
        self, tools: dict, user_system_prompt: str = None, tool_config: str = None
    ) -> str:
        """
        Generate a system prompt based on the provided tools JSON, user prompt, and tool configuration.
        Args:
            tools (dict): The tools JSON containing definitions of the available tools.
            user_system_prompt (str): A user-provided description or instruction for the assistant (optional).
            tool_config (str): Additional tool configuration information (optional).
        Returns:
            str: The dynamically generated system prompt.
        """
        # set the user system prompt
        user_system_prompt = user_system_prompt or self.default_user_system_prompt
        # set the tools config
        tool_config = tool_config or self.default_tool_config
        # get the tools schema
        tools_json_schema = json.dumps(tools, indent=2)
        # perform replacements
        prompt = self.template.replace(
            "{{ TOOL DEFINITIONS IN JSON SCHEMA }}", tools_json_schema
        )
        prompt = prompt.replace("{{ FORMATTING INSTRUCTIONS }}", "")
        prompt = prompt.replace("{{ USER SYSTEM PROMPT }}", user_system_prompt)
        prompt = prompt.replace("{{ TOOL CONFIGURATION }}", tool_config)
        # return the prompt
        return prompt
# Default environment variables to inherit
DEFAULT_INHERITED_ENV_VARS = (
    ["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"]
    if sys.platform != "win32"
    else [
        "APPDATA",
        "HOMEDRIVE",
        "HOMEPATH",
        "LOCALAPPDATA",
        "PATH",
        "PROCESSOR_ARCHITECTURE",
        "SYSTEMDRIVE",
        "SYSTEMROOT",
        "TEMP",
        "USERNAME",
        "USERPROFILE",
    ]
)
def get_default_environment() -> dict[str, str]:
    """
    Retrieve a dictionary of default environment variables to inherit.
    """
    # get the current environment
    env = {
        key: value
        for key in DEFAULT_INHERITED_ENV_VARS
        if (value := os.environ.get(key)) and not value.startswith("()")
    }
    # return the dictionary
    return env
def clean_response(response):
    # Remove artefacts from reply here
    response = response.replace("[TOOL_CALLS]", "")
    if "```json" in response:
        response = response.replace("'''json", "").replace("'''", "")
    return response