langgraph_agent.py•10.4 kB
import asyncio
import json
import subprocess
import sys
from typing import Dict, List, Any, Optional
from dataclasses import dataclass
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolNode
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
from langchain_core.tools import tool
import os
from dotenv import load_dotenv
load_dotenv()
@dataclass
class AgentState:
"""State for the LangGraph agent"""
messages: List[Any]
user_query: str
current_step: str
results: Dict[str, Any]
class MCPClient:
"""Client to interact with the MCP server"""
def __init__(self, server_script: str = "server.py"):
self.server_script = server_script
self.process = None
async def start_server(self):
"""Start the MCP server process"""
try:
self.process = subprocess.Popen(
[sys.executable, self.server_script],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=0
)
print("MCP Server started successfully")
except Exception as e:
print(f"Error starting MCP server: {e}")
raise
async def stop_server(self):
"""Stop the MCP server process"""
if self.process:
self.process.terminate()
self.process.wait()
print("MCP Server stopped")
async def call_tool(self, tool_name: str, **kwargs) -> str:
"""Call a tool on the MCP server"""
if not self.process:
await self.start_server()
# Create the tool call request
request = {
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": {
"name": tool_name,
"arguments": kwargs
}
}
try:
# Send request to server
self.process.stdin.write(json.dumps(request) + "\n")
self.process.stdin.flush()
# Read response
response_line = self.process.stdout.readline()
if response_line:
response = json.loads(response_line.strip())
if "result" in response:
return response["result"]["content"][0]["text"]
elif "error" in response:
return f"Error: {response['error']['message']}"
else:
return "No response from server"
except Exception as e:
return f"Error calling tool: {str(e)}"
class LangGraphAgent:
"""LangGraph agent that uses MCP tools"""
def __init__(self):
self.llm = ChatOpenAI(
model="gpt-4o-mini",
temperature=0.1,
api_key=os.getenv("OPENAI_API_KEY")
)
self.mcp_client = MCPClient()
self.graph = self._build_graph()
def _build_graph(self) -> StateGraph:
"""Build the LangGraph workflow"""
# Define the tools available to the agent
@tool
async def web_search_tool(query: str) -> str:
"""Search the web for information about the given query"""
return await self.mcp_client.call_tool("web_search", query=query)
@tool
async def roll_dice_tool(notation: str, num_rolls: int = 1) -> str:
"""Roll the dice with the given notation"""
return await self.mcp_client.call_tool("roll_dice", notation=notation, num_rolls=num_rolls)
@tool
async def get_stock_data_tool(symbol: str, date: str = None) -> str:
"""Get Open, High, Low, Close (OHLC) data for a stock symbol"""
return await self.mcp_client.call_tool("get_stock_ohlc", symbol=symbol, date=date)
# Create tool node
tools = [web_search_tool, roll_dice_tool, get_stock_data_tool]
tool_node = ToolNode(tools)
# Build the graph
workflow = StateGraph(AgentState)
# Add nodes
workflow.add_node("agent", self._agent_node)
workflow.add_node("tools", tool_node)
# Add conditional edges to prevent infinite loops
workflow.add_conditional_edges(
"agent",
self._should_continue,
{
"continue": "tools",
"end": END
}
)
workflow.add_edge("tools", "agent")
# Set entry point
workflow.set_entry_point("agent")
return workflow.compile()
def _should_continue(self, state: AgentState) -> str:
"""Determine whether to continue or end the workflow"""
# If we have tool calls, continue to tools
if state.messages and hasattr(state.messages[-1], 'tool_calls') and state.messages[-1].tool_calls:
return "continue"
# Otherwise, end the workflow
return "end"
async def _agent_node(self, state: AgentState) -> AgentState:
"""Main agent node that decides what to do"""
# Get the last message
last_message = state.messages[-1] if state.messages else HumanMessage(content=state.user_query)
# Create a prompt for the agent
system_prompt = """You are a helpful AI assistant with access to several tools:
1. web_search_tool(query): Search the web for information
2. roll_dice_tool(notation, num_rolls): Roll dice with given notation
3. get_stock_data_tool(symbol, date): Get stock OHLC data
Analyze the user's request and determine which tool(s) to use. If the user asks about:
- Stock prices or market data → use get_stock_data_tool
- Rolling dice or random numbers → use roll_dice_tool
- General information or web search → use web_search_tool
Always provide helpful and accurate responses based on the tool results."""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": state.user_query}
]
try:
# Get response from LLM
response = await self.llm.ainvoke(messages)
# Update state
state.messages.append(response)
state.current_step = "agent_response"
except Exception as e:
error_msg = f"Error in agent node: {str(e)}"
state.messages.append(AIMessage(content=error_msg))
state.current_step = "error"
return state
async def run(self, user_query: str) -> str:
"""Run the agent with a user query"""
# Initialize state
initial_state = AgentState(
messages=[HumanMessage(content=user_query)],
user_query=user_query,
current_step="start",
results={}
)
try:
# Start MCP server
await self.mcp_client.start_server()
# Run the graph
result = await self.graph.ainvoke(initial_state)
# Extract final response
if result.messages:
last_message = result.messages[-1]
if hasattr(last_message, 'content'):
return last_message.content
else:
return str(last_message)
else:
return "No response generated"
except Exception as e:
return f"Error running agent: {str(e)}"
finally:
# Clean up
await self.mcp_client.stop_server()
class SimpleAgentApp:
"""Simple application wrapper for the LangGraph agent"""
def __init__(self):
self.agent = LangGraphAgent()
async def process_query(self, query: str) -> str:
"""Process a user query through the agent"""
print(f"\nProcessing query: {query}")
print("=" * 50)
try:
result = await self.agent.run(query)
print(f"Response: {result}")
return result
except Exception as e:
error_msg = f"Error: {str(e)}"
print(error_msg)
return error_msg
async def interactive_mode(self):
"""Run the agent in interactive mode"""
print("LangGraph Agent with MCP Tools")
print("Available tools: web_search, roll_dice, get_stock_data")
print("Type 'quit' to exit\n")
while True:
try:
query = input("Enter your query: ").strip()
if query.lower() in ['quit', 'exit', 'q']:
print("Goodbye!")
break
if not query:
continue
await self.process_query(query)
print()
except KeyboardInterrupt:
print("\nGoodbye!")
break
except Exception as e:
print(f"Unexpected error: {e}")
async def main():
"""Main function to run the application"""
# Check for required environment variables
if not os.getenv("OPENAI_API_KEY"):
print("Error: OPENAI_API_KEY environment variable is required")
return
if not os.getenv("POLYGON_API_KEY"):
print("Warning: POLYGON_API_KEY not set. Stock data tool will not work.")
if not os.getenv("TAVILY_API_KEY"):
print("Warning: TAVILY_API_KEY not set. Web search tool will not work.")
# Create and run the application
app = SimpleAgentApp()
# Example queries to demonstrate functionality
example_queries = [
"What's the current stock price of Apple (AAPL)?",
"Roll 2d6 dice",
"Search for information about artificial intelligence trends",
"Get stock data for Microsoft (MSFT)"
]
print("Running example queries...")
for query in example_queries:
await app.process_query(query)
print()
# Uncomment the line below to run in interactive mode
# await app.interactive_mode()
if __name__ == "__main__":
asyncio.run(main())