"""LangGraph-powered product agent that orchestrates MCP and custom tools."""
from __future__ import annotations
import json
import logging
from typing import Any, Protocol, TypedDict
from app.agent.mcp_client import MCPProductClient
from app.agent.mock_llm import HumanMessage, RuleBasedMockLLM
from app.agent.tools import calculate_discount, format_products, format_statistics
LOGGER = logging.getLogger(__name__)
try:
from langgraph.graph import END, StateGraph
except ImportError: # pragma: no cover - local fallback for offline environments
END = "__end__"
class _FallbackCompiledGraph:
def __init__(self, agent: "ProductAgent") -> None:
self._agent = agent
async def ainvoke(self, initial_state: dict[str, Any]) -> dict[str, Any]:
state = dict(initial_state)
state.update(self._agent._analyze_query(state))
state.update(await self._agent._execute_tools(state))
state.update(self._agent._format_response(state))
return state
class StateGraph: # type: ignore[override]
def __init__(self, state_type: Any) -> None:
del state_type
self._agent: ProductAgent | None = None
def add_node(self, name: str, fn: Any) -> None:
if name == "analyze":
self._agent = fn.__self__
def set_entry_point(self, name: str) -> None:
del name
def add_edge(self, source: str, target: str) -> None:
del source, target
def compile(self) -> _FallbackCompiledGraph:
if self._agent is None:
raise RuntimeError("Fallback graph is not initialized")
return _FallbackCompiledGraph(self._agent)
class ProductClientProtocol(Protocol):
"""Protocol for MCP product client to simplify testing."""
async def list_products(self, category: str | None = None) -> list[dict[str, Any]]:
...
async def get_product(self, product_id: int) -> dict[str, Any]:
...
async def add_product(
self,
name: str,
price: float,
category: str,
in_stock: bool = True,
) -> dict[str, Any]:
...
async def get_statistics(self) -> dict[str, Any]:
...
class AgentState(TypedDict, total=False):
query: str
plan: dict[str, Any]
tool_result: Any
response: str
class ProductAgent:
"""Agent that interprets natural language and executes tool calls."""
def __init__(
self,
mcp_client: ProductClientProtocol | None = None,
llm: RuleBasedMockLLM | None = None,
) -> None:
self._mcp_client = mcp_client or MCPProductClient()
self._llm = llm or RuleBasedMockLLM()
self._graph = self._build_graph()
def _build_graph(self):
workflow = StateGraph(AgentState)
workflow.add_node("analyze", self._analyze_query)
workflow.add_node("execute", self._execute_tools)
workflow.add_node("respond", self._format_response)
workflow.set_entry_point("analyze")
workflow.add_edge("analyze", "execute")
workflow.add_edge("execute", "respond")
workflow.add_edge("respond", END)
return workflow.compile()
def _analyze_query(self, state: AgentState) -> AgentState:
message = HumanMessage(content=state["query"])
llm_output = self._llm.invoke([message]).content
try:
plan = json.loads(llm_output)
except json.JSONDecodeError:
plan = {"action": "unknown", "params": {}}
return {"plan": plan}
async def _execute_tools(self, state: AgentState) -> AgentState:
plan = state.get("plan", {})
action = plan.get("action", "unknown")
params = plan.get("params", {})
try:
if action == "list_products":
result = await self._mcp_client.list_products()
elif action == "list_products_by_category":
result = await self._mcp_client.list_products(params.get("category"))
elif action == "get_average_price":
result = await self._mcp_client.get_statistics()
elif action == "add_product":
result = await self._mcp_client.add_product(
name=str(params["name"]),
price=float(params["price"]),
category=str(params["category"]),
in_stock=bool(params.get("in_stock", True)),
)
elif action == "discount_by_id":
product = await self._mcp_client.get_product(int(params["product_id"]))
discount_percent = float(params["discount_percent"])
discounted_price = calculate_discount(
price=float(product["price"]),
discount_percent=discount_percent,
)
result = {
"product": product,
"discount_percent": discount_percent,
"discounted_price": discounted_price,
}
else:
result = {
"error": (
"Не удалось определить действие. "
"Попробуйте запросы про список, среднюю цену, добавление или скидку."
)
}
except Exception as exc: # noqa: BLE001
LOGGER.exception("Agent tool execution failed")
result = {"error": str(exc)}
return {"tool_result": result}
def _format_response(self, state: AgentState) -> AgentState:
plan = state.get("plan", {})
action = plan.get("action", "unknown")
result = state.get("tool_result")
if isinstance(result, dict) and "error" in result:
return {"response": f"Ошибка: {result['error']}"}
if action in {"list_products", "list_products_by_category"}:
response = format_products(list(result))
elif action == "get_average_price":
response = format_statistics(dict(result))
elif action == "add_product":
product = dict(result)
response = (
"Продукт добавлен: "
f"{product['name']} (ID={product['id']}), цена {product['price']}, "
f"категория {product['category']}"
)
elif action == "discount_by_id":
payload = dict(result)
product = payload["product"]
response = (
f"Скидка {payload['discount_percent']}% для товара ID={product['id']} "
f"({product['name']}): новая цена {payload['discounted_price']}"
)
else:
response = (
"Не понял запрос. Примеры: "
"'Покажи все продукты в категории Электроника', "
"'Какая средняя цена продуктов?', "
"'Добавь новый продукт: Мышка, цена 1500, категория Электроника'."
)
return {"response": response}
async def run(self, query: str) -> dict[str, Any]:
"""Run the full graph and return user-facing output."""
state = await self._graph.ainvoke({"query": query})
plan = state.get("plan", {})
return {
"response": state.get("response", ""),
"action": plan.get("action"),
"tool_result": state.get("tool_result"),
}