autonomous_engine.pyβ’18.2 kB
"""
Autonomous Penetration Testing Decision Engine
Makes intelligent decisions about next steps based on findings
"""
import asyncio
import logging
from typing import Dict, List, Any, Optional
from datetime import datetime
import json
from lm_studio_client import LMStudioClient
import tools
logger = logging.getLogger(__name__)
class AutonomousDecisionEngine:
"""AI-powered autonomous penetration testing engine"""
def __init__(self, lm_studio_client: LMStudioClient):
self.lm_client = lm_studio_client
self.decision_history: List[Dict[str, Any]] = []
async def create_initial_plan(self, session: Dict[str, Any]) -> Dict[str, Any]:
"""Create initial penetration testing plan"""
target = session["target"]
scope = session.get("scope", [target])
roe = session.get("roe", {})
depth = session.get("depth", "vulnerability_scan")
objectives = []
if depth == "reconnaissance":
objectives = ["Gather information about target"]
elif depth == "vulnerability_scan":
objectives = ["Identify services", "Find vulnerabilities"]
elif depth == "exploitation":
objectives = ["Identify services", "Find vulnerabilities", "Attempt safe exploitation"]
elif depth == "post_exploitation":
objectives = ["Full penetration test", "Demonstrate impact"]
plan = await self.lm_client.create_attack_plan(
target=target,
scope=scope,
objectives=objectives,
constraints=roe
)
return plan
async def run_pentest(self, session: Dict[str, Any]):
"""Execute autonomous penetration test"""
logger.info(f"Starting autonomous pentest for session {session['id']}")
try:
# Phase 1: Reconnaissance
await self._run_reconnaissance_phase(session)
# Phase 2: Vulnerability Identification
if session["depth"] in ["vulnerability_scan", "exploitation", "post_exploitation"]:
await self._run_vulnerability_phase(session)
# Phase 3: Exploitation (if authorized)
if session["depth"] in ["exploitation", "post_exploitation"]:
await self._run_exploitation_phase(session)
# Phase 4: Post-exploitation (if authorized)
if session["depth"] == "post_exploitation":
await self._run_post_exploitation_phase(session)
session["status"] = "completed"
session["completed_at"] = datetime.now().isoformat()
except Exception as e:
logger.error(f"Error in autonomous pentest: {e}", exc_info=True)
session["status"] = "error"
session["error"] = str(e)
async def _run_reconnaissance_phase(self, session: Dict[str, Any]):
"""Phase 1: Information gathering"""
target = session["target"]
logger.info(f"Running reconnaissance phase on {target}")
session["timeline"].append({
"phase": "reconnaissance",
"started_at": datetime.now().isoformat()
})
# Determine target type (IP, domain, URL)
target_type = self._determine_target_type(target)
if target_type == "domain":
# DNS enumeration
logger.info("Running DNS enumeration")
dns_result = await tools.dns_enum(target)
self._record_finding(session, "dns_enum", dns_result)
# Subdomain enumeration results
if dns_result.get("subdomains"):
session["scope"].extend(dns_result["subdomains"][:10]) # Add top 10 subdomains
elif target_type in ["ip", "domain"]:
# Network reconnaissance
logger.info("Running Nmap scan")
nmap_result = await tools.nmap_scan(target, scan_type="full")
self._record_finding(session, "nmap_scan", nmap_result)
# Analyze services and suggest next steps
await self._analyze_services(session, nmap_result)
elif target_type == "url":
# Web application reconnaissance
logger.info("Running web reconnaissance")
# Nikto scan
nikto_result = await tools.nikto_scan(target)
self._record_finding(session, "nikto_scan", nikto_result)
# Directory enumeration
gobuster_result = await tools.gobuster_scan(
target,
wordlist="/usr/share/wordlists/dirb/common.txt",
extensions=["php", "asp", "aspx", "jsp", "html"]
)
self._record_finding(session, "gobuster_scan", gobuster_result)
# Get AI suggestions for next phase
suggestion = await self.get_suggestions(session, "Reconnaissance complete. What should we do next?")
session["timeline"][-1]["completed_at"] = datetime.now().isoformat()
session["timeline"][-1]["suggestion"] = suggestion
async def _run_vulnerability_phase(self, session: Dict[str, Any]):
"""Phase 2: Vulnerability identification"""
logger.info("Running vulnerability identification phase")
session["timeline"].append({
"phase": "vulnerability_identification",
"started_at": datetime.now().isoformat()
})
target = session["target"]
target_type = self._determine_target_type(target)
# Get findings from reconnaissance
recon_findings = [f for f in session["findings"] if f["phase"] == "reconnaissance"]
if target_type == "url":
# Web application vulnerability scanning
logger.info("Running web vulnerability scans")
# Nuclei scan
nuclei_result = await tools.nuclei_scan(
target,
severity=["critical", "high", "medium"]
)
self._record_finding(session, "nuclei_scan", nuclei_result, "vulnerability_identification")
# Check if it's WordPress
if any("wordpress" in str(f).lower() for f in recon_findings):
logger.info("WordPress detected, running WPScan")
wpscan_result = await tools.wpscan(target, enumerate="vp")
self._record_finding(session, "wpscan", wpscan_result, "vulnerability_identification")
# SQL injection testing on discovered forms/parameters
sqlmap_result = await tools.sqlmap_scan(target, level=2, risk=2)
self._record_finding(session, "sqlmap_scan", sqlmap_result, "vulnerability_identification")
# SSL/TLS analysis if HTTPS
if target.startswith("https://"):
hostname = target.replace("https://", "").split("/")[0]
ssl_result = await tools.ssl_scan(hostname, 443)
self._record_finding(session, "ssl_scan", ssl_result, "vulnerability_identification")
else:
# Network vulnerability scanning
logger.info("Running network vulnerability scans")
# Service-specific vulnerability checks based on open ports
for finding in recon_findings:
if finding["tool"] == "nmap" and "findings" in finding["result"]:
services = finding["result"]["findings"].get("services", [])
for service in services[:5]: # Limit to first 5 services
await self._scan_service_vulnerabilities(session, target, service)
# Analyze all vulnerabilities and prioritize
await self._analyze_vulnerabilities(session)
session["timeline"][-1]["completed_at"] = datetime.now().isoformat()
async def _run_exploitation_phase(self, session: Dict[str, Any]):
"""Phase 3: Safe exploitation attempts"""
logger.info("Running exploitation phase")
# Check rules of engagement
if not session.get("roe", {}).get("allow_exploitation", False):
logger.warning("Exploitation not authorized in rules of engagement")
return
session["timeline"].append({
"phase": "exploitation",
"started_at": datetime.now().isoformat()
})
# Get high-risk vulnerabilities
vulnerabilities = [
f for f in session["findings"]
if f["phase"] == "vulnerability_identification" and
self._is_exploitable(f)
]
# Sort by risk/severity
vulnerabilities.sort(key=lambda x: self._get_risk_score(x), reverse=True)
# Attempt exploitation on top vulnerabilities
for vuln in vulnerabilities[:3]: # Limit to top 3
await self._attempt_exploitation(session, vuln)
session["timeline"][-1]["completed_at"] = datetime.now().isoformat()
async def _run_post_exploitation_phase(self, session: Dict[str, Any]):
"""Phase 4: Post-exploitation activities"""
logger.info("Running post-exploitation phase")
if not session.get("roe", {}).get("allow_post_exploitation", False):
logger.warning("Post-exploitation not authorized")
return
session["timeline"].append({
"phase": "post_exploitation",
"started_at": datetime.now().isoformat()
})
# Check if we have any successful exploits
successful_exploits = [
f for f in session["findings"]
if f["phase"] == "exploitation" and f.get("result", {}).get("success")
]
if not successful_exploits:
logger.info("No successful exploits to follow up on")
return
# Post-exploitation activities would go here
# This is a placeholder for actual post-exploitation modules
session["timeline"][-1]["completed_at"] = datetime.now().isoformat()
async def _analyze_services(self, session: Dict[str, Any], nmap_result: Dict[str, Any]):
"""Analyze discovered services and identify potential vulnerabilities"""
if "findings" not in nmap_result or "services" not in nmap_result["findings"]:
return
services = nmap_result["findings"]["services"]
# Use AI to identify vulnerabilities
for service in services[:10]: # Limit to 10 services
vulnerabilities = await self.lm_client.identify_vulnerabilities(service)
if vulnerabilities:
self._record_finding(
session,
"vulnerability_identification",
{
"service": service,
"vulnerabilities": vulnerabilities
},
"reconnaissance"
)
async def _scan_service_vulnerabilities(
self,
session: Dict[str, Any],
target: str,
service: Dict[str, Any]
):
"""Scan specific service for vulnerabilities"""
service_name = service.get("service", "").lower()
port = service.get("port")
logger.info(f"Scanning service {service_name} on port {port}")
# Service-specific scanning
if service_name in ["smb", "microsoft-ds"]:
# SMB enumeration
enum_result = await tools.enum4linux(target)
self._record_finding(session, "enum4linux", enum_result, "vulnerability_identification")
elif service_name == "snmp":
# SNMP enumeration
snmp_result = await tools.snmp_check(target)
self._record_finding(session, "snmp_check", snmp_result, "vulnerability_identification")
elif service_name in ["http", "https"]:
# Web service scanning
url = f"http{'s' if service_name == 'https' else ''}://{target}:{port}"
nikto_result = await tools.nikto_scan(url)
self._record_finding(session, "nikto_scan", nikto_result, "vulnerability_identification")
# Search for known exploits
search_query = f"{service_name} {service.get('version', '')}"
exploits = await tools.searchsploit(search_query)
if exploits.get("exploits"):
self._record_finding(
session,
"exploit_search",
{
"service": service,
"exploits": exploits["exploits"][:5] # Top 5
},
"vulnerability_identification"
)
async def _analyze_vulnerabilities(self, session: Dict[str, Any]):
"""Analyze all vulnerabilities and create risk assessment"""
vuln_findings = [
f for f in session["findings"]
if f["phase"] == "vulnerability_identification"
]
if not vuln_findings:
return
# Use AI to analyze and prioritize
analysis = await self.lm_client.analyze_scan_results(
"vulnerability_assessment",
{
"target": session["target"],
"findings": vuln_findings
}
)
self._record_finding(
session,
"vulnerability_analysis",
analysis,
"vulnerability_identification"
)
async def _attempt_exploitation(self, session: Dict[str, Any], vulnerability: Dict[str, Any]):
"""Safely attempt to exploit a vulnerability"""
logger.info(f"Attempting exploitation of {vulnerability.get('tool')}")
# Get AI recommendation on exploitation approach
messages = [
{
"role": "system",
"content": "You are a penetration testing expert. Suggest a safe exploitation approach."
},
{
"role": "user",
"content": f"Vulnerability: {json.dumps(vulnerability)}\n\nSuggest exploitation approach."
}
]
suggestion = await self.lm_client.chat_completion(messages)
self._record_finding(
session,
"exploitation_attempt",
{
"vulnerability": vulnerability,
"approach": suggestion,
"status": "attempted"
},
"exploitation"
)
def _record_finding(
self,
session: Dict[str, Any],
tool: str,
result: Dict[str, Any],
phase: Optional[str] = None
):
"""Record finding in session"""
finding = {
"timestamp": datetime.now().isoformat(),
"tool": tool,
"result": result,
"phase": phase or "reconnaissance"
}
session["findings"].append(finding)
logger.info(f"Recorded finding from {tool} in phase {phase}")
def _determine_target_type(self, target: str) -> str:
"""Determine if target is IP, domain, or URL"""
if target.startswith("http://") or target.startswith("https://"):
return "url"
elif target.replace(".", "").isdigit():
return "ip"
else:
return "domain"
def _is_exploitable(self, finding: Dict[str, Any]) -> bool:
"""Determine if a finding is exploitable"""
result = finding.get("result", {})
# Check for explicit vulnerability markers
if result.get("vulnerable"):
return True
if result.get("vulnerabilities") and len(result["vulnerabilities"]) > 0:
return True
if result.get("exploits") and len(result["exploits"]) > 0:
return True
return False
def _get_risk_score(self, finding: Dict[str, Any]) -> int:
"""Calculate risk score for a finding"""
score = 0
result = finding.get("result", {})
# Check severity
if "critical" in str(result).lower():
score += 10
elif "high" in str(result).lower():
score += 7
elif "medium" in str(result).lower():
score += 5
elif "low" in str(result).lower():
score += 2
# Check for known exploits
if result.get("exploits"):
score += 5
# Check for CVE
if "cve" in str(result).lower():
score += 3
return score
async def get_suggestions(
self,
session: Dict[str, Any],
context: Optional[str] = None
) -> Dict[str, Any]:
"""Get AI-powered suggestions for next steps"""
completed_tools = list(set([f["tool"] for f in session["findings"]]))
suggestion = await self.lm_client.suggest_next_action(
current_findings=session["findings"],
completed_scans=completed_tools,
target_info={
"target": session["target"],
"scope": session["scope"]
},
rules_of_engagement=session.get("roe", {})
)
self.decision_history.append({
"timestamp": datetime.now().isoformat(),
"session_id": session["id"],
"context": context,
"suggestion": suggestion
})
return suggestion
async def auto_execute_suggestion(
self,
session: Dict[str, Any],
suggestion: Dict[str, Any]
) -> Dict[str, Any]:
"""Automatically execute the suggested action"""
action = suggestion.get("recommended_action")
parameters = suggestion.get("parameters", {})
logger.info(f"Auto-executing: {action} with parameters: {parameters}")
try:
# Get the tool function
tool_func = getattr(tools, action, None)
if tool_func and callable(tool_func):
result = await tool_func(**parameters)
self._record_finding(session, action, result)
return {
"success": True,
"action": action,
"result": result
}
else:
return {
"success": False,
"error": f"Tool {action} not found"
}
except Exception as e:
logger.error(f"Error auto-executing {action}: {e}")
return {
"success": False,
"error": str(e)
}