context_manager.py•20.4 kB
"""
Context Window Management for Orchestration
Prevents context overflow and optimizes token usage for any LLM
"""
import json
import re
from typing import Dict, List, Tuple, Any, Optional
from dataclasses import dataclass
from datetime import datetime
import hashlib
from pathlib import Path
@dataclass
class ContextItem:
"""Represents a single item in the context window"""
content: Any
priority: int # 1-10, 1 is highest
content_type: str # 'error', 'code', 'analysis', 'history', 'file'
tokens: int
timestamp: datetime
metadata: Dict = None
def __post_init__(self):
if self.metadata is None:
self.metadata = {}
class ContextManager:
"""
Manages context window for LLM interactions
Prevents overflow and optimizes token usage
"""
# Model context limits (in tokens)
MODEL_LIMITS = {
"claude-3-opus": 200000,
"claude-3-sonnet": 200000,
"claude-3-haiku": 200000,
"claude-2.1": 200000,
"claude-2": 100000,
"gpt-4-turbo": 128000,
"gpt-4-32k": 32768,
"gpt-4": 8192,
"gpt-3.5-turbo": 16384,
"gpt-3.5": 4096,
"local": 4096,
"default": 100000
}
# Reserve tokens for response
RESPONSE_BUFFER = {
"claude-3-opus": 4000,
"claude-3-sonnet": 4000,
"gpt-4-turbo": 4000,
"gpt-4": 2000,
"default": 2000
}
def __init__(self, model: str = "claude-3-opus",
target_utilization: float = 0.9,
enable_compression: bool = True):
"""
Initialize context manager
Args:
model: LLM model name
target_utilization: Target % of context window to use (0.9 = 90%)
enable_compression: Whether to enable smart compression
"""
self.model = model
self.max_tokens = self.MODEL_LIMITS.get(model, self.MODEL_LIMITS["default"])
self.response_buffer = self.RESPONSE_BUFFER.get(model, self.RESPONSE_BUFFER["default"])
self.target_tokens = int((self.max_tokens - self.response_buffer) * target_utilization)
self.enable_compression = enable_compression
self.context_items: List[ContextItem] = []
self.current_tokens = 0
self.compression_stats = {
"original_tokens": 0,
"compressed_tokens": 0,
"items_compressed": 0,
"items_dropped": 0
}
def estimate_tokens(self, content: Any) -> int:
"""
Estimate token count for content
More accurate than simple char/4 estimation
Uses different ratios based on content type
"""
if content is None:
return 0
if isinstance(content, str):
# More accurate estimation based on content type
if self._is_code(content):
# Code typically has more tokens due to syntax
return len(content) // 3
else:
# Regular text
return len(content) // 4
elif isinstance(content, dict):
# JSON representation
json_str = json.dumps(content, separators=(',', ':'))
return len(json_str) // 4
elif isinstance(content, list):
return sum(self.estimate_tokens(item) for item in content)
else:
# Fallback to string representation
return len(str(content)) // 4
def add_context(self, content: Any, priority: int = 5,
content_type: str = "text", metadata: Dict = None) -> bool:
"""
Add content to context with priority
Args:
content: Content to add
priority: 1-10, where 1 is highest priority
content_type: Type of content for compression strategy
metadata: Additional metadata
Returns:
True if added successfully, False if would overflow
"""
tokens = self.estimate_tokens(content)
# Check if single item exceeds limit
if tokens > self.target_tokens:
# Try to chunk or compress
if content_type == "file" and isinstance(content, str):
return self._add_large_file(content, priority, metadata)
else:
# Compress if possible
compressed = self._compress_content(content, content_type, target_ratio=0.3)
tokens = self.estimate_tokens(compressed)
content = compressed
# Create context item
item = ContextItem(
content=content,
priority=priority,
content_type=content_type,
tokens=tokens,
timestamp=datetime.now(),
metadata=metadata or {}
)
# Add to context
self.context_items.append(item)
self.current_tokens += tokens
# Optimize if needed
if self.current_tokens > self.target_tokens:
self.optimize_context()
return True
def optimize_context(self) -> Tuple[int, int]:
"""
Optimize context to fit within limits
Returns:
(items_compressed, items_dropped)
"""
if self.current_tokens <= self.target_tokens:
return (0, 0)
items_compressed = 0
items_dropped = 0
# Sort by priority (lower number = higher priority)
sorted_items = sorted(self.context_items, key=lambda x: (x.priority, -x.timestamp.timestamp()))
# First pass: Compress low-priority items
if self.enable_compression:
for item in sorted_items:
if self.current_tokens <= self.target_tokens:
break
if item.priority >= 7 and item.content_type != "error":
# Compress low-priority items
original_tokens = item.tokens
compressed = self._compress_content(
item.content,
item.content_type,
target_ratio=0.2 if item.priority >= 9 else 0.4
)
new_tokens = self.estimate_tokens(compressed)
if new_tokens < original_tokens:
item.content = compressed
item.tokens = new_tokens
self.current_tokens -= (original_tokens - new_tokens)
items_compressed += 1
self.compression_stats["items_compressed"] += 1
# Second pass: Drop lowest priority items if still over
while self.current_tokens > self.target_tokens and sorted_items:
# Find lowest priority item
lowest_priority = max(sorted_items, key=lambda x: (x.priority, -x.timestamp.timestamp()))
# Never drop priority 1-3 items (errors, critical info)
if lowest_priority.priority <= 3:
break
# Remove item
self.context_items.remove(lowest_priority)
self.current_tokens -= lowest_priority.tokens
sorted_items.remove(lowest_priority)
items_dropped += 1
self.compression_stats["items_dropped"] += 1
return (items_compressed, items_dropped)
def _compress_content(self, content: Any, content_type: str,
target_ratio: float = 0.5) -> Any:
"""
Compress content based on type
Args:
content: Content to compress
content_type: Type of content
target_ratio: Target compression ratio (0.5 = 50% of original)
"""
if not self.enable_compression:
return content
if content_type == "error":
# Never compress errors
return content
elif content_type == "code" and isinstance(content, str):
# Remove comments and excess whitespace
lines = content.split('\n')
compressed_lines = []
for line in lines:
# Remove comments (simple approach)
if '#' in line:
line = line[:line.index('#')].rstrip()
if line.strip(): # Keep non-empty lines
compressed_lines.append(line.strip())
compressed = '\n'.join(compressed_lines)
# If still too large, truncate middle
if len(compressed) > len(content) * target_ratio:
target_len = int(len(content) * target_ratio)
compressed = self._truncate_middle(compressed, target_len)
return compressed
elif content_type == "analysis" and isinstance(content, dict):
# Keep only important fields
important_fields = ["summary", "errors", "issues", "score", "grade",
"recommendations", "priority", "action_plan"]
compressed = {}
for field in important_fields:
if field in content:
compressed[field] = content[field]
# Add truncated version of other fields
for key, value in content.items():
if key not in compressed and key not in ["raw_data", "debug", "trace"]:
if isinstance(value, str) and len(value) > 100:
compressed[key] = value[:100] + "..."
elif isinstance(value, list) and len(value) > 5:
compressed[key] = value[:5] + ["..."]
else:
compressed[key] = value
return compressed
elif content_type == "history":
# Summarize historical data
if isinstance(content, list):
# Keep first and last items, sample middle
if len(content) <= 5:
return content
return [
content[0],
f"... {len(content) - 2} items ...",
content[-1]
]
elif isinstance(content, str):
# Generic string compression - truncate middle
if len(content) > 1000:
target_len = int(len(content) * target_ratio)
return self._truncate_middle(content, target_len)
return content
def _truncate_middle(self, text: str, max_length: int) -> str:
"""
Truncate middle of text, preserving start and end
"""
if len(text) <= max_length:
return text
# Keep equal parts from start and end
part_length = (max_length - 20) // 2 # Reserve 20 chars for ellipsis
start = text[:part_length]
end = text[-part_length:]
return f"{start}\n... truncated {len(text) - max_length} chars ...\n{end}"
def _add_large_file(self, content: str, priority: int, metadata: Dict) -> bool:
"""
Add large file by chunking intelligently
"""
# Chunk size based on priority
chunk_size = 25000 if priority <= 3 else 15000 if priority <= 6 else 10000
# For code files, try to chunk by functions/classes
if self._is_code(content):
chunks = self._chunk_code_intelligently(content, chunk_size)
else:
# Simple chunking for other content
chunks = self._chunk_text(content, chunk_size)
# Add chunks with increasing priority (first chunk highest)
for i, chunk in enumerate(chunks):
chunk_priority = min(priority + i, 10)
self.add_context(
chunk,
priority=chunk_priority,
content_type="file_chunk",
metadata={
**(metadata or {}),
"chunk": i + 1,
"total_chunks": len(chunks)
}
)
return True
def _is_code(self, content: str) -> bool:
"""Check if content appears to be code"""
code_indicators = [
r'def\s+\w+\s*\(',
r'class\s+\w+',
r'import\s+\w+',
r'function\s*\(',
r'const\s+\w+\s*=',
r'if\s*\(',
r'for\s*\(',
]
for pattern in code_indicators:
if re.search(pattern, content[:1000]): # Check first 1000 chars
return True
return False
def _chunk_code_intelligently(self, code: str, chunk_size: int) -> List[str]:
"""
Chunk code by logical boundaries (functions, classes)
"""
lines = code.split('\n')
chunks = []
current_chunk = []
current_size = 0
indent_stack = [0]
for line in lines:
# Calculate indentation
indent = len(line) - len(line.lstrip())
# Check if we're starting a new top-level block
if indent == 0 and line.strip() and (
line.strip().startswith('def ') or
line.strip().startswith('class ') or
line.strip().startswith('async def ')
):
# Save current chunk if it's substantial
if current_size > chunk_size * 0.7:
chunks.append('\n'.join(current_chunk))
current_chunk = []
current_size = 0
current_chunk.append(line)
current_size += len(line)
# Check if we need to start new chunk
if current_size >= chunk_size:
# Try to find a good breaking point
chunks.append('\n'.join(current_chunk))
current_chunk = []
current_size = 0
# Add remaining
if current_chunk:
chunks.append('\n'.join(current_chunk))
return chunks
def _chunk_text(self, text: str, chunk_size: int) -> List[str]:
"""Simple text chunking with overlap"""
chunks = []
overlap = min(200, chunk_size // 4) # Characters of overlap, max 25% of chunk
step = max(1, chunk_size - overlap) # Ensure step is at least 1
for i in range(0, len(text), step):
chunk = text[i:i + chunk_size]
chunks.append(chunk)
return chunks
def get_optimized_context(self) -> Tuple[List[Dict], Dict]:
"""
Get optimized context ready for LLM
Returns:
(context_items, metadata)
"""
# Optimize first
self.optimize_context()
# Prepare context items
context_list = []
for item in sorted(self.context_items, key=lambda x: (x.priority, -x.timestamp.timestamp())):
context_entry = {
"type": item.content_type,
"priority": item.priority,
"content": item.content
}
if item.metadata:
context_entry["metadata"] = item.metadata
context_list.append(context_entry)
# Prepare metadata
metadata = {
"model": self.model,
"total_tokens": self.current_tokens,
"max_tokens": self.max_tokens,
"target_tokens": self.target_tokens,
"utilization": f"{(self.current_tokens / self.max_tokens) * 100:.1f}%",
"items": len(self.context_items),
"compression_stats": self.compression_stats
}
return context_list, metadata
def get_remaining_tokens(self) -> int:
"""Get number of tokens remaining in context window"""
return max(0, self.target_tokens - self.current_tokens)
def clear_context(self, keep_priority: Optional[int] = None):
"""
Clear context, optionally keeping high-priority items
Args:
keep_priority: Keep items with priority <= this value
"""
if keep_priority is None:
self.context_items = []
self.current_tokens = 0
else:
# Keep high priority items
kept_items = [
item for item in self.context_items
if item.priority <= keep_priority
]
self.context_items = kept_items
self.current_tokens = sum(item.tokens for item in kept_items)
def create_sliding_window(self, window_size: int = 10):
"""
Maintain sliding window of recent context
Older items are summarized or dropped
"""
if len(self.context_items) <= window_size:
return
# Sort by timestamp
sorted_items = sorted(self.context_items, key=lambda x: x.timestamp)
# Keep recent items
recent = sorted_items[-window_size:]
old = sorted_items[:-window_size]
# Summarize old items by type
summaries = {}
for item in old:
if item.content_type not in summaries:
summaries[item.content_type] = []
summaries[item.content_type].append(item)
# Create summary context items
self.context_items = []
for content_type, items in summaries.items():
summary = {
"type": "summary",
"content_type": content_type,
"items_count": len(items),
"total_tokens": sum(item.tokens for item in items),
"priority_range": f"{min(item.priority for item in items)}-{max(item.priority for item in items)}",
"time_range": f"{min(item.timestamp for item in items).isoformat()} to {max(item.timestamp for item in items).isoformat()}"
}
# Add key information based on type
if content_type == "error":
summary["errors"] = [item.content for item in items if item.priority <= 3]
elif content_type == "analysis":
summary["key_findings"] = [
item.content.get("summary", "")
for item in items
if isinstance(item.content, dict)
][:5]
self.add_context(
summary,
priority=9, # Low priority for summaries
content_type="history_summary"
)
# Add recent items back
for item in recent:
self.context_items.append(item)
# Recalculate tokens
self.current_tokens = sum(item.tokens for item in self.context_items)
def get_statistics(self) -> Dict:
"""Get detailed statistics about context usage"""
priority_distribution = {}
type_distribution = {}
for item in self.context_items:
# Priority distribution
priority_bucket = f"priority_{item.priority}"
priority_distribution[priority_bucket] = priority_distribution.get(priority_bucket, 0) + item.tokens
# Type distribution
type_distribution[item.content_type] = type_distribution.get(item.content_type, 0) + item.tokens
return {
"model": self.model,
"current_tokens": self.current_tokens,
"max_tokens": self.max_tokens,
"target_tokens": self.target_tokens,
"utilization": f"{(self.current_tokens / self.max_tokens) * 100:.1f}%",
"items_count": len(self.context_items),
"priority_distribution": priority_distribution,
"type_distribution": type_distribution,
"compression_stats": self.compression_stats,
"average_item_tokens": self.current_tokens / len(self.context_items) if self.context_items else 0
}