"""Client for Remediation Workflow API."""
import logging
from dataclasses import dataclass
from typing import Any, Dict, Optional
import httpx
logger = logging.getLogger(__name__)
@dataclass
class RemediationState:
"""State information for remediation workflow."""
thread_id: str
interrupt_id: str
graph_of_subgraphs: Dict[str, Any]
next_input_node: Optional[str] = None
def resolve_input_node(self) -> Optional[Dict[str, Any]]:
"""
Resolve the input node from next_input_node JSON pointer.
Returns:
Optional[Dict]: The resolved node or None
"""
if not self.next_input_node or not self.graph_of_subgraphs:
return None
# Parse JSON pointer (e.g., "/nodes/0/data")
parts = [p for p in self.next_input_node.split("/") if p]
current = self.graph_of_subgraphs
for part in parts:
if isinstance(current, dict):
current = current.get(part)
elif isinstance(current, list):
try:
current = current[int(part)]
except (ValueError, IndexError):
return None
else:
return None
if current is None:
return None
return current if isinstance(current, dict) else None
class RemediationClient:
"""
HTTP client for Remediation Workflow API.
Supports:
- InitiateEnsemble: Start a new remediation workflow
- ResumeEnsemble: Resume workflow with input
"""
def __init__(self, base_url: str, timeout: int = 300):
"""
Initialize remediation client.
Args:
base_url: Base URL for HTTP API
timeout: Request timeout in seconds
"""
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self.client = httpx.AsyncClient(timeout=timeout)
async def close(self) -> None:
"""Close the HTTP client."""
await self.client.aclose()
async def initiate_remediation(
self, incident_id: str, rca_analysis: Dict[str, str]
) -> RemediationState:
"""
Initiate a new remediation workflow.
Args:
incident_id: Unique incident identifier
rca_analysis: RCA analysis with title, summary, nextSteps
Returns:
RemediationState: Initial state of remediation workflow
Raises:
httpx.HTTPError: If request fails
"""
payload = {
"apiMethod": "InitiateEnsemble",
"apiVersion": "1",
"ensembleName": "REMEDIATION",
"payload": {
"incidentId": incident_id,
"rcaAnalysis": rca_analysis,
},
}
logger.info(f"Initiating remediation for incident: {incident_id}")
logger.debug(f"Payload: {payload}")
response = await self.client.post(
f"{self.base_url}/api/workflow", json=payload
)
response.raise_for_status()
data = response.json()
logger.debug(f"Response: {data}")
# Handle error in response
if "error" in data:
raise ValueError(f"API error: {data['error']}")
# Extract state
state_id = data.get("stateIdentifier", {})
updates = data.get("updates", {})
return RemediationState(
thread_id=state_id.get("threadId", ""),
interrupt_id=state_id.get("interruptId", ""),
graph_of_subgraphs=updates.get("graphOfSubgraphs", {}),
next_input_node=updates.get("nextInputNode"),
)
async def resume_remediation(
self,
thread_id: str,
interrupt_id: str,
node_id: str,
user_input: str,
) -> RemediationState:
"""
Resume remediation workflow with user input.
Args:
thread_id: Thread identifier
interrupt_id: Interrupt identifier
node_id: Node identifier to resume
user_input: User input text
Returns:
RemediationState: Updated state of remediation workflow
Raises:
httpx.HTTPError: If request fails
"""
payload = {
"apiMethod": "ResumeEnsemble",
"apiVersion": "1",
"payload": {
"messageType": "node_input",
"stateIdentifier": {
"threadId": thread_id,
"interruptId": interrupt_id,
},
"nodeId": node_id,
"inputProperties": {"input": user_input},
},
}
logger.info(f"Resuming remediation: thread={thread_id}, node={node_id}")
logger.debug(f"Payload: {payload}")
response = await self.client.post(
f"{self.base_url}/api/workflow", json=payload
)
response.raise_for_status()
data = response.json()
logger.debug(f"Response: {data}")
# Handle error in response
if "error" in data:
raise ValueError(f"API error: {data['error']}")
# Extract state
state_id = data.get("stateIdentifier", {})
updates = data.get("updates", {})
return RemediationState(
thread_id=state_id.get("threadId", thread_id),
interrupt_id=state_id.get("interruptId", interrupt_id),
graph_of_subgraphs=updates.get("graphOfSubgraphs", {}),
next_input_node=updates.get("nextInputNode"),
)
async def get_graph_state(self, state: RemediationState) -> Dict[str, Any]:
"""
Get the current graph state.
Args:
state: Current remediation state
Returns:
Dict: Graph state
"""
return state.graph_of_subgraphs
def extract_user_command_nodes(
self, graph: Dict[str, Any]
) -> list[Dict[str, Any]]:
"""
Extract user command nodes from graph.
Args:
graph: Graph of subgraphs
Returns:
List of command nodes
"""
nodes = []
def traverse(obj: Any) -> None:
if isinstance(obj, dict):
# Check if this is a user command node
if obj.get("type") == "user_command" or obj.get(
"nodeType"
) == "command":
nodes.append(obj)
for value in obj.values():
traverse(value)
elif isinstance(obj, list):
for item in obj:
traverse(item)
traverse(graph)
return nodes