workflow.pyโข9.95 kB
"""Workflow engine for multi-agent collaboration."""
import re
import logging
from typing import Any
from pathlib import Path
from datetime import datetime
from pydantic import BaseModel, Field
import yaml
from .orchestrator import OrchestratorRegistry
from .logging_config import delegation_logger
logger = logging.getLogger(__name__)
class WorkflowStep(BaseModel):
"""A single step in a workflow."""
id: str
agent: str
task: str
output: str | None = None # Variable name to store output
condition: str | None = None # Conditional execution
description: str = ""
class WorkflowDefinition(BaseModel):
"""Definition of a multi-agent workflow."""
name: str
description: str = ""
steps: list[WorkflowStep]
metadata: dict[str, Any] = Field(default_factory=dict)
@classmethod
def from_yaml(cls, path: Path) -> "WorkflowDefinition":
"""Load workflow from YAML file."""
with open(path) as f:
data = yaml.safe_load(f)
return cls(**data)
def to_yaml(self, path: Path) -> None:
"""Save workflow to YAML file."""
with open(path, "w") as f:
yaml.dump(self.model_dump(), f, default_flow_style=False)
class WorkflowContext:
"""Context for workflow execution with variable storage."""
def __init__(self):
self.variables: dict[str, Any] = {}
self.history: list[dict[str, Any]] = []
def set(self, name: str, value: Any) -> None:
"""Set a variable in context."""
self.variables[name] = value
def get(self, name: str, default: Any = None) -> Any:
"""Get a variable from context."""
return self.variables.get(name, default)
def interpolate(self, template: str) -> str:
"""
Interpolate variables in template string with safe escaping.
Supports: {{ variable_name }}
Note: Values are NOT shell-escaped to allow for flexible use cases.
If the interpolated string will be passed to shell commands, the caller
must handle escaping appropriately.
"""
def replace_var(match):
var_name = match.group(1).strip()
# Validate variable name (only allow alphanumeric and underscore)
if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', var_name):
logger.warning(f"Invalid variable name in template: {var_name}")
return f"{{{{ {var_name} }}}}"
value = self.get(var_name)
if value is None:
return f"{{{{ {var_name} }}}}"
# Convert to string and sanitize dangerous characters
str_value = str(value)
# Log warning if potentially dangerous characters detected
dangerous_chars = [';', '&&', '||', '`', '$', '|', '>', '<', '\n', '\r']
if any(char in str_value for char in dangerous_chars):
logger.warning(f"Variable '{var_name}' contains potentially dangerous characters: {str_value[:50]}")
return str_value
return re.sub(r'\{\{\s*([^}]+)\s*\}\}', replace_var, template)
def evaluate_condition(self, condition: str) -> bool:
"""
Evaluate a simple condition.
Supports:
- {{ var_name }}: Check if variable exists and is truthy
- {{ var_name | length > 0 }}: Check list/string length
"""
if not condition:
return True
# Extract variable name
var_match = re.search(r'\{\{\s*([^}|]+)', condition)
if not var_match:
return True
var_name = var_match.group(1).strip()
value = self.get(var_name)
# Check for length filter
if '| length' in condition:
if isinstance(value, (list, str, dict)):
length = len(value)
# Extract comparison - with null checks to prevent ReDoS
if '>' in condition:
match = re.search(r'>\s*(\d+)', condition)
if not match:
logger.warning(f"Invalid condition format: {condition}")
return False
threshold = int(match.group(1))
return length > threshold
elif '<' in condition:
match = re.search(r'<\s*(\d+)', condition)
if not match:
logger.warning(f"Invalid condition format: {condition}")
return False
threshold = int(match.group(1))
return length < threshold
elif '==' in condition or '=' in condition:
match = re.search(r'==?\s*(\d+)', condition)
if not match:
logger.warning(f"Invalid condition format: {condition}")
return False
threshold = int(match.group(1))
return length == threshold
return False
# Default: check truthiness
return bool(value)
def add_to_history(self, step_id: str, result: dict[str, Any]) -> None:
"""Add step result to history."""
self.history.append({
"step_id": step_id,
"timestamp": datetime.now().isoformat(),
**result
})
class WorkflowResult(BaseModel):
"""Result of workflow execution."""
workflow_name: str
success: bool
steps_completed: int
total_steps: int
duration: float
outputs: dict[str, Any]
errors: list[str] = Field(default_factory=list)
class WorkflowEngine:
"""Engine for executing multi-agent workflows."""
def __init__(self, registry: OrchestratorRegistry):
self.registry = registry
async def execute(
self,
workflow: WorkflowDefinition,
initial_context: dict[str, Any] | None = None
) -> WorkflowResult:
"""
Execute a workflow.
Args:
workflow: Workflow definition
initial_context: Initial variables for context
Returns:
WorkflowResult with execution details
"""
logger.info(f"Starting workflow: {workflow.name}")
start = datetime.now()
# Initialize context
context = WorkflowContext()
if initial_context:
for key, value in initial_context.items():
context.set(key, value)
steps_completed = 0
errors = []
# Execute steps sequentially
for step in workflow.steps:
logger.info(f"Executing step: {step.id} (agent: {step.agent})")
# Check condition
if step.condition and not context.evaluate_condition(step.condition):
logger.info(f"Skipping step {step.id}: condition not met")
continue
# Interpolate task with context variables
task = context.interpolate(step.task)
logger.debug(f"Task after interpolation: {task}")
# Execute step
try:
stdout, stderr, returncode = await self.registry.execute(
step.agent,
task
)
success = returncode == 0
if success:
steps_completed += 1
# Store output in context
if step.output:
# Parse output - for now just store stdout
context.set(step.output, stdout.strip())
logger.info(f"Stored output in variable: {step.output}")
# Add to history
context.add_to_history(step.id, {
"agent": step.agent,
"success": True,
"output": stdout,
})
else:
error_msg = f"Step {step.id} failed: {stderr}"
logger.error(error_msg)
errors.append(error_msg)
context.add_to_history(step.id, {
"agent": step.agent,
"success": False,
"error": stderr,
})
# Stop on first error (can be made configurable)
break
except Exception as e:
error_msg = f"Step {step.id} error: {str(e)}"
logger.error(error_msg)
errors.append(error_msg)
context.add_to_history(step.id, {
"agent": step.agent,
"success": False,
"error": str(e),
})
# Stop on error
break
duration = (datetime.now() - start).total_seconds()
result = WorkflowResult(
workflow_name=workflow.name,
success=steps_completed == len(workflow.steps) and not errors,
steps_completed=steps_completed,
total_steps=len(workflow.steps),
duration=duration,
outputs=context.variables,
errors=errors,
)
logger.info(
f"Workflow {workflow.name} completed: "
f"{steps_completed}/{len(workflow.steps)} steps in {duration:.2f}s"
)
return result
def load_workflow(self, path: Path) -> WorkflowDefinition:
"""Load workflow from file."""
return WorkflowDefinition.from_yaml(path)
def list_workflows(self, directory: Path) -> list[WorkflowDefinition]:
"""List all workflows in a directory."""
workflows = []
for yaml_file in directory.glob("*.yaml"):
try:
workflow = self.load_workflow(yaml_file)
workflows.append(workflow)
except Exception as e:
logger.warning(f"Failed to load workflow {yaml_file}: {e}")
return workflows