"""Tiny LLM for testing MCP server performance.
This module provides a minimal LLM implementation for testing purposes.
Uses a very small model that can run in GitHub Actions CI.
"""
import logging
import random
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
@dataclass
class TinyLLMResponse:
"""Response from the tiny LLM."""
content: str
tokens_used: int
processing_time: float
model_name: str
confidence: float
class TinyLLM:
"""Tiny LLM implementation for testing purposes.
This is a mock LLM that simulates real LLM behavior with:
- Configurable response times (to simulate slow models)
- Token counting
- Confidence scoring
- Tool calling simulation
"""
def __init__(
self,
model_name: str = "tiny-test-llm",
base_delay: float = 0.5,
max_delay: float = 2.0,
token_per_char: float = 0.25,
):
"""Initialize the tiny LLM.
Args:
model_name: Name of the model
base_delay: Base delay in seconds for response simulation
max_delay: Maximum delay in seconds
token_per_char: Tokens per character ratio for estimation
"""
self.model_name = model_name
self.base_delay = base_delay
self.max_delay = max_delay
self.token_per_char = token_per_char
self.logger = logging.getLogger(self.__class__.__name__)
# Simple response templates for different types of queries
self.response_templates = {
"hosts": [
"I found {count} host(s) in your infrastructure. Here are the details:",
"Based on your query, I located {count} server(s) that match your criteria:",
"Here are the {count} host(s) I discovered in your network:",
],
"vms": [
"I identified {count} virtual machine(s) in your environment:",
"Found {count} VM(s) that match your search criteria:",
"Here are the {count} virtual machine(s) I found:",
],
"ips": [
"I located {count} IP address(es) in your network:",
"Found {count} IP(s) that match your query:",
"Here are the {count} IP address(es) I discovered:",
],
"vlans": [
"I found {count} VLAN(s) in your infrastructure:",
"Located {count} VLAN(s) that match your criteria:",
"Here are the {count} VLAN(s) I identified:",
],
"search": [
"I searched your infrastructure and found {count} matching item(s):",
"Based on your search query, I located {count} result(s):",
"Here are the {count} item(s) that match your search:",
],
"default": [
"I processed your request and found {count} relevant item(s):",
"Based on your query, I located {count} matching result(s):",
"Here are the {count} item(s) I found for you:",
],
}
async def generate_response(
self,
prompt: str,
tools: Optional[List[Dict[str, Any]]] = None,
max_tokens: int = 100,
) -> TinyLLMResponse:
"""Generate a response from the tiny LLM.
Args:
prompt: Input prompt
tools: Available tools (optional)
max_tokens: Maximum tokens to generate
Returns:
TinyLLMResponse with generated content
"""
start_time = time.time()
# Simulate processing delay (like a real slow LLM)
delay = self._calculate_delay(prompt, tools)
await self._simulate_delay(delay)
# Generate response content
content = self._generate_content(prompt, tools)
# Calculate tokens used
tokens_used = int(len(content) * self.token_per_char)
# Calculate confidence based on prompt complexity
confidence = self._calculate_confidence(prompt, tools)
processing_time = time.time() - start_time
self.logger.info(
f"TinyLLM generated response: {tokens_used} tokens, {processing_time:.2f}s"
)
return TinyLLMResponse(
content=content,
tokens_used=tokens_used,
processing_time=processing_time,
model_name=self.model_name,
confidence=confidence,
)
def _calculate_delay(
self, prompt: str, tools: Optional[List[Dict[str, Any]]]
) -> float:
"""Calculate response delay based on prompt complexity."""
base_delay = self.base_delay
# Add delay based on prompt length
length_factor = len(prompt) / 1000 # 1 second per 1000 chars
# Add delay based on number of tools
tool_factor = len(tools) * 0.1 if tools else 0
# Add random variation
random_factor = random.uniform(0.5, 1.5)
total_delay = (
base_delay + length_factor + tool_factor
) * random_factor
return min(total_delay, self.max_delay)
async def _simulate_delay(self, delay: float):
"""Simulate processing delay."""
await asyncio.sleep(delay)
def _generate_content(
self, prompt: str, tools: Optional[List[Dict[str, Any]]]
) -> str:
"""Generate response content based on prompt."""
# Determine query type
query_type = self._classify_query(prompt)
# Get response template
templates = self.response_templates.get(
query_type, self.response_templates["default"]
)
template = random.choice(templates)
# Generate mock data count
count = random.randint(1, 10)
# Generate tool calls if tools are available
tool_calls = []
if tools:
# Select 1-3 random tools to "call"
selected_tools = random.sample(
tools, min(len(tools), random.randint(1, 3))
)
for tool in selected_tools:
tool_calls.append(
{
"tool_name": tool["name"],
"parameters": {
"limit": random.randint(5, 20),
"include_certainty": True,
},
"reasoning": f"Need to query {tool.get('category', 'data')} for user request",
}
)
# Build response
response = template.format(count=count)
if tool_calls:
response += "\n\nI'll use the following tools to gather detailed information:\n"
for i, tool_call in enumerate(tool_calls, 1):
response += f"{i}. {tool_call['tool_name']} - {tool_call['reasoning']}\n"
# Add some additional context
response += f"\n\nThis response was generated by {self.model_name} and contains {count} relevant items from your infrastructure."
return response
def _classify_query(self, prompt: str) -> str:
"""Classify the query type based on keywords."""
prompt_lower = prompt.lower()
if any(
word in prompt_lower
for word in ["host", "server", "device", "machine"]
):
return "hosts"
elif any(
word in prompt_lower
for word in ["vm", "virtual", "container", "instance"]
):
return "vms"
elif any(
word in prompt_lower
for word in ["ip", "address", "network", "subnet"]
):
return "ips"
elif any(
word in prompt_lower for word in ["vlan", "segment", "broadcast"]
):
return "vlans"
elif any(
word in prompt_lower
for word in ["search", "find", "look", "locate"]
):
return "search"
else:
return "default"
def _calculate_confidence(
self, prompt: str, tools: Optional[List[Dict[str, Any]]]
) -> float:
"""Calculate confidence score based on prompt and tools."""
base_confidence = 0.7
# Increase confidence for longer prompts (more context)
length_bonus = min(len(prompt) / 1000, 0.2)
# Increase confidence if tools are available
tool_bonus = 0.1 if tools else 0
# Add some randomness
random_factor = random.uniform(0.9, 1.1)
confidence = (
base_confidence + length_bonus + tool_bonus
) * random_factor
return min(confidence, 1.0)
async def generate_tool_calls(
self, prompt: str, tools: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Generate tool calls based on prompt and available tools."""
# Simulate delay
await asyncio.sleep(random.uniform(0.1, 0.3))
# Select tools based on prompt content
selected_tools = []
query_type = self._classify_query(prompt)
for tool in tools:
tool_category = tool.get("category", "")
if (
(query_type == "hosts" and "host" in tool_category)
or (query_type == "vms" and "vm" in tool_category)
or (query_type == "ips" and "ip" in tool_category)
or (query_type == "vlans" and "vlan" in tool_category)
or (query_type == "search")
):
selected_tools.append(tool)
# Limit to 3 tools max
selected_tools = selected_tools[:3]
# Generate tool calls
tool_calls = []
for tool in selected_tools:
tool_calls.append(
{
"tool_name": tool["name"],
"parameters": {
"limit": random.randint(5, 20),
"include_certainty": True,
},
"reasoning": f"Need to query {tool.get('category', 'data')} for user request",
}
)
return tool_calls
class TinyLLMClient:
"""Client for interacting with the tiny LLM."""
def __init__(self, model_name: str = "tiny-test-llm"):
"""Initialize the tiny LLM client."""
self.llm = TinyLLM(model_name=model_name)
self.logger = logging.getLogger(self.__class__.__name__)
async def chat_completion(
self,
messages: List[Dict[str, str]],
tools: Optional[List[Dict[str, Any]]] = None,
max_tokens: int = 100,
) -> Dict[str, Any]:
"""Simulate OpenAI chat completion API."""
# Extract user message
user_message = ""
for message in messages:
if message.get("role") == "user":
user_message = message.get("content", "")
break
if not user_message:
raise ValueError("No user message found in messages")
# Generate response
response = await self.llm.generate_response(
prompt=user_message, tools=tools, max_tokens=max_tokens
)
# Generate tool calls
tool_calls = []
if tools:
tool_calls = await self.llm.generate_tool_calls(
user_message, tools
)
return {
"id": f"chatcmpl-{random.randint(100000, 999999)}",
"object": "chat.completion",
"created": int(time.time()),
"model": self.llm.model_name,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": response.content,
"tool_calls": tool_calls,
},
"finish_reason": "tool_calls" if tool_calls else "stop",
}
],
"usage": {
"prompt_tokens": int(len(user_message) * 0.25),
"completion_tokens": response.tokens_used,
"total_tokens": int(len(user_message) * 0.25)
+ response.tokens_used,
},
"response_time": response.processing_time,
"confidence": response.confidence,
}
async def function_calling(
self, prompt: str, functions: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""Simulate function calling."""
# Generate tool calls
tool_calls = await self.llm.generate_tool_calls(prompt, functions)
return {
"function_calls": tool_calls,
"reasoning": f"Selected {len(tool_calls)} functions based on user query",
"confidence": random.uniform(0.8, 0.95),
}
# Import asyncio for the delay simulation
import asyncio