agent_templates.py•22 kB
"""
Sub-Agent Template Management System
Provides reusable, focused code analysis and generation units
"""
import ast
import json
import re
from typing import Dict, List, Any, Optional, Callable
from pathlib import Path
from datetime import datetime
import hashlib
from .db_wrapper import ThreadSafeDB
from .clean_search import CleanSmartCodeSearch
from .git_analyzer import GitAnalyzer
from .dependency_analyzer import DependencyAnalyzer
from .usage_analyzer import UsageAnalyzer
class AgentTemplate:
"""Represents a single agent template"""
def __init__(self, name: str, agent_type: str, description: str,
template_code: str, input_schema: Dict, output_schema: Dict,
dependencies: Optional[List[str]] = None):
self.name = name
self.type = agent_type
self.description = description
self.template_code = template_code
self.input_schema = input_schema
self.output_schema = output_schema
self.dependencies = dependencies or []
self.usage_count = 0
self.success_rate = 0.0
self.avg_execution_ms = 0
def to_dict(self) -> Dict:
return {
'name': self.name,
'type': self.type,
'description': self.description,
'template_code': self.template_code,
'input_schema': self.input_schema,
'output_schema': self.output_schema,
'dependencies': self.dependencies,
'usage_count': self.usage_count,
'success_rate': self.success_rate,
'avg_execution_ms': self.avg_execution_ms
}
class AgentTemplateManager:
"""Manages sub-agent templates and their execution"""
def __init__(self, db: ThreadSafeDB, project_root: Path = None):
"""
Initialize agent template manager
Args:
db: Thread-safe database connection
project_root: Root directory of project
"""
self.db = db
self.project_root = project_root or Path.cwd()
self.templates: Dict[str, AgentTemplate] = {}
# Initialize helper services
self.search = CleanSmartCodeSearch(str(self.project_root))
self.git = GitAnalyzer(self.project_root)
self.deps = DependencyAnalyzer(str(self.project_root))
self.usage = UsageAnalyzer(self.project_root)
# Load existing templates from database
self.load_templates_from_db()
# Load prebuilt templates if database is empty
if not self.templates:
self.load_prebuilt_templates()
def load_templates_from_db(self):
"""Load templates from database"""
with self.db.get_connection() as conn:
cursor = conn.execute("""
SELECT name, type, description, template_code,
input_schema, output_schema, dependencies,
usage_count, success_rate, avg_execution_ms
FROM agent_templates
ORDER BY usage_count DESC
""")
for row in cursor:
template = AgentTemplate(
name=row[0],
agent_type=row[1],
description=row[2],
template_code=row[3],
input_schema=json.loads(row[4]),
output_schema=json.loads(row[5]),
dependencies=json.loads(row[6]) if row[6] else []
)
template.usage_count = row[7] or 0
template.success_rate = row[8] or 0.0
template.avg_execution_ms = row[9] or 0
self.templates[template.name] = template
def load_prebuilt_templates(self):
"""Load prebuilt agent templates"""
print("Loading prebuilt agent templates...")
# Analysis agents
templates = [
# Import Analysis
AgentTemplate(
name="import_analyzer",
agent_type="analyzer",
description="Analyze imports and find unused ones",
template_code='''
def analyze(code: str, file_path: str = None) -> dict:
imports = []
unused = []
try:
tree = ast.parse(code)
# Extract imports
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
imports.append({
'module': alias.name,
'alias': alias.asname,
'line': node.lineno
})
elif isinstance(node, ast.ImportFrom):
module = node.module or ''
for alias in node.names:
imports.append({
'module': f"{module}.{alias.name}" if module else alias.name,
'alias': alias.asname,
'line': node.lineno,
'from_import': True
})
# Find unused imports (simplified)
for imp in imports:
module_name = imp['alias'] or imp['module'].split('.')[-1]
# Check if module is referenced in code (excluding import lines)
code_lines = code.split('\\n')
used = False
for i, line in enumerate(code_lines, 1):
if i != imp['line'] and module_name in line:
used = True
break
if not used:
unused.append(imp)
except SyntaxError as e:
return {'error': str(e), 'imports': [], 'unused': []}
return {
'imports': imports,
'unused': unused,
'count': len(imports),
'unused_count': len(unused)
}
''',
input_schema={"type": "object", "properties": {"code": {"type": "string"}}},
output_schema={"type": "object", "properties": {"imports": {"type": "array"}}}
),
# Complexity Analysis
AgentTemplate(
name="complexity_analyzer",
agent_type="analyzer",
description="Analyze code complexity metrics",
template_code='''
def analyze(code: str, threshold: int = 10) -> dict:
complexity = 1 # Base complexity
nesting_depth = 0
max_nesting = 0
try:
tree = ast.parse(code)
# Calculate cyclomatic complexity
for node in ast.walk(tree):
if isinstance(node, (ast.If, ast.While, ast.For, ast.ExceptHandler)):
complexity += 1
elif isinstance(node, ast.BoolOp):
complexity += len(node.values) - 1
# Find max nesting (simplified)
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
current_depth = 0
for child in ast.walk(node):
if isinstance(child, (ast.If, ast.While, ast.For)):
current_depth += 1
max_nesting = max(max_nesting, current_depth)
except SyntaxError:
return {'error': 'Invalid Python syntax', 'complexity': 0}
risk = "high" if complexity > threshold else "medium" if complexity > 5 else "low"
return {
'complexity': complexity,
'risk': risk,
'max_nesting': max_nesting,
'exceeds_threshold': complexity > threshold
}
''',
input_schema={"type": "object", "properties": {"code": {"type": "string"}, "threshold": {"type": "integer"}}},
output_schema={"type": "object", "properties": {"complexity": {"type": "integer"}}}
),
# Function Extractor
AgentTemplate(
name="function_extractor",
agent_type="analyzer",
description="Extract all functions from code",
template_code='''
def analyze(code: str) -> dict:
functions = []
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
func_info = {
'name': node.name,
'line': node.lineno,
'args': [arg.arg for arg in node.args.args],
'is_async': isinstance(node, ast.AsyncFunctionDef),
'has_return': any(isinstance(n, ast.Return) for n in ast.walk(node)),
'decorators': [d.id if isinstance(d, ast.Name) else str(d) for d in node.decorator_list],
'docstring': ast.get_docstring(node) or ''
}
functions.append(func_info)
except SyntaxError:
return {'error': 'Invalid Python syntax', 'functions': []}
return {
'functions': functions,
'count': len(functions),
'async_count': sum(1 for f in functions if f['is_async'])
}
''',
input_schema={"type": "object", "properties": {"code": {"type": "string"}}},
output_schema={"type": "object", "properties": {"functions": {"type": "array"}}}
),
# Test Gap Finder
AgentTemplate(
name="test_gap_finder",
agent_type="analyzer",
description="Find functions without corresponding tests",
template_code='''
def analyze(functions: list, test_files: list) -> dict:
# Extract tested function names from test files
tested = set()
for test_file in test_files:
# Look for test_functionname patterns
import re
pattern = r'def test_([a-zA-Z_][a-zA-Z0-9_]*)'
matches = re.findall(pattern, test_file.get('content', ''))
tested.update(matches)
# Also look for direct function calls in tests
for func in functions:
if func['name'] in test_file.get('content', ''):
tested.add(func['name'])
# Find gaps
untested = []
for func in functions:
# Skip private functions if configured
if func['name'].startswith('_'):
continue
if func['name'] not in tested:
untested.append(func)
coverage = len(tested) / len(functions) if functions else 1.0
return {
'untested': untested,
'tested_count': len(tested),
'total_count': len(functions),
'coverage': coverage,
'coverage_percent': round(coverage * 100, 1)
}
''',
input_schema={"type": "object", "properties": {"functions": {"type": "array"}, "test_files": {"type": "array"}}},
output_schema={"type": "object", "properties": {"untested": {"type": "array"}}}
),
# Duplicate Detector
AgentTemplate(
name="duplicate_detector",
agent_type="analyzer",
description="Find semantically similar code blocks",
template_code='''
def analyze(code: str, threshold: float = 0.85) -> dict:
# This would use the semantic search in real implementation
# For now, return placeholder
duplicates = []
# Simple line-based duplicate detection
lines = code.split('\\n')
seen_blocks = {}
for i in range(len(lines) - 5): # Look for 5+ line duplicates
block = '\\n'.join(lines[i:i+5])
block_hash = hash(block.strip())
if block_hash in seen_blocks and block.strip():
duplicates.append({
'line_start_1': seen_blocks[block_hash],
'line_start_2': i + 1,
'lines': 5,
'similarity': 1.0
})
else:
seen_blocks[block_hash] = i + 1
return {
'duplicates': duplicates,
'count': len(duplicates),
'total_duplicate_lines': len(duplicates) * 5
}
''',
input_schema={"type": "object", "properties": {"code": {"type": "string"}, "threshold": {"type": "number"}}},
output_schema={"type": "object", "properties": {"duplicates": {"type": "array"}}}
),
]
# Generator agents
generator_templates = [
AgentTemplate(
name="test_generator",
agent_type="generator",
description="Generate test cases for functions",
template_code='''
def generate(function_info: dict, examples: list = None) -> str:
func_name = function_info['name']
args = function_info.get('args', [])
test_code = f"""def test_{func_name}():
\"\"\"Test {func_name} function\"\"\"
"""
# Generate basic test structure
if not args:
test_code += f" result = {func_name}()\\n"
test_code += " assert result is not None\\n"
else:
# Generate test with sample arguments
sample_args = ', '.join(['None' if 'self' in arg else '1' for arg in args if arg != 'self'])
test_code += f" result = {func_name}({sample_args})\\n"
test_code += " assert result is not None\\n"
# Add edge cases
test_code += "\\n # Test edge cases\\n"
if args:
test_code += f" with pytest.raises(TypeError):\\n"
test_code += f" {func_name}() # Missing arguments\\n"
return test_code
''',
input_schema={"type": "object", "properties": {"function_info": {"type": "object"}}},
output_schema={"type": "string"}
),
AgentTemplate(
name="docstring_generator",
agent_type="generator",
description="Generate docstrings for functions",
template_code='''
def generate(function_info: dict) -> str:
func_name = function_info['name']
args = function_info.get('args', [])
has_return = function_info.get('has_return', False)
docstring = f'"""\\n {func_name.replace("_", " ").title()}\\n'
if args and args != ['self']:
docstring += "\\n Args:\\n"
for arg in args:
if arg != 'self':
docstring += f" {arg}: Description\\n"
if has_return:
docstring += "\\n Returns:\\n"
docstring += " Description of return value\\n"
docstring += ' """'
return docstring
''',
input_schema={"type": "object", "properties": {"function_info": {"type": "object"}}},
output_schema={"type": "string"}
),
]
# Register all templates
for template in templates + generator_templates:
self.register_template(template)
print(f"Loaded {len(templates) + len(generator_templates)} prebuilt templates")
def register_template(self, template: AgentTemplate):
"""Register a new agent template"""
# Store in memory
self.templates[template.name] = template
# Store in database
with self.db.get_connection() as conn:
conn.execute("""
INSERT OR REPLACE INTO agent_templates
(name, type, description, template_code, input_schema, output_schema, dependencies)
VALUES (?, ?, ?, ?, ?, ?, ?)
""", (
template.name,
template.type,
template.description,
template.template_code,
json.dumps(template.input_schema),
json.dumps(template.output_schema),
json.dumps(template.dependencies)
))
conn.commit()
def get_agent(self, name: str) -> Optional[AgentTemplate]:
"""Get agent template by name"""
return self.templates.get(name)
def list_agents(self, agent_type: Optional[str] = None) -> List[AgentTemplate]:
"""List all agents, optionally filtered by type"""
agents = list(self.templates.values())
if agent_type:
agents = [a for a in agents if a.type == agent_type]
return sorted(agents, key=lambda a: a.usage_count, reverse=True)
def execute_agent(self, agent_name: str, inputs: Dict) -> Any:
"""
Execute an agent template
Args:
agent_name: Name of the agent to execute
inputs: Input parameters
Returns:
Agent execution result
"""
agent = self.get_agent(agent_name)
if not agent:
raise ValueError(f"Agent '{agent_name}' not found")
# Create execution environment with helper functions
exec_globals = {
'ast': ast,
'json': json,
're': re,
'Path': Path,
'datetime': datetime,
# Helper services
'search': self.search,
'git': self.git,
'deps': self.deps,
'usage': self.usage,
# Utility functions
'extract_functions': self._extract_functions,
'find_unused_imports': self._find_unused_imports,
'calculate_complexity': self._calculate_complexity,
}
# Execute template code
exec(agent.template_code, exec_globals)
# Prepare inputs with parameter mapping for compatibility
# Ensure 'code' parameter is available for all agents
mapped_inputs = dict(inputs)
# Handle special parameter mappings based on agent name
if agent_name == 'test_gap_finder':
# This agent expects 'functions' and 'test_files'
if 'code' in inputs and 'functions' not in inputs:
# Extract functions from code if not provided
mapped_inputs['functions'] = []
mapped_inputs['test_files'] = []
elif agent_name == 'duplicate_detector':
# Ensure threshold is set
if 'threshold' not in mapped_inputs:
mapped_inputs['threshold'] = 0.85
# Call the appropriate function based on agent type
try:
if agent.type in ['analyzer', 'validator']:
# Try calling with mapped inputs
result = exec_globals['analyze'](**mapped_inputs)
elif agent.type == 'generator':
result = exec_globals['generate'](**mapped_inputs)
elif agent.type == 'refactor':
result = exec_globals['refactor'](**mapped_inputs)
else:
raise ValueError(f"Unknown agent type: {agent.type}")
except TypeError as e:
# If there's a parameter mismatch, try with just 'code' parameter
if 'code' in inputs and agent.type in ['analyzer', 'validator']:
try:
result = exec_globals['analyze'](inputs['code'])
except:
# If that doesn't work either, return error details
return {'error': f'Parameter mismatch for {agent_name}: {str(e)}', 'inputs_provided': list(inputs.keys())}
else:
raise
# Update usage statistics
self._update_usage_stats(agent_name, success=True)
return result
def _update_usage_stats(self, agent_name: str, success: bool, execution_ms: int = 0):
"""Update agent usage statistics"""
with self.db.get_connection() as conn:
if success:
conn.execute("""
UPDATE agent_templates
SET usage_count = usage_count + 1,
success_rate = (success_rate * usage_count + 1.0) / (usage_count + 1),
avg_execution_ms = (avg_execution_ms * usage_count + ?) / (usage_count + 1),
updated_at = ?
WHERE name = ?
""", (execution_ms, datetime.now().isoformat(), agent_name))
else:
conn.execute("""
UPDATE agent_templates
SET usage_count = usage_count + 1,
success_rate = (success_rate * usage_count) / (usage_count + 1),
updated_at = ?
WHERE name = ?
""", (datetime.now().isoformat(), agent_name))
conn.commit()
# Helper functions for agents
def _extract_functions(self, code: str) -> List[Dict]:
"""Extract function information from code"""
functions = []
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
functions.append({
'name': node.name,
'line': node.lineno,
'args': [arg.arg for arg in node.args.args]
})
except:
pass
return functions
def _find_unused_imports(self, imports: List[str], code: str) -> List[str]:
"""Find unused imports in code"""
unused = []
for imp in imports:
# Simple check - would be more sophisticated in production
if imp not in code:
unused.append(imp)
return unused
def _calculate_complexity(self, code: str) -> int:
"""Calculate cyclomatic complexity"""
complexity = 1
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, (ast.If, ast.While, ast.For, ast.ExceptHandler)):
complexity += 1
except:
pass
return complexity