"""
Query processor: The orchestrator that connects everything.
This is the BRAIN of the system. It:
1. Takes user queries
2. Gets available tools from MCP server (filtered by RBAC)
3. Calls the LLM provider (OpenAI/Claude/Gemini) with tools
4. Executes any tool calls via MCP
5. Feeds results back to LLM
6. Returns final answer
KEY FEATURE: This is completely LLM-agnostic!
It works with ANY provider through the abstraction layer.
"""
from __future__ import annotations
import json
import time
from typing import List, Dict, Any, Optional, Tuple
from contextlib import asynccontextmanager
from uuid import uuid4
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from src.config import settings
from src.llm import get_llm_provider, ToolCall, ToolCallStatus, BaseLLMProvider
from src.core import Role, UserContext, RBAC
from src.observability import get_logger, metrics
logger = get_logger(__name__)
async def handle_sampling_request(
sampling_request: Dict[str, Any],
provider: BaseLLMProvider,
trace_id: str
) -> str:
"""
Handle MCP sampling request by calling client LLM for formatting.
This is the core of MCP sampling - it takes the raw data and formatting
instructions from the tool, then calls the client's LLM to format it
according to the user's preferences.
Args:
sampling_request: The sampling request from the tool
provider: LLM provider to use for formatting
trace_id: Trace ID for logging
Returns:
Formatted response string
Example:
Tool returns: {"vendors": [...], "format": "table"}
This function: Calls LLM to format as table
Returns: "| Name | Location | ... |"
"""
try:
logger.info(
"sampling_request_started",
trace_id=trace_id,
has_messages=bool(sampling_request.get("messages")),
has_preferences=bool(sampling_request.get("modelPreferences"))
)
# Extract sampling messages
messages = sampling_request.get("messages", [])
if not messages:
logger.warning("sampling_request_no_messages", trace_id=trace_id)
return "No formatting instructions provided"
# Extract model preferences
model_prefs = sampling_request.get("modelPreferences", {})
temperature = model_prefs.get("temperature", settings.sampling_temperature)
max_tokens = model_prefs.get("max_tokens", settings.sampling_max_tokens)
# Add system prompt if provided
system_prompt = sampling_request.get("systemPrompt")
if system_prompt:
messages.insert(0, {"role": "system", "content": system_prompt})
logger.debug(
"calling_llm_for_sampling",
trace_id=trace_id,
message_count=len(messages),
temperature=temperature,
max_tokens=max_tokens
)
# Call LLM with sampling instructions
# The LLM will format the data according to the instructions
response = await provider.generate(
messages=messages,
tools=None, # No tools for sampling
temperature=temperature,
max_tokens=max_tokens
)
logger.info(
"sampling_request_completed",
trace_id=trace_id,
response_length=len(response.content) if response else 0
)
return response.content if response else "Error formatting response."
except Exception as e:
logger.error(
"sampling_request_failed",
trace_id=trace_id,
error=str(e),
error_type=type(e).__name__
)
# Return raw data as fallback
return "Error formatting response. Please try again."
@asynccontextmanager
async def connect_to_mcp_server():
"""
Connect to the MCP server.
This creates a client connection to the MCP server running
on the configured host/port.
Yields:
ClientSession: Connected MCP client session
"""
mcp_url = f"http://{settings.mcp_server_host}:{settings.mcp_server_port}{settings.mcp_path}"
logger.info("connecting_to_mcp_server", url=mcp_url)
async with streamablehttp_client(mcp_url) as (read_stream, write_stream, _):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
logger.info("mcp_connection_established")
yield session
async def get_available_tools(
session: ClientSession,
user_context: UserContext,
) -> List[Dict[str, Any]]:
"""Get tools available based on RBAC and current context.
Get tools available to the user based on their role.
This enforces RBAC by filtering tools based on user permissions.
Args:
session: MCP client session
user_context: User's context (role, org, etc.)
Returns:
List of tools in LLM-compatible format
"""
# Get all tools from MCP server
tool_result = await session.list_tools()
# Get tool policies
from src.core.domain_policies import ALL_POLICIES
# Get current context
from src.web_chat.context import get_current_context
context = get_current_context()
logger.info("current_context debug by sai", context=context)
if not context:
logger.error("no_context_for_tool_filtering")
raise RuntimeError(
"Cannot get available tools without page context"
)
# Log all available tools for debugging
logger.info("available_mcp_tools", tool_count=len(tool_result.tools), tools=[t.name for t in tool_result.tools])
logger.info("user role", role=user_context.role.name, role_value=user_context.role.value)
# Get tool policies
from src.core.domain_policies import ALL_POLICIES
logger.info("tool policies debug by sai", tool_policies=ALL_POLICIES)
# Filter by RBAC and context policy
allowed_tools = []
for tool in tool_result.tools:
tool_name = tool.name
# Step 1: Check RBAC
if not RBAC.is_tool_allowed(user_context.role, tool_name):
logger.debug(
"tool_blocked_by_rbac",
tool=tool_name,
role=user_context.role.name
)
continue
# Step 2: Check tool has policy
policy = ALL_POLICIES.get(tool_name)
if not policy:
logger.warning(
"tool_missing_policy",
tool_name=tool.name
)
continue
# Step 3: Check context policy
if not context.matches_policy(policy):
logger.debug(
"tool_blocked_by_context",
tool=tool_name,
page_type=context.page_type,
allowed_pages=[pt.value for pt in policy.allowed_page_types]
)
continue
# Tool passed all checks - convert to LLM format
input_schema = getattr(tool, "inputSchema", None) or getattr(tool, "input_schema", None) or {}
allowed_tools.append({
"type": "function",
"name": tool.name,
"description": tool.description or "",
"parameters": input_schema,
})
logger.info(
"tools_filtered",
user_role=user_context.role.name,
page_type=context.page_type.value if context else None,
total_tools=len(tool_result.tools),
rbac_allowed=len([t for t in tool_result.tools if RBAC.is_tool_allowed(user_context.role, t.name)]),
context_allowed=len(allowed_tools),
final_tools=[t["name"] for t in allowed_tools]
)
return allowed_tools
async def execute_tool(
session: ClientSession,
tool_call: ToolCall,
user_context: UserContext,
provider: Optional[BaseLLMProvider] = None,
) -> ToolCall:
"""Execute tool with context validation
Execute a tool via MCP and return the result.
This injects user context into tool arguments and executes
the tool through the MCP server.
Args:
session: MCP client session
tool_call: Tool call to execute
user_context: User's context
Returns:
ToolCall with result populated
"""
start_time = time.time()
try:
# Get and validate context
from src.web_chat.context import get_current_context
from src.core.domain_policies import ALL_POLICIES
context = get_current_context()
if not context:
logger.error(
"no_context_for_tool_execution",
tool_name=tool_call.name
)
raise RuntimeError(
"Cannot execute tool without page context"
)
# Get tool policy
policy = ALL_POLICIES.get(tool_call.name)
if not policy:
logger.error(
"tool_missing_policy_at_execution",
tool_name=tool_call.name
)
raise RuntimeError(
f"Tool {tool_call.name} has no context policy"
)
# Validate context policy
if not context.matches_policy(policy):
logger.error(
"tool_context_mismatch",
tool_name=tool_call.name,
page_type=context.page_type,
entity_type=context.entity_type
)
raise RuntimeError(
f"Tool {tool_call.name} cannot be used in current context"
)
# If tool requires entity, inject from context
if policy.requires_entity:
if not context.has_entity():
raise RuntimeError(
f"Tool {tool_call.name} requires an entity"
)
if context.entity_type != policy.entity_type:
raise RuntimeError(
f"Tool {tool_call.name} requires entity type {policy.entity_type}"
)
# Override any provided entity ID with context entity
args["entity_id"] = context.entity_id
logger.info(
"injected_context_entity",
tool_name=tool_call.name,
entity_id=context.entity_id
)
# Inject user context into arguments
args = dict(tool_call.arguments)
args["user_id"] = str(user_context.user_id) # Convert to string
args["organization_id"] = user_context.organization_id
args["role"] = user_context.role.value # Pass role as integer
args["trace_id"] = tool_call.id
if user_context.platform_id:
args["platform_id"] = user_context.platform_id
if user_context.dealership_id:
args["dealership_id"] = user_context.dealership_id
if user_context.bearer_token:
args["bearer_token"] = user_context.bearer_token
if user_context.email:
args["email"] = user_context.email
if user_context.name:
args["name"] = user_context.name
if user_context.permissions:
args["permissions"] = user_context.permissions
logger.info(
"executing_tool",
tool_name=tool_call.name,
user_id=user_context.user_id,
role=user_context.role.name
)
# Execute via MCP
result = await session.call_tool(tool_call.name, arguments=args)
# Parse result
content = getattr(result, "content", None)
if isinstance(content, list) and content:
# Extract JSON from MCP response
for item in content:
if hasattr(item, "json"):
# Call the json property/method to get the actual data
json_data = item.json
tool_call.result = json_data() if callable(json_data) else json_data
break
elif hasattr(item, "text"):
text_data = item.text
text_str = text_data() if callable(text_data) else text_data
try:
tool_call.result = json.loads(text_str)
except json.JSONDecodeError:
tool_call.result = {"text": text_str}
break
if not tool_call.result:
tool_call.result = {"raw": str(result)}
# CRITICAL: Check for MCP sampling request
# If the tool returned a sampling request, we need to handle it
if settings.enable_sampling and isinstance(tool_call.result, dict):
meta = tool_call.result.get("_meta", {})
if meta.get("sampling") and provider:
logger.info(
"sampling_detected",
tool_name=tool_call.name,
trace_id=tool_call.id
)
# Extract sampling request
sampling_req = tool_call.result.get("sampling_request")
if sampling_req:
# Call LLM to format the response
formatted_response = await handle_sampling_request(
sampling_request=sampling_req,
provider=provider,
trace_id=tool_call.id
)
# Replace result with formatted response
tool_call.result = {
"meta": {
"tool_name": tool_call.name,
"trace_id": tool_call.id,
"formatted_via_sampling": True
},
"data": {
"formatted_response": formatted_response
}
}
logger.info(
"sampling_completed",
tool_name=tool_call.name,
trace_id=tool_call.id,
response_length=len(formatted_response)
)
tool_call.status = ToolCallStatus.COMPLETED
duration = time.time() - start_time
logger.info(
"tool_executed",
tool_name=tool_call.name,
duration_ms=int(duration * 1000),
success=True
)
except Exception as e:
tool_call.status = ToolCallStatus.FAILED
tool_call.error = str(e)
duration = time.time() - start_time
logger.error(
"tool_execution_failed",
tool_name=tool_call.name,
error=str(e),
error_type=type(e).__name__,
duration_ms=int(duration * 1000)
)
return tool_call
async def process_query(
query: str,
session: ClientSession,
user_context: UserContext,
history: Optional[List[Dict[str, Any]]] = None,
) -> Tuple[str, List[Dict[str, Any]]]:
"""
Process a user query with LLM + MCP tools.
This is the main orchestration function that:
1. Gets available tools (RBAC-filtered)
2. Calls LLM with query + tools
3. Executes any tool calls
4. Returns final answer
KEY: This works with ANY LLM provider (OpenAI, Claude, Gemini)!
Args:
query: User's question
session: MCP client session
user_context: User's context (role, IDs, etc.)
history: Previous conversation history
Returns:
Tuple of (final_answer, updated_conversation_history)
"""
trace_id = str(uuid4())
start_time = time.time()
logger.info(
"query_processing_started",
trace_id=trace_id,
user_id=user_context.user_id,
role=user_context.role.name,
query_length=len(query),
)
try:
# Get LLM provider (OpenAI/Claude/Gemini based on config)
provider = get_llm_provider()
logger.info(
"llm_provider_selected",
trace_id=trace_id,
provider=provider.provider_name,
model=provider.model_name
)
# Get available tools (RBAC-filtered)
tools = await get_available_tools(session, user_context)
# Build conversation
messages = list(history or [])
# Add system message with role context
system_msg = (
f"You are an AI assistant for Updation, a car dealership management platform. "
f"The user is a {user_context.role.name} with access to specific tools. "
f"Use the available tools to answer their questions accurately. "
f"Always provide clear, helpful responses."
)
messages.insert(0, {"role": "system", "content": system_msg})
# Add user query
messages.append({"role": "user", "content": query})
# Multi-turn tool calling loop
max_turns = 10
for turn in range(max_turns):
logger.debug(
"llm_generation_turn",
trace_id=trace_id,
turn=turn + 1,
max_turns=max_turns
)
# Call LLM
llm_start = time.time()
response = await provider.generate(
messages=messages,
tools=tools if tools else None,
)
llm_duration = time.time() - llm_start
# Track metrics
metrics.track_llm_call(
provider=provider.provider_name,
model=provider.model_name,
prompt_tokens=response.usage.get("prompt_tokens", 0),
completion_tokens=response.usage.get("completion_tokens", 0),
duration=llm_duration,
success=True
)
# If no tool calls, we're done
if not response.has_tool_calls:
final_answer = response.content
# Add assistant response to history
messages.append({"role": "assistant", "content": final_answer})
duration = time.time() - start_time
logger.info(
"query_processing_completed",
trace_id=trace_id,
turns=turn + 1,
duration_ms=int(duration * 1000),
answer_length=len(final_answer)
)
return final_answer, messages
# Execute tool calls
logger.info(
"executing_tool_calls",
trace_id=trace_id,
tool_count=len(response.tool_calls)
)
# Add assistant's tool call message
messages.append({
"role": "assistant",
"content": response.content or None,
"tool_calls": [
{
"id": tc.id,
"type": "function", # Required by OpenAI
"function": {
"name": tc.name,
"arguments": json.dumps(tc.arguments),
}
}
for tc in response.tool_calls
]
})
# Execute each tool
for tool_call in response.tool_calls:
executed = await execute_tool(session, tool_call, user_context, provider)
# Add tool result to conversation
tool_result_msg = provider.format_tool_result(executed)
messages.append(tool_result_msg)
# Max turns exceeded
logger.warning(
"max_turns_exceeded",
trace_id=trace_id,
max_turns=max_turns
)
return (
"I apologize, but I couldn't complete your request after multiple attempts. Please try rephrasing your question.",
messages
)
except Exception as e:
duration = time.time() - start_time
logger.error(
"query_processing_failed",
trace_id=trace_id,
error=str(e),
error_type=type(e).__name__,
duration_ms=int(duration * 1000)
)
return (
f"I encountered an error processing your request: {str(e)}",
messages if 'messages' in locals() else []
)