llm_optimizer.py•14.5 kB
"""LLM-specific optimizations for NetBox MCP Server.
This module provides optimizations specifically designed for local LLMs
accessing the MCP server through an OpenAI API gateway.
"""
import logging
import json
import time
from typing import Dict, Any, List, Optional, Union
from dataclasses import dataclass
from functools import lru_cache
import asyncio
from concurrent.futures import ThreadPoolExecutor
logger = logging.getLogger(__name__)
@dataclass
class LLMResponse:
"""Optimized response structure for LLMs."""
content: str
metadata: Dict[str, Any]
confidence: float
response_time: float
token_count: int
class LLMOptimizer:
"""Optimizer for LLM-specific performance improvements."""
def __init__(self, max_workers: int = 4):
"""Initialize LLM optimizer.
Args:
max_workers: Maximum number of worker threads for parallel processing
"""
self.max_workers = max_workers
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self._response_cache = {}
self._token_estimator = TokenEstimator()
def optimize_for_llm(self, data: List[Dict[str, Any]],
response_type: str = "list") -> LLMResponse:
"""Optimize data specifically for LLM consumption.
Args:
data: Raw data from NetBox
response_type: Type of response (list, detail, search)
Returns:
Optimized LLM response
"""
start_time = time.time()
# Optimize data structure for LLM
optimized_content = self._optimize_data_structure(data, response_type)
# Estimate token count
token_count = self._token_estimator.estimate(optimized_content)
# Add LLM-friendly metadata
metadata = self._create_llm_metadata(data, response_type)
# Calculate confidence score
confidence = self._calculate_confidence(data)
response_time = time.time() - start_time
return LLMResponse(
content=optimized_content,
metadata=metadata,
confidence=confidence,
response_time=response_time,
token_count=token_count
)
def _optimize_data_structure(self, data: List[Dict[str, Any]],
response_type: str) -> str:
"""Optimize data structure for LLM consumption."""
if not data:
return self._create_empty_response(response_type)
if response_type == "list":
return self._optimize_list_response(data)
elif response_type == "detail":
return self._optimize_detail_response(data[0])
elif response_type == "search":
return self._optimize_search_response(data)
else:
return self._optimize_generic_response(data)
def _optimize_list_response(self, data: List[Dict[str, Any]]) -> str:
"""Optimize list responses for LLMs."""
if not data:
return "No results found."
# Create LLM-friendly summary
summary = f"Found {len(data)} items:\n\n"
# Group by type for better LLM understanding
grouped_data = self._group_by_type(data)
for item_type, items in grouped_data.items():
summary += f"## {item_type.title()} ({len(items)} items)\n"
for item in items[:5]: # Limit to 5 items per type
summary += f"- {self._format_item_summary(item)}\n"
if len(items) > 5:
summary += f" ... and {len(items) - 5} more\n"
summary += "\n"
return summary.strip()
def _optimize_detail_response(self, data: Dict[str, Any]) -> str:
"""Optimize detail responses for LLMs."""
if not data:
return "Item not found."
# Create structured detail view
result = f"## {data.get('name', 'Unknown')} Details\n\n"
# Key information first
key_fields = ['id', 'name', 'display', 'status', 'primary_ip4', 'device_role']
for field in key_fields:
if field in data and data[field]:
result += f"**{field.replace('_', ' ').title()}**: {self._format_value(data[field])}\n"
# Additional information
result += "\n### Additional Information\n"
for key, value in data.items():
if key not in key_fields and value:
result += f"- **{key.replace('_', ' ').title()}**: {self._format_value(value)}\n"
return result
def _optimize_search_response(self, data: List[Dict[str, Any]]) -> str:
"""Optimize search responses for LLMs."""
if not data:
return "No matching results found."
result = f"Found {len(data)} matching results:\n\n"
for i, item in enumerate(data, 1):
result += f"{i}. **{item.get('name', 'Unknown')}**\n"
result += f" - Type: {self._get_item_type(item)}\n"
result += f" - Status: {self._get_item_status(item)}\n"
if 'primary_ip4' in item and item['primary_ip4']:
result += f" - IP: {self._format_value(item['primary_ip4'])}\n"
result += "\n"
return result.strip()
def _optimize_generic_response(self, data: List[Dict[str, Any]]) -> str:
"""Optimize generic responses for LLMs."""
if not data:
return "No data available."
return json.dumps(data, indent=2, default=str)
def _group_by_type(self, data: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]:
"""Group items by type for better LLM understanding."""
grouped = {}
for item in data:
item_type = self._get_item_type(item)
if item_type not in grouped:
grouped[item_type] = []
grouped[item_type].append(item)
return grouped
def _get_item_type(self, item: Dict[str, Any]) -> str:
"""Determine item type for grouping."""
if 'device_role' in item:
return 'device'
elif 'role' in item and 'cluster' in item:
return 'virtual_machine'
elif 'vid' in item:
return 'vlan'
elif 'address' in item:
return 'ip_address'
else:
return 'unknown'
def _get_item_status(self, item: Dict[str, Any]) -> str:
"""Get item status in a readable format."""
status = item.get('status', {})
if isinstance(status, dict):
return status.get('label', 'Unknown')
return str(status)
def _format_item_summary(self, item: Dict[str, Any]) -> str:
"""Format item for summary display."""
name = item.get('name', 'Unknown')
status = self._get_item_status(item)
# Add IP if available
ip_info = ""
if 'primary_ip4' in item and item['primary_ip4']:
ip = self._format_value(item['primary_ip4'])
ip_info = f" ({ip})"
return f"{name} - {status}{ip_info}"
def _format_value(self, value: Any) -> str:
"""Format value for LLM consumption."""
if isinstance(value, dict):
if 'display' in value:
return value['display']
elif 'address' in value:
return value['address']
elif 'name' in value:
return value['name']
else:
return str(value)
elif isinstance(value, list):
return ', '.join(str(v) for v in value[:3]) + ('...' if len(value) > 3 else '')
else:
return str(value)
def _create_empty_response(self, response_type: str) -> str:
"""Create appropriate empty response."""
responses = {
'list': "No items found.",
'detail': "Item not found.",
'search': "No matching results found.",
'generic': "No data available."
}
return responses.get(response_type, "No results found.")
def _create_llm_metadata(self, data: List[Dict[str, Any]],
response_type: str) -> Dict[str, Any]:
"""Create metadata optimized for LLM consumption."""
return {
'count': len(data),
'type': response_type,
'timestamp': time.time(),
'optimized_for_llm': True,
'data_types': list(set(self._get_item_type(item) for item in data)),
'has_confidence_scores': any('certainty_score' in item for item in data)
}
def _calculate_confidence(self, data: List[Dict[str, Any]]) -> float:
"""Calculate overall confidence score."""
if not data:
return 0.0
confidence_scores = []
for item in data:
if 'certainty_score' in item:
confidence_scores.append(item['certainty_score'])
if confidence_scores:
return sum(confidence_scores) / len(confidence_scores)
else:
return 0.8 # Default confidence for data without scores
async def batch_optimize(self, data_batches: List[List[Dict[str, Any]]],
response_types: List[str]) -> List[LLMResponse]:
"""Optimize multiple data batches in parallel."""
loop = asyncio.get_event_loop()
tasks = []
for data, response_type in zip(data_batches, response_types):
task = loop.run_in_executor(
self.executor,
self.optimize_for_llm,
data,
response_type
)
tasks.append(task)
return await asyncio.gather(*tasks)
def close(self):
"""Close the executor."""
self.executor.shutdown(wait=True)
class TokenEstimator:
"""Estimates token count for LLM responses."""
def __init__(self):
self._cache = {}
@lru_cache(maxsize=1000)
def estimate(self, text: str) -> int:
"""Estimate token count for text.
Uses a simple approximation: ~4 characters per token for English text.
"""
if not text:
return 0
# Simple approximation: 4 characters per token
return len(text) // 4
def estimate_json(self, data: Any) -> int:
"""Estimate token count for JSON data."""
json_str = json.dumps(data, default=str)
return self.estimate(json_str)
class LLMResponseFormatter:
"""Formats responses specifically for different LLM contexts."""
@staticmethod
def format_for_chat_completion(response: LLMResponse) -> Dict[str, Any]:
"""Format response for OpenAI chat completion API."""
return {
'role': 'assistant',
'content': response.content,
'metadata': {
'confidence': response.confidence,
'response_time': response.response_time,
'token_count': response.token_count,
**response.metadata
}
}
@staticmethod
def format_for_function_calling(response: LLMResponse) -> Dict[str, Any]:
"""Format response for OpenAI function calling."""
return {
'function_name': 'netbox_query',
'arguments': {
'result': response.content,
'confidence': response.confidence,
'metadata': response.metadata
}
}
@staticmethod
def format_for_streaming(response: LLMResponse) -> List[Dict[str, Any]]:
"""Format response for streaming API."""
chunks = []
content = response.content
# Split content into chunks for streaming
chunk_size = 100 # characters per chunk
for i in range(0, len(content), chunk_size):
chunk = content[i:i + chunk_size]
chunks.append({
'delta': {'content': chunk},
'metadata': response.metadata if i == 0 else None
})
return chunks
class LLMCache:
"""Intelligent caching for LLM responses."""
def __init__(self, max_size: int = 1000, ttl: int = 300):
"""Initialize LLM cache.
Args:
max_size: Maximum number of cached responses
ttl: Time to live in seconds
"""
self.max_size = max_size
self.ttl = ttl
self._cache = {}
self._access_times = {}
self._creation_times = {}
def get(self, key: str) -> Optional[LLMResponse]:
"""Get cached response."""
if key not in self._cache:
return None
# Check TTL
if time.time() - self._creation_times[key] > self.ttl:
self._remove(key)
return None
# Update access time
self._access_times[key] = time.time()
return self._cache[key]
def put(self, key: str, response: LLMResponse) -> None:
"""Cache response."""
# Remove oldest if cache is full
if len(self._cache) >= self.max_size:
self._evict_oldest()
self._cache[key] = response
self._access_times[key] = time.time()
self._creation_times[key] = time.time()
def _remove(self, key: str) -> None:
"""Remove item from cache."""
if key in self._cache:
del self._cache[key]
del self._access_times[key]
del self._creation_times[key]
def _evict_oldest(self) -> None:
"""Evict least recently used item."""
if not self._access_times:
return
oldest_key = min(self._access_times.keys(),
key=lambda k: self._access_times[k])
self._remove(oldest_key)
def clear(self) -> None:
"""Clear all cached responses."""
self._cache.clear()
self._access_times.clear()
self._creation_times.clear()
def stats(self) -> Dict[str, Any]:
"""Get cache statistics."""
return {
'size': len(self._cache),
'max_size': self.max_size,
'ttl': self.ttl,
'hit_rate': getattr(self, '_hit_count', 0) / max(getattr(self, '_access_count', 1), 1)
}