from omnicoreagent import OmniAgent, MemoryRouter, EventRouter, logger
from utils.jwt_util import create_server_token
import asyncio
import os
from typing import Optional, Dict, Any
# Generate token dynamically
AUTH_TOKEN = create_server_token("tutoring-agent-mcp-client")
# Get RAG server URL from environment variable (Heroku) or use default (local Docker)
RAG_SERVER_URL = os.getenv("RAG_MCP_SERVER_URL", "http://rag_mcp_server:9000/mcp")
MCP_TOOLS = [
{
"name": "turtor_rag",
"transport_type": "streamable_http",
"url": RAG_SERVER_URL,
"headers": {"Authorization": f"Bearer {AUTH_TOKEN}"},
}
]
class TutoringRagAgent:
def __init__(self):
self.memory_router = None
self.event_router = None
self.agent = None
self.mcp_client = None
self.base_system_instruction = None
self._agent_lock = asyncio.Lock() # Prevent concurrent agent runs
async def initialized(self):
"""Initialize the TutoringRagAgent server."""
# Create memory and event routers
self.memory_router = MemoryRouter(memory_store_type="redis")
self.event_router = EventRouter(event_store_type="in_memory")
# Import and store base system instruction
from mcp_host.mcp_agent.system_prompt import base_system_instruction
self.base_system_instruction = base_system_instruction
# Create the OmniAgent with base instruction
self.agent = OmniAgent(
name="TutoringRagAgent",
system_instruction=self.base_system_instruction,
model_config={
"provider": "openai",
"model": "gpt-4.1",
"temperature": 0.3,
"max_context_length": 5000,
"top_p": 0.7,
},
mcp_tools=MCP_TOOLS,
agent_config={
"agent_name": "TutoringRagAgent",
"max_steps": 15,
"tool_call_timeout": 60,
"request_limit": 0,
"total_tokens_limit": 0,
# --- Memory Retrieval Config ---
"memory_config": {"mode": "sliding_window", "value": 100},
"memory_results_limit": 5,
"memory_similarity_threshold": 0.5,
# --- Tool Retrieval Config ---
"enable_tools_knowledge_base": False,
"tools_results_limit": 10,
"tools_similarity_threshold": 0.1,
"memory_tool_backend": "local",
},
memory_router=self.memory_router,
event_router=self.event_router,
debug=True,
)
# Connect to MCP servers and log results
logger.info("Connecting to MCP servers...")
await self.agent.connect_mcp_servers()
self.mcp_client = self.agent.mcp_client
# CRITICAL FIX: Wrap MCP client to preprocess tool arguments
# This ensures query arrays are converted to strings BEFORE sending to RAG server
await self._wrap_mcp_client_for_preprocessing()
# Log available tools
if hasattr(self.mcp_client, "sessions"):
logger.info(f"MCP Sessions: {list(self.mcp_client.sessions.keys())}")
for server_name, session_info in self.mcp_client.sessions.items():
logger.info(
f"Server '{server_name}' tools: {session_info.get('tools', [])}"
)
logger.info("MCP servers connected successfully")
async def _wrap_mcp_client_for_preprocessing(self):
"""
Wrap the MCP client's call_tool method to preprocess arguments.
Specifically converts query arrays to strings for knowledge_base_retrieval.
"""
if not hasattr(self.mcp_client, "sessions"):
logger.warning("MCP client has no sessions attribute, skipping preprocessing wrapper")
return
logger.info("π§ Wrapping MCP client for argument preprocessing...")
# Wrap each session's call_tool method
for server_name, session_info in self.mcp_client.sessions.items():
if "session" in session_info:
session = session_info["session"]
original_call_tool = session.call_tool
async def wrapped_call_tool(tool_name: str, arguments: dict, _original=original_call_tool, _server=server_name):
"""Wrapper that preprocesses tool arguments before calling"""
# Preprocess knowledge_base_retrieval query parameter
if tool_name == "knowledge_base_retrieval" and "query" in arguments:
query_value = arguments["query"]
if isinstance(query_value, list):
# Convert array to comma-separated string
converted_query = ", ".join(str(item) for item in query_value)
arguments = {**arguments, "query": converted_query}
logger.info(f"π MCP HOST PREPROCESSING: Converted query array to string")
logger.info(f" Server: {_server}")
logger.info(f" Tool: {tool_name}")
logger.info(f" Original: {query_value}")
logger.info(f" Converted: {converted_query}")
# Call the original method with preprocessed arguments
return await _original(tool_name, arguments)
# Replace the method
session.call_tool = wrapped_call_tool
logger.info(f"β
Wrapped call_tool for server: {server_name}")
async def handle_query(
self,
query: str,
session_id: str = None,
user_context: Optional[Dict[str, Any]] = None,
) -> dict:
"""Handle a user query and return the agent's response.
Args:
query: The user's question
session_id: Session identifier for conversation continuity
user_context: Dictionary containing user information (user_id, email, name)
"""
# Acquire lock to prevent concurrent agent execution
async with self._agent_lock:
try:
logger.info(f"Agent received query: {query[:100]}...")
logger.info(f"Session ID: {session_id}")
logger.info(f"User context: {user_context}")
# ============= FIX: Inject user_id and chat history into system instruction =============
if user_context and "user_id" in user_context:
user_id = user_context["user_id"]
logger.info(f"π Injecting user_id into system instruction: {user_id}")
# Build chat history context if available
chat_history_context = ""
if "chat_history" in user_context and user_context["chat_history"]:
history_messages = user_context["chat_history"]
logger.info(f"π Including {len(history_messages)} historical messages")
history_text = "\n".join([
f"[{msg.get('timestamp', 'Unknown time')}] {msg['role'].upper()}: {msg['content'][:200]}{'...' if len(msg['content']) > 200 else ''}"
for msg in history_messages
])
chat_history_context = f"""
<previous_conversation_history>
The student has had previous conversations with you. Here are their recent messages for context:
{history_text}
Remember these previous interactions when responding. Reference them if relevant, but focus on the current question.
</previous_conversation_history>
"""
# Create a modified system instruction with the actual user_id and history
modified_instruction = f"""{self.base_system_instruction}
<current_session_context>
π CURRENT USER ID: {user_id}
CRITICAL: When calling knowledge_base_retrieval, you MUST use this EXACT user_id value:
user_id="{user_id}"
DO NOT use placeholder text. USE THE VALUE ABOVE.
Current user information:
- User ID: {user_id}
- Name: {user_context.get("name", "Unknown")}
- Email: {user_context.get("email", "Unknown")}
</current_session_context>
{chat_history_context}
"""
# Update the agent's system instruction for this query
self.agent.system_instruction = modified_instruction
logger.info("β
System instruction updated with user_id and chat history")
else:
logger.warning("β οΈ No user context provided - using base instruction")
# =======================================================================
# Run the agent with timeout
logger.info("Starting agent.run()...")
result = await asyncio.wait_for(
self.agent.run(query, session_id), timeout=45.0
)
logger.info(f"Agent.run() completed")
return result
except asyncio.TimeoutError:
logger.error("Agent.run() timed out after 45 seconds")
return {
"response": "I apologize, but processing your request took too long. Please try a simpler question.",
"session_id": session_id or "timeout_session",
}
except Exception as e:
logger.error(f"Failed to process query: {e}", exc_info=True)
return {
"response": f"I apologize, but I encountered an error: {str(e)}",
"session_id": session_id or "error_session",
}
async def get_session_history(self, session_id: str) -> list[dict]:
"""Get conversation history for a session."""
try:
return await self.agent.get_session_history(session_id)
except Exception as e:
logger.error(f"Failed to get session history: {e}")
return []
async def clear_session_memory(self, session_id: str) -> bool:
"""Clear memory for a specific session."""
try:
await self.agent.clear_session_history(session_id)
logger.info(f"Cleared memory for session: {session_id}")
return True
except Exception as e:
logger.error(f"Failed to clear session memory: {e}")
return False
def get_agent_info(self) -> dict:
"""Get information about the agent configuration."""
return {
"agent_name": self.agent.name,
"memory_store_type": "redis",
"memory_store_info": self.memory_router.get_memory_store_info(),
"event_store_type": self.agent.get_event_store_type(),
"debug_mode": self.agent.debug,
}