# src/fctr_okta_mcp/security/subprocess_executor.py
"""
Secure Subprocess Code Executor for Okta MCP Server
Executes generated Python code in an isolated subprocess with:
- Timeout controls
- Real-time output streaming
- Memory limits (where supported)
- Proper cleanup of temp files
Adapted from: https://github.com/fctr-id/okta-ai-agent/blob/main/src/core/orchestration/modern_execution_manager.py
"""
import asyncio
import json
import os
import sys
import time
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Awaitable, Union
from dataclasses import dataclass
# Configurable timeout (default 5 minutes)
DEFAULT_TIMEOUT_SECONDS = int(os.environ.get("OKTA_MCP_EXECUTION_TIMEOUT_SECONDS", "300"))
from fctr_okta_mcp.security.code_validator import validate_generated_code
from fctr_okta_mcp.utils.logger import get_logger
logger = get_logger(__name__)
# Public API
__all__ = [
"ExecutionResult",
"SubprocessExecutor",
"execute_code_subprocess",
]
# Local temp directory for scripts (docker-friendly)
# Use a subdirectory under the package to avoid temp directory conflicts
TEMP_SCRIPT_DIR = Path(__file__).parent.parent.parent.parent / "temp_scripts"
def _ensure_temp_dir() -> Path:
"""Ensure temp directory exists with proper permissions."""
TEMP_SCRIPT_DIR.mkdir(parents=True, exist_ok=True)
return TEMP_SCRIPT_DIR
@dataclass
class ExecutionResult:
"""Result of code execution"""
success: bool
results: Any
error: Optional[str]
execution_time_ms: float
stdout: str
stderr: str
# Template for the subprocess script
# Results are written to a file instead of stdout to handle large datasets (50k+ records)
EXECUTION_TEMPLATE = '''
# Auto-generated execution script - DO NOT EDIT
import asyncio
import json
import sys
import os
from datetime import datetime, timezone, timedelta
from pathlib import Path
# Force unbuffered output for real-time streaming (critical on Windows)
sys.stdout.reconfigure(line_buffering=True)
sys.stderr.reconfigure(line_buffering=True)
# Add parent directory to path for imports
sys.path.insert(0, "{parent_path}")
# Configure logging to write to the same log files as the main server
from fctr_okta_mcp.utils.logger import setup_logging, get_logger
setup_logging()
logger = get_logger("fctr_okta_mcp.subprocess")
# Result file path passed from parent process
RESULT_FILE = Path("{result_file_path}")
async def main():
"""Main execution wrapper"""
try:
# Import the Okta client
from fctr_okta_mcp.client.base_okta_api_client import OktaAPIClient
print("[SUBPROCESS] Initializing...", flush=True)
# Create client (no MCP context in subprocess, but has env vars)
client = OktaAPIClient(ctx=None, is_test={is_test})
print("[SUBPROCESS] Executing query...", flush=True)
# User's generated code
{user_code}
# Execute the query
results = await execute_query(client)
result_count = len(results) if isinstance(results, list) else 1
print(f"[SUBPROCESS] Complete: {{result_count}} results", flush=True)
# Write results to file (handles large datasets - 50k+ records)
output = {{
"success": True,
"results": results,
"result_count": result_count
}}
with open(RESULT_FILE, 'w', encoding='utf-8') as f:
json.dump(output, f, default=str)
print("__EXECUTION_SUCCESS__", flush=True)
except Exception as e:
import traceback
print(f"[SUBPROCESS] ERROR: {{type(e).__name__}}: {{str(e)}}", flush=True)
logger.error(f"Execution failed: {{type(e).__name__}}: {{str(e)}}\\n{{traceback.format_exc()}}")
# Write error to file
error_output = {{
"success": False,
"error": str(e),
"error_type": type(e).__name__
}}
with open(RESULT_FILE, 'w', encoding='utf-8') as f:
json.dump(error_output, f)
print("__EXECUTION_FAILED__", flush=True)
sys.exit(1)
if __name__ == "__main__":
asyncio.run(main())
'''
class SubprocessExecutor:
"""
Executes generated Python code in an isolated subprocess.
Security features:
- Code runs in separate process (can't affect server)
- Timeout enforcement
- Output streaming for real-time progress
- Temp file cleanup
"""
def __init__(
self,
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
progress_callback: Optional[Callable[[str], Awaitable[None]]] = None
):
self.timeout_seconds = timeout_seconds
self.progress_callback = progress_callback
self.parent_path = str(Path(__file__).parent.parent.parent.parent)
async def execute(
self,
code: str,
is_test: bool = True,
env_vars: Optional[Dict[str, str]] = None
) -> ExecutionResult:
"""
Execute code in a subprocess with streaming output.
Args:
code: Python code containing async def execute_query(client)
is_test: Whether to run in test mode (limits results)
env_vars: Additional environment variables to pass
Returns:
ExecutionResult with success status, results, and output
"""
start_time = time.time()
# Step 1: Security validation (AST-based)
logger.debug("Running security validation...")
security_result = validate_generated_code(code)
if not security_result.is_valid:
return ExecutionResult(
success=False,
results=None,
error=f"Security validation failed: {'; '.join(security_result.violations)}",
execution_time_ms=0,
stdout="",
stderr=""
)
logger.debug("Security validation passed")
# Step 2: Create temp files for script and results
temp_file_path = None
result_file_path = None
try:
# Ensure local temp directory exists
temp_dir = _ensure_temp_dir()
# Create unique file names
timestamp = int(time.time() * 1000)
temp_file_path = temp_dir / f"exec_{timestamp}.py"
result_file_path = temp_dir / f"result_{timestamp}.json"
# Indent user code for the template
indented_code = "\n".join(f" {line}" for line in code.split("\n"))
script_content = EXECUTION_TEMPLATE.format(
parent_path=self.parent_path.replace("\\", "\\\\"),
is_test=str(is_test),
user_code=indented_code,
result_file_path=str(result_file_path).replace("\\", "\\\\")
)
# Write script to temp file
temp_file_path.write_text(script_content, encoding='utf-8')
logger.debug(f"Created temp script: {temp_file_path}")
logger.debug(f"Result file will be: {result_file_path}")
# Step 3: Build environment
env = os.environ.copy()
if env_vars:
env.update(env_vars)
# Step 4: Execute in subprocess
logger.debug(f"Starting subprocess with Python: {sys.executable}")
logger.debug(f"Script path: {temp_file_path}")
# Launch subprocess
logger.debug("Launching subprocess...")
# Execute subprocess
process = await asyncio.create_subprocess_exec(
sys.executable,
"-u", # Unbuffered stdout/stderr - critical for streaming progress
str(temp_file_path),
stdin=asyncio.subprocess.DEVNULL, # CRITICAL: Don't inherit stdin from MCP server
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=env,
cwd=str(temp_dir) # Set working directory
)
logger.debug(f"Subprocess started with PID: {process.pid}")
stdout_lines = []
stderr_lines = []
execution_success = False
# Stream stdout for progress updates only (results go to file)
async def read_stdout():
nonlocal execution_success
logger.debug("stdout reader started")
while True:
try:
line = await process.stdout.readline()
if not line:
logger.debug("stdout EOF reached")
break
decoded = line.decode('utf-8', errors='replace').rstrip()
# Check for completion markers
if decoded == "__EXECUTION_SUCCESS__":
execution_success = True
logger.debug("Subprocess completed successfully")
continue
elif decoded == "__EXECUTION_FAILED__":
execution_success = False
logger.warning("Subprocess execution failed")
continue
stdout_lines.append(decoded)
# Log important progress messages (not internal debug)
if decoded.startswith("[SUBPROCESS]"):
logger.debug(decoded)
# Send progress update if callback provided
if self.progress_callback and decoded:
await self._notify_progress(decoded)
except Exception as e:
logger.error(f"Error reading stdout: {e}")
break
# Stream stderr using explicit readline pattern
async def read_stderr():
logger.debug("stderr reader started")
rate_limit_count = 0
last_rate_limit_log_time = 0
while True:
try:
line = await process.stderr.readline()
if not line:
logger.debug("stderr EOF reached")
break
decoded = line.decode('utf-8', errors='replace').rstrip()
if decoded:
stderr_lines.append(decoded)
# Check if this is a rate limit message
is_rate_limit = "rate limit" in decoded.lower() or "Rate limit" in decoded
if is_rate_limit:
rate_limit_count += 1
# Only log rate limit messages every 5 seconds to reduce spam
current_time = time.time()
if current_time - last_rate_limit_log_time >= 5:
last_rate_limit_log_time = current_time
logger.warning(f"[subprocess] Rate limit hit ({rate_limit_count} total)")
# Notify progress to keep MCP connection alive
if self.progress_callback:
await self._notify_progress(f"[rate-limited] Waiting for Okta API ({rate_limit_count} retries)...")
else:
# Non-rate-limit stderr messages
# Only log as warning if it's actually a warning/error (not INFO)
if "| WARNING" in decoded or "| ERROR" in decoded:
logger.warning(f"[subprocess stderr] {decoded}")
if self.progress_callback:
await self._notify_progress(f"[api] {decoded[:100]}")
else:
logger.debug(f"[subprocess stderr] {decoded}")
except Exception as e:
logger.error(f"Error reading stderr: {e}")
break
# Heartbeat task to keep MCP connection alive during long operations
# MCP clients typically timeout after 60 seconds of no progress
heartbeat_active = True
heartbeat_count = 0
async def heartbeat():
"""Send periodic progress updates to keep MCP connection alive."""
nonlocal heartbeat_count
while heartbeat_active:
await asyncio.sleep(30) # Send heartbeat every 30 seconds
if heartbeat_active and self.progress_callback:
heartbeat_count += 1
await self._notify_progress(f"[heartbeat] Still processing... ({heartbeat_count * 30}s elapsed)")
logger.debug(f"Heartbeat #{heartbeat_count} sent to keep MCP connection alive")
# Run with timeout
logger.debug(f"Waiting for subprocess (timeout: {self.timeout_seconds}s)...")
stdout_task = asyncio.create_task(read_stdout())
stderr_task = asyncio.create_task(read_stderr())
wait_task = asyncio.create_task(process.wait())
heartbeat_task = asyncio.create_task(heartbeat())
try:
# Wait for process to complete (heartbeat runs in background)
await asyncio.wait_for(
asyncio.gather(stdout_task, stderr_task, wait_task),
timeout=self.timeout_seconds
)
# Stop heartbeat once process completes
heartbeat_active = False
heartbeat_task.cancel()
try:
await heartbeat_task
except asyncio.CancelledError:
pass
except asyncio.TimeoutError:
heartbeat_active = False
heartbeat_task.cancel()
try:
await heartbeat_task
except asyncio.CancelledError:
pass
process.kill()
await process.wait()
return ExecutionResult(
success=False,
results=None,
error=f"Execution timed out after {self.timeout_seconds} seconds",
execution_time_ms=(time.time() - start_time) * 1000,
stdout="\n".join(stdout_lines),
stderr="\n".join(stderr_lines)
)
execution_time_ms = (time.time() - start_time) * 1000
# Step 5: Read results from file (handles large datasets - 50k+ records)
result_json = None
if result_file_path and result_file_path.exists():
try:
logger.debug(f"Reading results from file: {result_file_path}")
with open(result_file_path, 'r', encoding='utf-8') as f:
result_json = json.load(f)
logger.debug(f"Successfully loaded results from file")
except json.JSONDecodeError as e:
logger.error(f"Failed to parse result JSON from file: {e}")
except Exception as e:
logger.error(f"Failed to read result file: {e}")
if result_json:
if result_json.get("success"):
result_count = result_json.get("result_count", 0)
logger.debug(f"Execution successful: {result_count} results")
return ExecutionResult(
success=True,
results=result_json.get("results"),
error=None,
execution_time_ms=execution_time_ms,
stdout="\n".join(stdout_lines),
stderr="\n".join(stderr_lines)
)
else:
return ExecutionResult(
success=False,
results=None,
error=result_json.get("error", "Unknown error"),
execution_time_ms=execution_time_ms,
stdout="\n".join(stdout_lines),
stderr="\n".join(stderr_lines)
)
else:
# No result file found - check exit code
if process.returncode == 0:
return ExecutionResult(
success=False,
results=None,
error="No result file created by execution",
execution_time_ms=execution_time_ms,
stdout="\n".join(stdout_lines),
stderr="\n".join(stderr_lines)
)
else:
return ExecutionResult(
success=False,
results=None,
error=f"Process exited with code {process.returncode}",
execution_time_ms=execution_time_ms,
stdout="\n".join(stdout_lines),
stderr="\n".join(stderr_lines)
)
except Exception as e:
logger.error(f"Subprocess execution failed: {e}", exc_info=True)
return ExecutionResult(
success=False,
results=None,
error=f"Execution failed: {str(e)}",
execution_time_ms=(time.time() - start_time) * 1000,
stdout="",
stderr=""
)
finally:
# Cleanup temp files
if temp_file_path and temp_file_path.exists():
try:
temp_file_path.unlink()
logger.debug(f"Cleaned up temp script: {temp_file_path}")
except Exception as e:
logger.warning(f"Failed to cleanup temp script: {e}")
if result_file_path and result_file_path.exists():
try:
result_file_path.unlink()
logger.debug(f"Cleaned up result file: {result_file_path}")
except Exception as e:
logger.warning(f"Failed to cleanup result file: {e}")
async def _notify_progress(self, message: str):
"""Send progress notification via callback"""
if self.progress_callback:
try:
if asyncio.iscoroutinefunction(self.progress_callback):
await self.progress_callback(message)
else:
self.progress_callback(message)
except Exception as e:
logger.warning(f"Progress callback failed: {e}")
# Convenience function
async def execute_code_subprocess(
code: str,
is_test: bool = True,
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
progress_callback: Optional[Callable[[str], None]] = None,
env_vars: Optional[Dict[str, str]] = None
) -> ExecutionResult:
"""
Execute generated Python code in an isolated subprocess.
Args:
code: Python code containing async def execute_query(client)
is_test: Whether to run in test mode (limits results to 3)
timeout_seconds: Maximum execution time
progress_callback: Async callback for progress updates
env_vars: Additional environment variables
Returns:
ExecutionResult with success status and results
"""
executor = SubprocessExecutor(
timeout_seconds=timeout_seconds,
progress_callback=progress_callback
)
return await executor.execute(code, is_test, env_vars)