orchestrator.pyโข8.84 kB
"""Orchestrator registry and management."""
import asyncio
import os
import re
import shutil
import subprocess
from typing import Any
from pathlib import Path
from .config import OrchestratorConfig
class OrchestratorRegistry:
"""Registry for managing available orchestrators/CLIs."""
def __init__(self):
self.orchestrators: dict[str, OrchestratorConfig] = {}
self._active_sessions: dict[str, asyncio.subprocess.Process] = {}
def register(self, config: OrchestratorConfig) -> None:
"""Register an orchestrator."""
self.orchestrators[config.name] = config
def unregister(self, name: str) -> None:
"""Unregister an orchestrator."""
self.orchestrators.pop(name, None)
def get(self, name: str) -> OrchestratorConfig | None:
"""Get orchestrator configuration."""
return self.orchestrators.get(name)
def list_enabled(self) -> list[str]:
"""List all enabled orchestrators."""
return [name for name, config in self.orchestrators.items() if config.enabled]
@staticmethod
def _resolve_command(cmd: list[str]) -> list[str]:
"""
Resolve command to full path on Windows.
On Windows, asyncio.create_subprocess_exec() doesn't reliably search PATH,
so we need to resolve commands to their full paths using shutil.which().
Args:
cmd: Command list
Returns:
Resolved command (full path on Windows, original on Unix)
"""
if os.name != "nt" or not cmd:
# On Unix systems, PATH search works fine
return cmd
# On Windows, resolve the executable path
resolved = shutil.which(cmd[0])
if resolved:
return [resolved] + cmd[1:]
return cmd
async def execute(
self,
orchestrator_name: str,
task: str,
timeout: int | None = None,
progress_callback: Any = None,
) -> tuple[str, str, int]:
"""
Execute a task using specified orchestrator.
Args:
orchestrator_name: Name of orchestrator to use
task: Task description/query
timeout: Optional timeout in seconds
progress_callback: Optional async callback(line: str) for stdout streaming
Returns:
tuple: (stdout, stderr, return_code)
"""
config = self.get(orchestrator_name)
if not config:
raise ValueError(f"Orchestrator '{orchestrator_name}' not found")
if not config.enabled:
raise ValueError(f"Orchestrator '{orchestrator_name}' is disabled")
# Build command
if isinstance(config.command, list):
cmd = config.command + config.args + [task]
else:
cmd = [config.command] + config.args + [task]
# Resolve command path on Windows
resolved_cmd = self._resolve_command(cmd)
# Execute with timeout
timeout_seconds = timeout or config.timeout
process = None
# Build safe environment with allowlist approach
# Only include essential environment variables
allowed_env_vars = [
'PATH', 'HOME', 'USER', 'LANG', 'LC_ALL', 'TERM',
'PYTHONPATH', 'NODE_PATH', 'OPENROUTER_API_KEY',
'ANTHROPIC_API_KEY', 'OPENAI_API_KEY', 'GOOGLE_API_KEY',
'TMPDIR', 'TEMP', 'TMP', 'USERPROFILE', 'SYSTEMROOT',
]
safe_env = {}
for key in allowed_env_vars:
if key in os.environ:
safe_env[key] = os.environ[key]
# Add config-specified env vars with validation
for key, value in config.env.items():
# Validate env var name (alphanumeric and underscore only)
if not re.match(r'^[A-Z_][A-Z0-9_]*$', key):
import logging
logging.getLogger(__name__).warning(
f"Skipping invalid environment variable name: {key}"
)
continue
safe_env[key] = value
stdout_chunks = []
stderr_chunks = []
async def _read_stream(stream, is_stderr: bool):
while True:
line = await stream.readline()
if not line:
break
text = line.decode("utf-8", errors="replace")
if is_stderr:
stderr_chunks.append(text)
else:
stdout_chunks.append(text)
if on_output:
try:
await on_output(text, is_stderr)
except Exception:
pass # Ignore callback errors
try:
process = await asyncio.create_subprocess_exec(
*resolved_cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=safe_env,
)
stdout_chunks = []
stderr_chunks = []
async def read_stream(stream, chunks, callback=None):
while True:
line = await stream.readline()
if not line:
break
decoded_line = line.decode("utf-8", errors="replace")
chunks.append(decoded_line)
if callback:
try:
if asyncio.iscoroutinefunction(callback):
await callback(decoded_line.strip())
else:
callback(decoded_line.strip())
except Exception:
pass # Ignore callback errors to prevent crashing execution
# Create tasks for reading stdout and stderr
stdout_task = asyncio.create_task(
read_stream(process.stdout, stdout_chunks, progress_callback)
)
stderr_task = asyncio.create_task(
read_stream(process.stderr, stderr_chunks)
)
# Wait for everything to finish or timeout
try:
# We wait for the process AND the stream readers
# This ensures we don't timeout if the process is done but streams are still being read
# and conversely, we DO timeout if streams are blocked even if process is done (unlikely but possible)
# or if process is hanging.
await asyncio.wait_for(
asyncio.gather(process.wait(), stdout_task, stderr_task),
timeout=timeout_seconds
)
except asyncio.TimeoutError:
# Timeout occurred - clean up everything
if process:
try:
process.kill()
except ProcessLookupError:
pass
# Cancel stream readers
stdout_task.cancel()
stderr_task.cancel()
# Wait for cancellation to complete
try:
await asyncio.gather(stdout_task, stderr_task, return_exceptions=True)
except Exception:
pass
raise TimeoutError(
f"Orchestrator '{orchestrator_name}' timed out after {timeout_seconds}s"
)
return (
"".join(stdout_chunks),
"".join(stderr_chunks),
process.returncode or 0,
)
except Exception as e:
if process and process.returncode is None:
try:
process.kill()
await process.wait()
except ProcessLookupError:
pass
if isinstance(e, (TimeoutError, RuntimeError)):
raise e
raise RuntimeError(
f"Orchestrator '{orchestrator_name}' failed: {str(e)}"
) from e
def validate_all(self) -> dict[str, bool]:
"""
Validate all registered orchestrators are available.
Returns:
dict: {orchestrator_name: is_available}
"""
results = {}
for name, config in self.orchestrators.items():
cmd = config.command if isinstance(config.command, str) else config.command[0]
try:
subprocess.run(
["which", cmd] if subprocess.os.name != "nt" else ["where", cmd],
capture_output=True,
check=True,
)
results[name] = True
except subprocess.CalledProcessError:
results[name] = False
return results