test_structured_protocol.py•18 kB
"""Tests for structured protocol implementation."""
import pytest
import json
import time
from unittest.mock import Mock, patch
from src.structured_protocol import (
StructuredProtocol, ProtocolOptimizer, StructuredQuery, RouterDecision,
LLMRequest, LLMResponse, MCPRequest, MCPResponse, FinalResponse,
ToolCategory, ModelLocation, MessageType
)
class TestStructuredProtocol:
"""Test the structured protocol functionality."""
@pytest.fixture
def protocol(self):
"""Create structured protocol for testing."""
return StructuredProtocol()
def test_create_query(self, protocol):
"""Test creating a structured query."""
query = protocol.create_query(
user_id="test_user",
query="Show me all servers",
context={"department": "IT"},
preferences={"format": "json"}
)
assert query.user_id == "test_user"
assert query.query == "Show me all servers"
assert query.context == {"department": "IT"}
assert query.preferences == {"format": "json"}
assert query.id is not None
assert query.timestamp is not None
def test_create_router_decision(self, protocol):
"""Test creating a router decision."""
decision = protocol.create_router_decision(
query_id="test_query_123",
model_location=ModelLocation.LOCAL,
model_name="llama-3.1-8b",
tools_needed=[ToolCategory.HOSTS, ToolCategory.VMS],
priority=1,
estimated_tokens=1000,
estimated_cost=0.01,
reasoning="Query needs host and VM data"
)
assert decision.query_id == "test_query_123"
assert decision.model_location == ModelLocation.LOCAL
assert decision.model_name == "llama-3.1-8b"
assert ToolCategory.HOSTS in decision.tools_needed
assert ToolCategory.VMS in decision.tools_needed
assert decision.priority == 1
assert decision.estimated_tokens == 1000
assert decision.estimated_cost == 0.01
def test_create_llm_request(self, protocol):
"""Test creating an LLM request."""
tools_available = [
{"name": "list_hosts", "description": "List hosts", "category": "hosts"},
{"name": "list_vms", "description": "List VMs", "category": "vms"}
]
request = protocol.create_llm_request(
query_id="test_query_123",
user_query="Show me all servers",
context={"department": "IT"},
tools_available=tools_available,
model_config={"name": "llama-3.1-8b", "location": "local"},
max_tokens=4000,
temperature=0.1,
stream=True
)
assert request.query_id == "test_query_123"
assert request.user_query == "Show me all servers"
assert request.context == {"department": "IT"}
assert request.tools_available == tools_available
assert request.max_tokens == 4000
assert request.temperature == 0.1
assert request.stream is True
def test_create_llm_response(self, protocol):
"""Test creating an LLM response."""
tool_calls = [
{"tool_name": "list_hosts", "parameters": {"limit": 10}}
]
response = protocol.create_llm_response(
query_id="test_query_123",
content="I'll query the infrastructure data for you.",
tool_calls=tool_calls,
confidence=0.9,
reasoning="User needs server information",
tokens_used=150,
processing_time=0.5,
model_used="llama-3.1-8b"
)
assert response.query_id == "test_query_123"
assert "infrastructure data" in response.content
assert response.tool_calls == tool_calls
assert response.confidence == 0.9
assert response.tokens_used == 150
assert response.processing_time == 0.5
assert response.model_used == "llama-3.1-8b"
def test_create_mcp_request(self, protocol):
"""Test creating an MCP request."""
request = protocol.create_mcp_request(
query_id="test_query_123",
tool_name="list_hosts",
parameters={"limit": 10, "include_certainty": True},
context={"department": "IT"},
priority=1,
timeout=30.0
)
assert request.query_id == "test_query_123"
assert request.tool_name == "list_hosts"
assert request.parameters == {"limit": 10, "include_certainty": True}
assert request.context == {"department": "IT"}
assert request.priority == 1
assert request.timeout == 30.0
def test_create_mcp_response(self, protocol):
"""Test creating an MCP response."""
data = [
{"id": 1, "name": "server-01", "status": "Active"},
{"id": 2, "name": "server-02", "status": "Active"}
]
response = protocol.create_mcp_response(
query_id="test_query_123",
tool_name="list_hosts",
data=data,
metadata={"source": "netbox", "count": 2},
confidence=0.95,
processing_time=0.3,
cache_hit=False
)
assert response.query_id == "test_query_123"
assert response.tool_name == "list_hosts"
assert response.data == data
assert response.metadata == {"source": "netbox", "count": 2}
assert response.confidence == 0.95
assert response.processing_time == 0.3
assert response.cache_hit is False
def test_create_final_response(self, protocol):
"""Test creating a final response."""
data = [{"id": 1, "name": "server-01"}]
sources = [{"tool": "list_hosts", "count": 1, "confidence": 0.95}]
response = protocol.create_final_response(
query_id="test_query_123",
user_id="test_user",
answer="Here are the servers you requested.",
data=data,
sources=sources,
confidence=0.95,
processing_time=1.0,
model_used="llama-3.1-8b",
tools_used=["list_hosts"],
cost=0.01
)
assert response.query_id == "test_query_123"
assert response.user_id == "test_user"
assert "servers you requested" in response.answer
assert response.data == data
assert response.sources == sources
assert response.confidence == 0.95
assert response.processing_time == 1.0
assert response.model_used == "llama-3.1-8b"
assert response.tools_used == ["list_hosts"]
assert response.cost == 0.01
def test_serialize_deserialize(self, protocol):
"""Test serialization and deserialization."""
# Create a query
query = protocol.create_query("user123", "Test query")
# Serialize
json_str = protocol.serialize_message(query)
assert isinstance(json_str, str)
# Deserialize
deserialized = protocol.deserialize_message(json_str, "query")
assert isinstance(deserialized, StructuredQuery)
assert deserialized.user_id == query.user_id
assert deserialized.query == query.query
def test_message_handlers(self, protocol):
"""Test message handlers."""
# Test query handler
query = protocol.create_query("user123", "Test query")
result = protocol._handle_query(query)
assert result["type"] == "query_handled"
assert result["query_id"] == query.id
# Test response handler
response = protocol.create_final_response(
"test123", "user123", "Test answer", [], [], 0.9, 1.0, "test-model", []
)
result = protocol._handle_response(response)
assert result["type"] == "response_handled"
assert result["query_id"] == response.query_id
# Test error handler
error_msg = {"error": "Test error"}
result = protocol._handle_error(error_msg)
assert result["type"] == "error_handled"
assert result["error"] == "Test error"
class TestProtocolOptimizer:
"""Test the protocol optimizer functionality."""
@pytest.fixture
def optimizer(self):
"""Create protocol optimizer for testing."""
return ProtocolOptimizer()
def test_optimize_router_decision_simple_query(self, optimizer):
"""Test router decision for simple query."""
query = StructuredQuery(
id="test123",
user_id="user123",
query="show me all servers",
context={},
preferences={}
)
decision = optimizer.optimize_router_decision(query)
assert decision.query_id == "test123"
assert decision.model_location == ModelLocation.LOCAL # Simple query -> local
assert "llama" in decision.model_name.lower()
assert ToolCategory.HOSTS in decision.tools_needed
assert decision.estimated_tokens > 0
assert decision.estimated_cost >= 0
def test_optimize_router_decision_complex_query(self, optimizer):
"""Test router decision for complex query."""
query = StructuredQuery(
id="test123",
user_id="user123",
query="complex detailed analysis of all infrastructure components",
context={},
preferences={}
)
decision = optimizer.optimize_router_decision(query)
assert decision.query_id == "test123"
assert decision.model_location == ModelLocation.CLOUD # Complex query -> cloud
assert "gpt" in decision.model_name.lower()
assert len(decision.tools_needed) >= 0 # Complex queries might not need specific tools
assert decision.estimated_tokens > 0
def test_optimize_router_decision_search_query(self, optimizer):
"""Test router decision for search query."""
query = StructuredQuery(
id="test123",
user_id="user123",
query="search for web servers in production",
context={},
preferences={}
)
decision = optimizer.optimize_router_decision(query)
assert decision.query_id == "test123"
assert ToolCategory.HOSTS in decision.tools_needed
assert ToolCategory.SEARCH in decision.tools_needed
def test_optimize_llm_request(self, optimizer):
"""Test LLM request optimization."""
query = StructuredQuery(
id="test123",
user_id="user123",
query="show me all servers",
context={"department": "IT"},
preferences={}
)
decision = RouterDecision(
query_id="test123",
model_location=ModelLocation.LOCAL,
model_name="llama-3.1-8b",
tools_needed=[ToolCategory.HOSTS],
priority=1,
estimated_tokens=100,
estimated_cost=0.01,
reasoning="Test reasoning"
)
request = optimizer.optimize_llm_request(query, decision)
assert request.query_id == "test123"
assert request.user_query == "show me all servers"
assert request.context == {"department": "IT"}
assert len(request.tools_available) > 0
assert request.model_config["location"] == "local"
assert request.model_config["name"] == "llama-3.1-8b"
def test_get_tools_for_categories(self, optimizer):
"""Test getting tools for specific categories."""
categories = [ToolCategory.HOSTS, ToolCategory.VMS]
tools = optimizer._get_tools_for_categories(categories)
assert len(tools) > 0
tool_names = [tool["name"] for tool in tools]
assert any("host" in name for name in tool_names)
assert any("vm" in name for name in tool_names)
def test_update_metrics(self, optimizer):
"""Test metrics update."""
initial_queries = optimizer.metrics["queries_processed"]
optimizer.update_metrics(
processing_time=1.0,
tokens_used=100,
cost=0.01,
cache_hit=True
)
assert optimizer.metrics["queries_processed"] == initial_queries + 1
assert optimizer.metrics["cache_hits"] == 1
assert optimizer.metrics["total_tokens_used"] == 100
assert optimizer.metrics["total_cost"] == 0.01
def test_get_metrics(self, optimizer):
"""Test getting metrics."""
metrics = optimizer.get_metrics()
assert "queries_processed" in metrics
assert "cache_hits" in metrics
assert "average_processing_time" in metrics
assert "total_tokens_used" in metrics
assert "total_cost" in metrics
class TestDataStructures:
"""Test the data structures."""
def test_structured_query_to_dict(self):
"""Test StructuredQuery to_dict conversion."""
query = StructuredQuery(
id="test123",
user_id="user123",
query="Test query",
context={"key": "value"},
preferences={"format": "json"}
)
data = query.to_dict()
assert data["id"] == "test123"
assert data["user_id"] == "user123"
assert data["query"] == "Test query"
assert data["context"] == {"key": "value"}
assert data["preferences"] == {"format": "json"}
def test_router_decision_to_dict(self):
"""Test RouterDecision to_dict conversion."""
decision = RouterDecision(
query_id="test123",
model_location=ModelLocation.LOCAL,
model_name="llama-3.1-8b",
tools_needed=[ToolCategory.HOSTS, ToolCategory.VMS],
priority=1,
estimated_tokens=100,
estimated_cost=0.01,
reasoning="Test reasoning"
)
data = decision.to_dict()
assert data["query_id"] == "test123"
assert data["model_location"] == "local"
assert data["model_name"] == "llama-3.1-8b"
assert data["tools_needed"] == ["hosts", "virtual_machines"]
assert data["priority"] == 1
assert data["estimated_tokens"] == 100
assert data["estimated_cost"] == 0.01
def test_llm_request_to_dict(self):
"""Test LLMRequest to_dict conversion."""
request = LLMRequest(
query_id="test123",
user_query="Test query",
context={"key": "value"},
tools_available=[{"name": "test_tool", "category": "test"}],
model_config={"name": "test-model"},
max_tokens=1000,
temperature=0.1,
stream=True
)
data = request.to_dict()
assert data["query_id"] == "test123"
assert data["user_query"] == "Test query"
assert data["context"] == {"key": "value"}
assert data["max_tokens"] == 1000
assert data["temperature"] == 0.1
assert data["stream"] is True
def test_mcp_response_to_dict(self):
"""Test MCPResponse to_dict conversion."""
response = MCPResponse(
query_id="test123",
tool_name="test_tool",
data=[{"id": 1, "name": "test"}],
metadata={"source": "test"},
confidence=0.9,
processing_time=1.0,
cache_hit=True
)
data = response.to_dict()
assert data["query_id"] == "test123"
assert data["tool_name"] == "test_tool"
assert data["data"] == [{"id": 1, "name": "test"}]
assert data["metadata"] == {"source": "test"}
assert data["confidence"] == 0.9
assert data["processing_time"] == 1.0
assert data["cache_hit"] is True
def test_final_response_to_dict(self):
"""Test FinalResponse to_dict conversion."""
response = FinalResponse(
query_id="test123",
user_id="user123",
answer="Test answer",
data=[{"id": 1}],
sources=[{"tool": "test"}],
confidence=0.9,
processing_time=1.0,
model_used="test-model",
tools_used=["test_tool"],
cost=0.01
)
data = response.to_dict()
assert data["query_id"] == "test123"
assert data["user_id"] == "user123"
assert data["answer"] == "Test answer"
assert data["data"] == [{"id": 1}]
assert data["sources"] == [{"tool": "test"}]
assert data["confidence"] == 0.9
assert data["processing_time"] == 1.0
assert data["model_used"] == "test-model"
assert data["tools_used"] == ["test_tool"]
assert data["cost"] == 0.01
class TestEnums:
"""Test the enum classes."""
def test_message_type_enum(self):
"""Test MessageType enum."""
assert MessageType.QUERY.value == "query"
assert MessageType.RESPONSE.value == "response"
assert MessageType.ERROR.value == "error"
assert MessageType.HEARTBEAT.value == "heartbeat"
assert MessageType.METADATA.value == "metadata"
def test_tool_category_enum(self):
"""Test ToolCategory enum."""
assert ToolCategory.HOSTS.value == "hosts"
assert ToolCategory.VMS.value == "virtual_machines"
assert ToolCategory.IPS.value == "ip_addresses"
assert ToolCategory.VLANS.value == "vlans"
assert ToolCategory.SEARCH.value == "search"
def test_model_location_enum(self):
"""Test ModelLocation enum."""
assert ModelLocation.LOCAL.value == "local"
assert ModelLocation.CLOUD.value == "cloud"
assert ModelLocation.EDGE.value == "edge"
assert ModelLocation.HYBRID.value == "hybrid"