structured_protocol.py•20.4 kB
"""Structured JSON protocol for optimized LLM-MCP communication.
This module defines a structured protocol for communication between:
User -> Gateway -> Router -> LLM -> MCP -> Gateway -> User
With full control over both sides, we can optimize the entire data flow.
"""
import json
import logging
from typing import Any, Dict, List, Optional, Union, Literal
from dataclasses import dataclass, asdict
from enum import Enum
import time
import uuid
logger = logging.getLogger(__name__)
class MessageType(Enum):
"""Types of messages in the structured protocol."""
QUERY = "query"
RESPONSE = "response"
ERROR = "error"
HEARTBEAT = "heartbeat"
METADATA = "metadata"
class ToolCategory(Enum):
"""Categories of MCP tools."""
HOSTS = "hosts"
VMS = "virtual_machines"
IPS = "ip_addresses"
VLANS = "vlans"
SEARCH = "search"
class ModelLocation(Enum):
"""Location of the LLM model."""
LOCAL = "local"
CLOUD = "cloud"
EDGE = "edge"
HYBRID = "hybrid"
@dataclass
class StructuredQuery:
"""Structured query from user to gateway."""
id: str
user_id: str
query: str
context: Optional[Dict[str, Any]] = None
preferences: Optional[Dict[str, Any]] = None
timestamp: float = None
def __post_init__(self):
if self.timestamp is None:
self.timestamp = time.time()
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization."""
return asdict(self)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'StructuredQuery':
"""Create from dictionary."""
return cls(**data)
@dataclass
class RouterDecision:
"""Router's decision about model and tool usage."""
query_id: str
model_location: ModelLocation
model_name: str
tools_needed: List[ToolCategory]
priority: int = 1
estimated_tokens: int = 0
estimated_cost: float = 0.0
reasoning: str = ""
timestamp: float = None
def __post_init__(self):
if self.timestamp is None:
self.timestamp = time.time()
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization."""
data = asdict(self)
data['model_location'] = self.model_location.value
data['tools_needed'] = [tool.value for tool in self.tools_needed]
return data
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'RouterDecision':
"""Create from dictionary."""
data['model_location'] = ModelLocation(data['model_location'])
data['tools_needed'] = [ToolCategory(tool) for tool in data['tools_needed']]
return cls(**data)
@dataclass
class LLMRequest:
"""Structured request from router to LLM."""
query_id: str
user_query: str
context: Dict[str, Any]
tools_available: List[Dict[str, Any]]
model_config: Dict[str, Any]
max_tokens: int = 4000
temperature: float = 0.1
stream: bool = False
timestamp: float = None
def __post_init__(self):
if self.timestamp is None:
self.timestamp = time.time()
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization."""
return asdict(self)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'LLMRequest':
"""Create from dictionary."""
return cls(**data)
@dataclass
class LLMResponse:
"""Structured response from LLM."""
query_id: str
content: str
tool_calls: List[Dict[str, Any]]
confidence: float
reasoning: str
tokens_used: int
processing_time: float
model_used: str
timestamp: float = None
def __post_init__(self):
if self.timestamp is None:
self.timestamp = time.time()
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization."""
return asdict(self)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'LLMResponse':
"""Create from dictionary."""
return cls(**data)
@dataclass
class MCPRequest:
"""Structured request to MCP server."""
query_id: str
tool_name: str
parameters: Dict[str, Any]
context: Dict[str, Any]
priority: int = 1
timeout: float = 30.0
timestamp: float = None
def __post_init__(self):
if self.timestamp is None:
self.timestamp = time.time()
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization."""
return asdict(self)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'MCPRequest':
"""Create from dictionary."""
return cls(**data)
@dataclass
class MCPResponse:
"""Structured response from MCP server."""
query_id: str
tool_name: str
data: List[Dict[str, Any]]
metadata: Dict[str, Any]
confidence: float
processing_time: float
cache_hit: bool = False
timestamp: float = None
def __post_init__(self):
if self.timestamp is None:
self.timestamp = time.time()
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization."""
return asdict(self)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'MCPResponse':
"""Create from dictionary."""
return cls(**data)
@dataclass
class FinalResponse:
"""Final structured response to user."""
query_id: str
user_id: str
answer: str
data: List[Dict[str, Any]]
sources: List[Dict[str, Any]]
confidence: float
processing_time: float
model_used: str
tools_used: List[str]
cost: float = 0.0
timestamp: float = None
def __post_init__(self):
if self.timestamp is None:
self.timestamp = time.time()
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization."""
return asdict(self)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'FinalResponse':
"""Create from dictionary."""
return cls(**data)
class StructuredProtocol:
"""Main protocol handler for structured communication."""
def __init__(self):
self.message_handlers = {
MessageType.QUERY: self._handle_query,
MessageType.RESPONSE: self._handle_response,
MessageType.ERROR: self._handle_error,
MessageType.HEARTBEAT: self._handle_heartbeat,
MessageType.METADATA: self._handle_metadata,
}
def create_query(self, user_id: str, query: str,
context: Optional[Dict[str, Any]] = None,
preferences: Optional[Dict[str, Any]] = None) -> StructuredQuery:
"""Create a structured query."""
return StructuredQuery(
id=str(uuid.uuid4()),
user_id=user_id,
query=query,
context=context or {},
preferences=preferences or {}
)
def create_router_decision(self, query_id: str,
model_location: ModelLocation,
model_name: str,
tools_needed: List[ToolCategory],
priority: int = 1,
estimated_tokens: int = 0,
estimated_cost: float = 0.0,
reasoning: str = "") -> RouterDecision:
"""Create a router decision."""
return RouterDecision(
query_id=query_id,
model_location=model_location,
model_name=model_name,
tools_needed=tools_needed,
priority=priority,
estimated_tokens=estimated_tokens,
estimated_cost=estimated_cost,
reasoning=reasoning
)
def create_llm_request(self, query_id: str, user_query: str,
context: Dict[str, Any],
tools_available: List[Dict[str, Any]],
model_config: Dict[str, Any],
max_tokens: int = 4000,
temperature: float = 0.1,
stream: bool = False) -> LLMRequest:
"""Create an LLM request."""
return LLMRequest(
query_id=query_id,
user_query=user_query,
context=context,
tools_available=tools_available,
model_config=model_config,
max_tokens=max_tokens,
temperature=temperature,
stream=stream
)
def create_llm_response(self, query_id: str, content: str,
tool_calls: List[Dict[str, Any]],
confidence: float,
reasoning: str,
tokens_used: int,
processing_time: float,
model_used: str) -> LLMResponse:
"""Create an LLM response."""
return LLMResponse(
query_id=query_id,
content=content,
tool_calls=tool_calls,
confidence=confidence,
reasoning=reasoning,
tokens_used=tokens_used,
processing_time=processing_time,
model_used=model_used
)
def create_mcp_request(self, query_id: str, tool_name: str,
parameters: Dict[str, Any],
context: Dict[str, Any],
priority: int = 1,
timeout: float = 30.0) -> MCPRequest:
"""Create an MCP request."""
return MCPRequest(
query_id=query_id,
tool_name=tool_name,
parameters=parameters,
context=context,
priority=priority,
timeout=timeout
)
def create_mcp_response(self, query_id: str, tool_name: str,
data: List[Dict[str, Any]],
metadata: Dict[str, Any],
confidence: float,
processing_time: float,
cache_hit: bool = False) -> MCPResponse:
"""Create an MCP response."""
return MCPResponse(
query_id=query_id,
tool_name=tool_name,
data=data,
metadata=metadata,
confidence=confidence,
processing_time=processing_time,
cache_hit=cache_hit
)
def create_final_response(self, query_id: str, user_id: str,
answer: str,
data: List[Dict[str, Any]],
sources: List[Dict[str, Any]],
confidence: float,
processing_time: float,
model_used: str,
tools_used: List[str],
cost: float = 0.0) -> FinalResponse:
"""Create a final response."""
return FinalResponse(
query_id=query_id,
user_id=user_id,
answer=answer,
data=data,
sources=sources,
confidence=confidence,
processing_time=processing_time,
model_used=model_used,
tools_used=tools_used,
cost=cost
)
def serialize_message(self, message: Union[StructuredQuery, RouterDecision,
LLMRequest, LLMResponse,
MCPRequest, MCPResponse,
FinalResponse]) -> str:
"""Serialize a message to JSON."""
return json.dumps(message.to_dict(), indent=2, default=str)
def deserialize_message(self, json_str: str, message_type: str) -> Any:
"""Deserialize a JSON message."""
data = json.loads(json_str)
if message_type == "query":
return StructuredQuery.from_dict(data)
elif message_type == "router_decision":
return RouterDecision.from_dict(data)
elif message_type == "llm_request":
return LLMRequest.from_dict(data)
elif message_type == "llm_response":
return LLMResponse.from_dict(data)
elif message_type == "mcp_request":
return MCPRequest.from_dict(data)
elif message_type == "mcp_response":
return MCPResponse.from_dict(data)
elif message_type == "final_response":
return FinalResponse.from_dict(data)
else:
raise ValueError(f"Unknown message type: {message_type}")
def _handle_query(self, message: StructuredQuery) -> Dict[str, Any]:
"""Handle a query message."""
return {
"type": "query_handled",
"query_id": message.id,
"status": "processing"
}
def _handle_response(self, message: FinalResponse) -> Dict[str, Any]:
"""Handle a response message."""
return {
"type": "response_handled",
"query_id": message.query_id,
"status": "completed"
}
def _handle_error(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""Handle an error message."""
return {
"type": "error_handled",
"error": message.get("error", "Unknown error"),
"status": "failed"
}
def _handle_heartbeat(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""Handle a heartbeat message."""
return {
"type": "heartbeat_handled",
"status": "alive",
"timestamp": time.time()
}
def _handle_metadata(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""Handle a metadata message."""
return {
"type": "metadata_handled",
"metadata": message,
"status": "processed"
}
class ProtocolOptimizer:
"""Optimizer for the structured protocol."""
def __init__(self):
self.protocol = StructuredProtocol()
self.cache = {}
self.metrics = {
"queries_processed": 0,
"cache_hits": 0,
"average_processing_time": 0.0,
"total_tokens_used": 0,
"total_cost": 0.0
}
def optimize_router_decision(self, query: StructuredQuery) -> RouterDecision:
"""Optimize router decision based on query analysis."""
# Analyze query to determine best model and tools
query_lower = query.query.lower()
# Determine model location based on query complexity
if any(word in query_lower for word in ["complex", "detailed", "analysis", "report"]):
model_location = ModelLocation.CLOUD
model_name = "gpt-4-turbo"
elif any(word in query_lower for word in ["quick", "simple", "list", "show"]):
model_location = ModelLocation.LOCAL
model_name = "llama-3.1-8b"
else:
model_location = ModelLocation.HYBRID
model_name = "gpt-3.5-turbo"
# Determine tools needed based on query content
tools_needed = []
if any(word in query_lower for word in ["host", "server", "device", "machine"]):
tools_needed.append(ToolCategory.HOSTS)
if any(word in query_lower for word in ["vm", "virtual", "container", "instance"]):
tools_needed.append(ToolCategory.VMS)
if any(word in query_lower for word in ["ip", "address", "network", "subnet"]):
tools_needed.append(ToolCategory.IPS)
if any(word in query_lower for word in ["vlan", "segment", "broadcast"]):
tools_needed.append(ToolCategory.VLANS)
if any(word in query_lower for word in ["search", "find", "look", "locate"]):
tools_needed.append(ToolCategory.SEARCH)
# Estimate tokens and cost
estimated_tokens = len(query.query.split()) * 2 # Rough estimation
estimated_cost = estimated_tokens * 0.0001 # Rough cost estimation
return self.protocol.create_router_decision(
query_id=query.id,
model_location=model_location,
model_name=model_name,
tools_needed=tools_needed,
priority=1,
estimated_tokens=estimated_tokens,
estimated_cost=estimated_cost,
reasoning=f"Query analysis: {len(query.query)} chars, {len(tools_needed)} tools needed"
)
def optimize_llm_request(self, query: StructuredQuery,
decision: RouterDecision) -> LLMRequest:
"""Optimize LLM request based on router decision."""
# Get available tools for the selected categories
tools_available = self._get_tools_for_categories(decision.tools_needed)
# Configure model based on location
model_config = {
"location": decision.model_location.value,
"name": decision.model_name,
"max_tokens": decision.estimated_tokens + 1000, # Add buffer
"temperature": 0.1 if "analysis" in query.query.lower() else 0.3,
"stream": decision.model_location == ModelLocation.LOCAL
}
return self.protocol.create_llm_request(
query_id=query.id,
user_query=query.query,
context=query.context,
tools_available=tools_available,
model_config=model_config,
max_tokens=model_config["max_tokens"],
temperature=model_config["temperature"],
stream=model_config["stream"]
)
def _get_tools_for_categories(self, categories: List[ToolCategory]) -> List[Dict[str, Any]]:
"""Get available tools for the specified categories."""
tools = []
for category in categories:
if category == ToolCategory.HOSTS:
tools.extend([
{"name": "list_hosts", "description": "List all hosts", "category": "hosts"},
{"name": "get_host", "description": "Get specific host", "category": "hosts"},
{"name": "search_hosts", "description": "Search hosts", "category": "hosts"}
])
elif category == ToolCategory.VMS:
tools.extend([
{"name": "list_vms", "description": "List all VMs", "category": "vms"},
{"name": "get_vm", "description": "Get specific VM", "category": "vms"},
{"name": "list_vm_interfaces", "description": "List VM interfaces", "category": "vms"}
])
elif category == ToolCategory.IPS:
tools.extend([
{"name": "list_ips", "description": "List all IPs", "category": "ips"},
{"name": "get_ip", "description": "Get specific IP", "category": "ips"},
{"name": "search_ips", "description": "Search IPs", "category": "ips"}
])
elif category == ToolCategory.VLANS:
tools.extend([
{"name": "list_vlans", "description": "List all VLANs", "category": "vlans"},
{"name": "get_vlan", "description": "Get specific VLAN", "category": "vlans"},
{"name": "list_vlan_ips", "description": "List VLAN IPs", "category": "vlans"}
])
return tools
def update_metrics(self, processing_time: float, tokens_used: int, cost: float, cache_hit: bool = False):
"""Update performance metrics."""
self.metrics["queries_processed"] += 1
if cache_hit:
self.metrics["cache_hits"] += 1
# Update average processing time
current_avg = self.metrics["average_processing_time"]
total_queries = self.metrics["queries_processed"]
self.metrics["average_processing_time"] = ((current_avg * (total_queries - 1)) + processing_time) / total_queries
self.metrics["total_tokens_used"] += tokens_used
self.metrics["total_cost"] += cost
def get_metrics(self) -> Dict[str, Any]:
"""Get current performance metrics."""
return self.metrics.copy()