"""
Enhanced Trigger System for Pentest MCP Server.
Implements intelligent AI suspension/resumption based on output patterns.
"""
import asyncio
import re
import time
import logging
from typing import Dict, List, Any, Optional, Tuple
from .ssh_manager import SSHManager
from .tmux_manager import TmuxManager
logger = logging.getLogger(__name__)
class TriggerSystem:
"""Manages trigger-based monitoring for AI suspension with enhanced features."""
# Comprehensive shell prompt patterns
PROMPT_PATTERNS = [
r'[\$#>]\s*$', # Standard bash/sh prompts
r'root@.*[#\$]\s*$', # root@hostname#
r'kali@.*[$]\s*$', # kali@hostname$
r'parrot@.*[$]\s*$', # parrot@hostname$
r'backbox@.*[$]\s*$', # backbox@hostname$
r'msf6?\s*>', # Metasploit prompt
r'\(meterpreter\)\s*>', # Meterpreter prompt
r'mysql>\s*$', # MySQL prompt
r'ftp>\s*$', # FTP prompt
r'>>>\s*$', # Python prompt
r'In \[\d+\]:\s*$', # IPython prompt
r'\[sudo\]\s*password', # Sudo password prompt
]
# Error detection patterns
ERROR_PATTERNS = [
r'error:',
r'failed',
r'exception',
r'traceback',
r'command not found',
r'permission denied',
r'connection refused',
r'no route to host',
r'timeout',
r'killed',
r'segmentation fault'
]
def __init__(self, ssh_manager: SSHManager, tmux_manager: TmuxManager):
self.ssh = ssh_manager
self.tmux = tmux_manager
self.active_monitors = {}
self.output_buffers = {} # Track output per session
async def monitor_session_with_triggers(
self,
session_id: str,
triggers: List[Dict[str, Any]],
max_timeout: int = 300,
poll_interval: float = 1.0
) -> Dict[str, Any]:
"""
Enhanced monitoring with comprehensive trigger support.
Args:
session_id: Tmux session identifier
triggers: List of trigger configurations
max_timeout: Maximum wait time in seconds
poll_interval: How often to check output (seconds)
Returns:
Monitoring result with detailed trigger information
"""
start_time = time.time()
# Initialize output buffer for this session
if session_id not in self.output_buffers:
self.output_buffers[session_id] = OutputBuffer()
output_buffer = self.output_buffers[session_id]
# Add implicit timeout trigger if not present
has_timeout = any(t.get("type") == "timeout" for t in triggers)
if not has_timeout:
triggers.append({"type": "timeout", "timeout_seconds": max_timeout})
# Validate triggers before starting
is_valid, error_msg = self.validate_triggers(triggers)
if not is_valid:
return {
"status": "error",
"error": f"Invalid trigger configuration: {error_msg}",
"execution_time": 0
}
logger.info(f"Starting trigger monitoring for session '{session_id}' with {len(triggers)} triggers")
while True:
elapsed_time = time.time() - start_time
# Check overall timeout
if elapsed_time >= max_timeout:
return {
"status": "timeout",
"message": f"Maximum timeout of {max_timeout} seconds reached",
"execution_time": elapsed_time
}
try:
# Capture current output using tmux manager
capture_result = await self.tmux.capture_pane(session_id)
if capture_result["status"] != "success":
return {
"status": "error",
"error": f"Failed to capture output: {capture_result.get('message', 'Unknown error')}",
"execution_time": elapsed_time
}
current_output = capture_result["output"]
new_output, full_output = output_buffer.get_new_output(current_output)
# Check all triggers
for trigger in triggers:
trigger_result = await self._check_trigger(
trigger, new_output, full_output, start_time, session_id
)
if trigger_result["matched"]:
logger.info(f"Trigger matched: {trigger_result['trigger_name']} for session '{session_id}'")
# Check for errors in output
error_info = self._detect_error_in_output(full_output)
if error_info["error_detected"]:
trigger_result["error_detected"] = error_info
return {
"status": "trigger_matched",
"trigger": trigger_result,
"output": full_output,
"execution_time": elapsed_time,
"session_id": session_id
}
# Sleep before next check
await asyncio.sleep(poll_interval)
except Exception as e:
logger.error(f"Monitoring error for session '{session_id}': {e}")
return {
"status": "error",
"error": f"Monitoring failed: {str(e)}",
"execution_time": time.time() - start_time
}
async def _check_trigger(
self,
trigger: Dict[str, Any],
new_output: str,
full_output: str,
start_time: float,
session_id: str
) -> Dict[str, Any]:
"""Check if a trigger condition is met with enhanced logic."""
trigger_type = trigger.get("type")
if trigger_type == "prompt":
return self._check_prompt_trigger(full_output)
elif trigger_type == "regex":
return self._check_regex_trigger(trigger, new_output, full_output)
elif trigger_type == "timeout":
return self._check_timeout_trigger(start_time, trigger)
elif trigger_type == "file_exists":
return await self._check_file_exists_trigger(trigger, session_id)
elif trigger_type == "multi":
return await self._check_multi_trigger(trigger, new_output, full_output, start_time, session_id)
else:
return {"matched": False, "trigger_name": "unknown"}
def _check_prompt_trigger(self, full_output: str) -> Dict[str, Any]:
"""Check if output ends with a shell prompt."""
# Look at last 200 characters only (prompts are at end)
tail = full_output[-200:] if len(full_output) > 200 else full_output
for pattern in self.PROMPT_PATTERNS:
match = re.search(pattern, tail, re.MULTILINE)
if match:
return {
"matched": True,
"trigger_name": "prompt_detected",
"trigger_type": "prompt",
"matched_pattern": pattern
}
return {"matched": False}
def _check_regex_trigger(self, trigger: Dict[str, Any], new_output: str, full_output: str) -> Dict[str, Any]:
"""Check if regex pattern matches in output."""
pattern = trigger["pattern"]
search_in = trigger.get("search_in", "new") # "new" or "full"
text = new_output if search_in == "new" else full_output
match = re.search(pattern, text, re.IGNORECASE | re.MULTILINE)
if match:
return {
"matched": True,
"trigger_name": trigger.get("name", "regex_match"),
"trigger_type": "regex",
"matched_text": match.group(0),
"match_position": match.start()
}
return {"matched": False}
def _check_timeout_trigger(self, start_time: float, trigger: Dict[str, Any]) -> Dict[str, Any]:
"""Check if timeout condition is met."""
timeout_seconds = trigger.get("timeout_seconds", 300)
elapsed = time.time() - start_time
if elapsed >= timeout_seconds:
return {
"matched": True,
"trigger_name": "timeout",
"trigger_type": "timeout",
"elapsed_seconds": elapsed,
"timeout_seconds": timeout_seconds
}
return {"matched": False}
async def _check_file_exists_trigger(self, trigger: Dict[str, Any], session_id: str) -> Dict[str, Any]:
"""Check if file exists on remote system."""
path = trigger["path"]
try:
# Use the session to check file existence
cmd = f"test -f '{path}' && echo EXISTS || echo NOTFOUND"
result = await self.tmux.execute_command(session_id, cmd)
if result.get("status") == "sent":
# Wait a moment for command to execute
await asyncio.sleep(0.5)
output_result = await self.tmux.capture_pane(session_id)
if "EXISTS" in output_result.get("output", ""):
# Get file size for validation
stat_cmd = f"stat -c '%s' '{path}' 2>/dev/null || echo 0"
await self.tmux.execute_command(session_id, stat_cmd)
await asyncio.sleep(0.5)
stat_output = await self.tmux.capture_pane(session_id)
file_size = 0
for line in stat_output.get("output", "").split('\n'):
if line.strip().isdigit():
file_size = int(line.strip())
break
return {
"matched": True,
"trigger_name": trigger.get("name", "file_exists"),
"trigger_type": "file_exists",
"file_path": path,
"file_size": file_size
}
except Exception as e:
logger.warning(f"File existence check failed for {path}: {e}")
return {"matched": False}
async def _check_multi_trigger(
self,
trigger: Dict[str, Any],
new_output: str,
full_output: str,
start_time: float,
session_id: str
) -> Dict[str, Any]:
"""Check multiple conditions with AND/OR logic."""
operator = trigger.get("operator", "AND")
conditions = trigger["conditions"]
results = []
for condition in conditions:
result = await self._check_trigger(condition, new_output, full_output, start_time, session_id)
results.append(result["matched"])
if operator == "AND":
matched = all(results)
else: # OR
matched = any(results)
return {
"matched": matched,
"trigger_name": trigger.get("name", "multi_condition"),
"trigger_type": "multi",
"conditions_met": results,
"operator": operator
}
def _detect_error_in_output(self, output: str) -> Dict[str, Any]:
"""Automatically detect errors in output."""
for pattern in self.ERROR_PATTERNS:
match = re.search(pattern, output, re.IGNORECASE)
if match:
context = self._extract_error_context(output, pattern)
return {
"error_detected": True,
"error_pattern": pattern,
"context": context,
"suggestion": self._get_error_suggestion(pattern)
}
return {"error_detected": False}
def _extract_error_context(self, output: str, pattern: str) -> str:
"""Extract context around error for better debugging."""
lines = output.split('\n')
for i, line in enumerate(lines):
if re.search(pattern, line, re.IGNORECASE):
start = max(0, i - 3)
end = min(len(lines), i + 4)
return '\n'.join(lines[start:end])
return ""
def _get_error_suggestion(self, pattern: str) -> str:
"""Get helpful suggestions for common errors."""
suggestions = {
"command not found": "Tool may not be installed. Try: apt install <tool>",
"permission denied": "Need elevated privileges. Try: sudo <command>",
"connection refused": "Port may be closed or firewalled",
"no route to host": "Target unreachable. Check network/VPN",
"error:": "Check command syntax and parameters",
"failed": "Operation failed, check dependencies and permissions"
}
for error_pattern, suggestion in suggestions.items():
if error_pattern in pattern.lower():
return suggestion
return "Check command and system status"
def validate_triggers(self, triggers: List[Dict[str, Any]]) -> Tuple[bool, str]:
"""Validate trigger configuration before execution."""
if not triggers:
return True, ""
for trigger in triggers:
if "type" not in trigger:
return False, "Trigger missing 'type' field"
trigger_type = trigger["type"]
if trigger_type == "regex":
if "pattern" not in trigger:
return False, "Regex trigger missing 'pattern' field"
elif trigger_type == "timeout":
if "timeout_seconds" not in trigger:
return False, "Timeout trigger missing 'timeout_seconds' field"
if trigger["timeout_seconds"] <= 0:
return False, "Timeout must be positive"
elif trigger_type == "file_exists":
if "path" not in trigger:
return False, "File exists trigger missing 'path' field"
elif trigger_type == "multi":
if "conditions" not in trigger:
return False, "Multi trigger missing 'conditions' field"
if not isinstance(trigger["conditions"], list):
return False, "Multi trigger conditions must be a list"
# Validate sub-conditions
for condition in trigger["conditions"]:
is_valid, error = self.validate_triggers([condition])
if not is_valid:
return False, f"Invalid condition in multi trigger: {error}"
elif trigger_type not in ["prompt"]:
return False, f"Unknown trigger type: {trigger_type}"
return True, ""
class OutputBuffer:
"""Efficient output buffer management for trigger monitoring."""
def __init__(self):
self.last_line_count = 0
self.full_output = ""
def get_new_output(self, current_output: str) -> Tuple[str, str]:
"""Get only new lines since last check."""
current_lines = current_output.split('\n')
new_lines = current_lines[self.last_line_count:]
self.last_line_count = len(current_lines)
new_output = '\n'.join(new_lines)
self.full_output = current_output
return new_output, self.full_output
def reset(self):
"""Reset buffer state."""
self.last_line_count = 0
self.full_output = ""