"""Orchestration engine for scenario execution."""
import asyncio
import json
import logging
from pathlib import Path
from typing import AsyncGenerator, Optional
import yaml
from ..clients.remediation_client import RemediationClient
from ..services.eval_service import EvalService
from ..services.executor_service import ExecutorService
from ..services.fault_service import FaultService
from ..utils.artifacts import artifact_path
from ..utils.variables import substitute_in_dict, substitute_variables
from .fsm import ScenarioContext, ScenarioResult, State, StepResult
logger = logging.getLogger(__name__)
async def run_scenario(
context: ScenarioContext,
remediation_client: RemediationClient,
fault_service: FaultService,
executor_service: ExecutorService,
eval_service: EvalService,
log_dir: str,
) -> AsyncGenerator[StepResult, None]:
"""
Execute a scenario end-to-end.
This is the main orchestration engine that drives scenario execution
through all FSM states.
Args:
context: Scenario execution context
remediation_client: Remediation workflow API client
fault_service: Fault injection service
executor_service: Command execution service
eval_service: Evaluation service
log_dir: Base log directory
Yields:
StepResult: Results for each execution step
"""
scenario = context.scenario
try:
# ===== INIT =====
logger.info(f"Starting scenario: {scenario.meta.id} (run: {context.run_id})")
# Resolve bindings (merge scenario bindings with context bindings)
bindings = {**scenario.bindings, **context.bindings}
result = StepResult(
state=State.INIT,
success=True,
message=f"Initialized scenario: {scenario.meta.title}",
)
context.add_result(result)
yield result
# ===== PRECHECK =====
logger.info("Running prechecks")
result = await _run_precheck(context, bindings)
context.add_result(result)
yield result
if not result.success:
context.set_final_state(State.FAIL)
return
# ===== FAULT_INJECT =====
logger.info("Injecting fault")
result = await _inject_fault(context, fault_service, bindings)
context.add_result(result)
yield result
if not result.success:
context.set_final_state(State.FAIL)
return
# ===== STABILIZE =====
logger.info("Stabilizing system")
result = await _stabilize(context)
context.add_result(result)
yield result
# ===== ASSISTANT_RCA =====
logger.info("Running RCA assistant")
result = await _run_rca_assistant(context, remediation_client, bindings)
context.add_result(result)
yield result
if not result.success:
context.set_final_state(State.FAIL)
return
# ===== EVAL_RCA =====
logger.info("Evaluating RCA response")
result = await _eval_rca(context, eval_service)
context.add_result(result)
yield result
if not result.success:
context.set_final_state(State.FAIL)
return
# ===== ASSISTANT_REMEDY =====
logger.info("Running remedy assistant")
result = await _run_remedy_assistant(context, remediation_client)
context.add_result(result)
yield result
if not result.success:
context.set_final_state(State.FAIL)
return
# ===== EVAL_REMEDY =====
logger.info("Evaluating remedy response")
result = await _eval_remedy(context, eval_service)
context.add_result(result)
yield result
if not result.success:
context.set_final_state(State.FAIL)
return
# ===== EXECUTE_REMEDY =====
logger.info("Executing remedy commands")
result = await _execute_remedy(context, executor_service)
context.add_result(result)
yield result
if not result.success:
context.set_final_state(State.FAIL)
return
# ===== VERIFY =====
logger.info("Verifying system state")
result = await _verify(context)
context.add_result(result)
yield result
if result.success:
context.set_final_state(State.PASS)
result = StepResult(
state=State.PASS,
success=True,
message="Scenario PASSED",
)
context.add_result(result)
yield result
else:
context.set_final_state(State.FAIL)
except Exception as e:
logger.error(f"Scenario execution failed: {e}", exc_info=True)
result = StepResult(
state=State.FAIL,
success=False,
message=f"Execution error: {e}",
)
context.add_result(result)
context.set_final_state(State.FAIL)
yield result
finally:
# ===== CLEANUP =====
logger.info("Running cleanup")
result = await _cleanup(context, fault_service, executor_service)
context.add_result(result)
yield result
# Save final artifacts
await _save_artifacts(context, log_dir)
async def _run_precheck(
context: ScenarioContext, bindings: dict
) -> StepResult:
"""Run prechecks."""
scenario = context.scenario
if not scenario.prechecks or not scenario.prechecks.signalflow:
return StepResult(
state=State.PRECHECK,
success=True,
message="No prechecks configured (skipped)",
)
# Stub: In production, integrate with SignalFlow API
logger.info("Running SignalFlow prechecks (stub)")
return StepResult(
state=State.PRECHECK,
success=True,
message="Prechecks passed (stub)",
)
async def _inject_fault(
context: ScenarioContext,
fault_service: FaultService,
bindings: dict,
) -> StepResult:
"""Inject fault."""
fault = context.scenario.fault
# Substitute variables in params
params = substitute_in_dict(fault.params, bindings)
success, fault_id, message = await fault_service.inject(
fault.type, params, context.run_id
)
if success:
context.fault_id = fault_id
return StepResult(
state=State.FAULT_INJECT,
success=success,
message=message,
metadata={"fault_id": fault_id},
)
async def _stabilize(context: ScenarioContext) -> StepResult:
"""Stabilize system after fault injection."""
wait_for = context.scenario.stabilize.wait_for
# Stub: In production, poll detector or use timeout
logger.info(f"Stabilizing for {wait_for.timeout}s (stub)")
await asyncio.sleep(2) # Short sleep for MVP
return StepResult(
state=State.STABILIZE,
success=True,
message=f"System stabilized (waited {wait_for.timeout}s)",
)
async def _run_rca_assistant(
context: ScenarioContext,
remediation_client: RemediationClient,
bindings: dict,
) -> StepResult:
"""Run RCA assistant interaction."""
rca_step = context.scenario.assistant_rca
# Substitute variables
system_prompt = substitute_variables(rca_step.system, bindings)
user_prompt = substitute_variables(rca_step.user, bindings)
logger.info("Initiating remediation workflow")
try:
# Initiate remediation
state = await remediation_client.initiate_remediation(
incident_id=f"incident-{context.run_id}",
rca_analysis={
"title": context.scenario.meta.title,
"summary": user_prompt[:200],
"nextSteps": "Awaiting RCA analysis",
},
)
context.thread_id = state.thread_id
context.interrupt_id = state.interrupt_id
# Extract RCA response from graph
graph = state.graph_of_subgraphs
rca_response = _extract_rca_from_graph(graph)
context.rca_response = rca_response
return StepResult(
state=State.ASSISTANT_RCA,
success=True,
message="RCA assistant completed",
metadata={
"thread_id": state.thread_id,
"response_preview": rca_response[:100] if rca_response else "",
},
)
except Exception as e:
logger.error(f"RCA assistant failed: {e}")
return StepResult(
state=State.ASSISTANT_RCA,
success=False,
message=f"RCA assistant failed: {e}",
)
async def _eval_rca(
context: ScenarioContext, eval_service: EvalService
) -> StepResult:
"""Evaluate RCA response."""
rca_step = context.scenario.assistant_rca
actual = context.rca_response or ""
passed, score, message, matched_refs, matched_metrics, failed_guards = (
await eval_service.score(
actual_text=actual,
expected_references=rca_step.expect.references,
expected_metrics=rca_step.expect.metrics,
guards=rca_step.expect.guards or [],
run_id=context.run_id,
)
)
return StepResult(
state=State.EVAL_RCA,
success=passed,
message=message,
score=score,
metadata={
"matched_references": matched_refs,
"matched_metrics": matched_metrics,
"failed_guards": failed_guards,
},
)
async def _run_remedy_assistant(
context: ScenarioContext,
remediation_client: RemediationClient,
) -> StepResult:
"""Run remedy assistant interaction."""
try:
# Get graph to extract commands
if not context.thread_id:
raise ValueError("No thread_id available")
# In production, would resume workflow with user input
# For stub, extract command nodes from existing graph
graph = await remediation_client.get_graph_state(
remediation_client.RemediationState(
thread_id=context.thread_id or "",
interrupt_id=context.interrupt_id or "",
graph_of_subgraphs={},
)
)
commands = remediation_client.extract_user_command_nodes(graph)
remedy_response = json.dumps(commands, indent=2) if commands else "No commands"
context.remedy_response = remedy_response
return StepResult(
state=State.ASSISTANT_REMEDY,
success=True,
message=f"Remedy assistant completed ({len(commands)} commands)",
metadata={"command_count": len(commands)},
)
except Exception as e:
logger.error(f"Remedy assistant failed: {e}")
# Continue with configured commands
context.remedy_response = "Using configured commands"
return StepResult(
state=State.ASSISTANT_REMEDY,
success=True,
message="Using configured remedy commands",
)
async def _eval_remedy(
context: ScenarioContext, eval_service: EvalService
) -> StepResult:
"""Evaluate remedy response."""
remedy_step = context.scenario.assistant_remedy
actual = context.remedy_response or ""
passed, score, message, matched_refs, matched_metrics, failed_guards = (
await eval_service.score(
actual_text=actual,
expected_references=remedy_step.expect.references,
expected_metrics=remedy_step.expect.metrics,
guards=remedy_step.expect.guards or [],
run_id=context.run_id,
)
)
return StepResult(
state=State.EVAL_REMEDY,
success=passed,
message=message,
score=score,
metadata={
"matched_references": matched_refs,
"matched_metrics": matched_metrics,
"failed_guards": failed_guards,
},
)
async def _execute_remedy(
context: ScenarioContext,
executor_service: ExecutorService,
) -> StepResult:
"""Execute remedy commands."""
execute = context.scenario.execute_remedy
sandbox = execute.sandbox
deny_patterns = sandbox.policies.get("deny_patterns", []) if sandbox.policies else []
all_success = True
results = []
for cmd in execute.commands:
logger.info(f"Executing command: {cmd.name}")
success, exit_code, stdout, stderr, artifact, message = (
await executor_service.run(
command=cmd.cmd,
args=cmd.args,
service_account=sandbox.service_account,
namespace=sandbox.namespace,
deny_patterns=deny_patterns,
run_id=context.run_id,
)
)
results.append(
{
"name": cmd.name,
"success": success,
"exit_code": exit_code,
"artifact": artifact,
}
)
if not success:
all_success = False
logger.error(f"Command failed: {cmd.name}")
return StepResult(
state=State.EXECUTE_REMEDY,
success=all_success,
message=f"Executed {len(execute.commands)} commands ({sum(r['success'] for r in results)} succeeded)",
metadata={"command_results": results},
)
async def _verify(context: ScenarioContext) -> StepResult:
"""Verify system state."""
verify = context.scenario.verify
if not verify or (not verify.signalflow and not verify.detector_clear):
return StepResult(
state=State.VERIFY,
success=True,
message="No verification configured (skipped)",
)
# Stub: In production, integrate with SignalFlow and detectors
logger.info("Running verification checks (stub)")
return StepResult(
state=State.VERIFY,
success=True,
message="Verification passed (stub)",
)
async def _cleanup(
context: ScenarioContext,
fault_service: FaultService,
executor_service: ExecutorService,
) -> StepResult:
"""Run cleanup tasks."""
cleanup_tasks = context.scenario.cleanup or []
# Clean up fault
if context.fault_id:
await fault_service.cleanup(context.fault_id, context.run_id)
# Run cleanup commands
failed = not context.passed
for cleanup in cleanup_tasks:
# Check if should run
if cleanup.when_failed and not failed:
logger.info(f"Skipping cleanup (when_failed=True): {cleanup.name}")
continue
logger.info(f"Running cleanup: {cleanup.name}")
await executor_service.run(
command=cleanup.cmd,
args=cleanup.args,
service_account=None,
namespace=None,
deny_patterns=[],
run_id=context.run_id,
)
return StepResult(
state=State.CLEANUP,
success=True,
message=f"Cleanup completed ({len(cleanup_tasks)} tasks)",
)
async def _save_artifacts(context: ScenarioContext, log_dir: str) -> None:
"""Save final artifacts."""
run_id = context.run_id
# Save scenario YAML
scenario_path = artifact_path(log_dir, run_id, "scenario.yaml")
with open(scenario_path, "w") as f:
yaml.dump(context.scenario.model_dump(), f, default_flow_style=False)
# Save transcript
transcript_path = artifact_path(log_dir, run_id, "transcript.json")
transcript = {
"run_id": run_id,
"rca_response": context.rca_response,
"remedy_response": context.remedy_response,
"thread_id": context.thread_id,
"interrupt_id": context.interrupt_id,
}
with open(transcript_path, "w") as f:
json.dump(transcript, f, indent=2)
# Save report
report_path = artifact_path(log_dir, run_id, "report.json")
result = ScenarioResult.from_context(context)
with open(report_path, "w") as f:
json.dump(result.to_dict(), f, indent=2)
logger.info(f"Artifacts saved to {Path(log_dir) / 'runs' / run_id}")
def _extract_rca_from_graph(graph: dict) -> str:
"""Extract RCA analysis from graph."""
# Stub: In production, parse graph structure
# Look for RCA node or analysis text
if not graph:
return "No RCA analysis available"
# Try to find analysis in graph
def find_text(obj, depth=0):
if depth > 10:
return None
if isinstance(obj, dict):
for key in ["analysis", "rca", "explanation", "text", "content"]:
if key in obj and isinstance(obj[key], str):
return obj[key]
for value in obj.values():
result = find_text(value, depth + 1)
if result:
return result
elif isinstance(obj, list):
for item in obj:
result = find_text(item, depth + 1)
if result:
return result
return None
text = find_text(graph)
return text or "RCA analysis extracted from workflow (stub)"