from __future__ import annotations
import re
import shlex
import subprocess
import logging
from dataclasses import dataclass
from .models import GuestExecResult
from .util import LabError, run_command
logger = logging.getLogger(__name__)
SAFE_ARG_PATTERN = re.compile(r"^[A-Za-z0-9_./:=,+\-@]+$")
PID_PATTERN = re.compile(r"^[0-9]+$")
ALLOWLIST = {
"ps": "ps",
"cat_proc_maps": "cat_proc_maps",
"sha256sum": "sha256sum",
"tar": "tar",
"true": "true",
"sleep": "sleep",
}
@dataclass
class GuestConnection:
host: str
port: int
user: str
private_key_path: str
known_hosts_path: str
def ssh_base_args(self) -> list[str]:
return [
"ssh",
"-i",
self.private_key_path,
"-p",
str(self.port),
"-o",
"StrictHostKeyChecking=no",
"-o",
f"UserKnownHostsFile={self.known_hosts_path}",
f"{self.user}@{self.host}",
]
def scp_base_args(self) -> list[str]:
return [
"scp",
"-i",
self.private_key_path,
"-P",
str(self.port),
"-o",
"StrictHostKeyChecking=no",
"-o",
f"UserKnownHostsFile={self.known_hosts_path}",
]
def validate_safe_args(args: list[str]) -> None:
for arg in args:
if not SAFE_ARG_PATTERN.fullmatch(arg):
raise LabError(
code="unsafe_argument",
message="Argument contains unsupported characters for allowlisted execution.",
details={"argument": arg},
)
def build_allowlisted_command(allowed_command: str, args: list[str]) -> list[str]:
if allowed_command not in ALLOWLIST:
raise LabError(
code="command_not_allowlisted",
message=f"Command '{allowed_command}' is not allowlisted.",
details={"allowed": sorted(ALLOWLIST.keys())},
)
if allowed_command in {"ps", "tar"}:
validate_safe_args(args)
return [allowed_command, *args]
if allowed_command == "cat_proc_maps":
if len(args) != 1 or not PID_PATTERN.fullmatch(args[0]):
raise LabError(
code="invalid_pid",
message="cat_proc_maps expects one numeric PID argument.",
details={"args": args},
)
return ["cat", f"/proc/{args[0]}/maps"]
if allowed_command == "sha256sum":
if len(args) != 1:
raise LabError(
code="invalid_arguments",
message="sha256sum expects exactly one path argument.",
details={"args": args},
)
validate_safe_args(args)
return ["sha256sum", args[0]]
if allowed_command == "true":
if args:
raise LabError(
code="invalid_arguments",
message="true does not accept arguments.",
details={"args": args},
)
return ["/bin/true"]
if allowed_command == "sleep":
if len(args) != 1 or not args[0].isdigit():
raise LabError(
code="invalid_arguments",
message="sleep expects one integer argument.",
details={"args": args},
)
return ["sleep", args[0]]
raise LabError(code="command_not_implemented", message=f"Unsupported allowlisted command: {allowed_command}")
def shell_join(args: list[str]) -> str:
return " ".join(shlex.quote(arg) for arg in args)
def run_ssh_command(
connection: GuestConnection,
command_args: list[str],
*,
timeout: float = 60.0,
) -> GuestExecResult:
remote_command = shell_join(command_args)
ssh_args = [*connection.ssh_base_args(), "--", remote_command]
logger.debug("ssh_exec host=%s port=%s timeout=%s command=%s", connection.host, connection.port, timeout, remote_command)
try:
result = run_command(ssh_args, timeout=timeout, check=False)
timed_out = False
except subprocess.TimeoutExpired as exc:
logger.warning(
"ssh_exec_timeout host=%s port=%s timeout=%s command=%s",
connection.host,
connection.port,
timeout,
remote_command,
)
return GuestExecResult(
vm_id="",
command=command_args[0] if command_args else "",
args=command_args[1:],
stdout=exc.stdout or "",
stderr=f"SSH command timed out after {timeout}s",
exit_code=-1,
timed_out=True,
)
logger.debug(
"ssh_exec_done host=%s port=%s exit_code=%s stdout_len=%s stderr_len=%s command=%s",
connection.host,
connection.port,
result.returncode,
len(result.stdout),
len(result.stderr),
remote_command,
)
return GuestExecResult(
vm_id="",
command=command_args[0] if command_args else "",
args=command_args[1:],
stdout=result.stdout,
stderr=result.stderr,
exit_code=result.returncode,
timed_out=timed_out,
)
def scp_copy_out(
connection: GuestConnection,
guest_path: str,
host_path: str,
*,
recursive: bool = False,
timeout: float = 120.0,
) -> None:
copy_args = connection.scp_base_args()
if recursive:
copy_args.append("-r")
copy_args.extend([f"{connection.user}@{connection.host}:{guest_path}", host_path])
logger.debug(
"scp_copy_out host=%s port=%s timeout=%s guest_path=%s host_path=%s recursive=%s",
connection.host,
connection.port,
timeout,
guest_path,
host_path,
recursive,
)
result = run_command(copy_args, timeout=timeout, check=False)
if result.returncode != 0:
raise LabError(
code="scp_copy_failed",
message="Failed to copy file from guest.",
details={"guest_path": guest_path, "host_path": host_path, "stderr": result.stderr, "stdout": result.stdout},
)
def scp_copy_in(
connection: GuestConnection,
host_path: str,
guest_path: str,
*,
recursive: bool = False,
timeout: float = 120.0,
) -> None:
copy_args = connection.scp_base_args()
if recursive:
copy_args.append("-r")
copy_args.extend([host_path, f"{connection.user}@{connection.host}:{guest_path}"])
logger.debug(
"scp_copy_in host=%s port=%s timeout=%s host_path=%s guest_path=%s recursive=%s",
connection.host,
connection.port,
timeout,
host_path,
guest_path,
recursive,
)
result = run_command(copy_args, timeout=timeout, check=False)
if result.returncode != 0:
raise LabError(
code="scp_copy_in_failed",
message="Failed to copy file to guest.",
details={"guest_path": guest_path, "host_path": host_path, "stderr": result.stderr, "stdout": result.stdout},
)