"""Code validation and sandboxed execution for custom Python tools."""
from __future__ import annotations
import ast
import signal
import textwrap
from typing import Any
from . import config, database
# ---------------------------------------------------------------------------
# Static validation via AST
# ---------------------------------------------------------------------------
class _CodeValidator(ast.NodeVisitor):
"""Walk the AST and reject dangerous constructs."""
def __init__(self) -> None:
self.errors: list[str] = []
# --- imports -----------------------------------------------------------
def visit_Import(self, node: ast.Import) -> None:
for alias in node.names:
mod = alias.name.split(".")[0]
if mod not in config.ALLOWED_IMPORTS:
self.errors.append(f"Import not allowed: {alias.name}")
self.generic_visit(node)
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if node.module:
mod = node.module.split(".")[0]
if mod not in config.ALLOWED_IMPORTS:
self.errors.append(f"Import not allowed: {node.module}")
self.generic_visit(node)
# --- dangerous builtins ------------------------------------------------
def visit_Name(self, node: ast.Name) -> None:
if node.id in config.BLOCKED_IDENTIFIERS:
self.errors.append(f"Blocked identifier: {node.id}")
self.generic_visit(node)
def visit_Attribute(self, node: ast.Attribute) -> None:
if node.attr in config.BLOCKED_IDENTIFIERS:
self.errors.append(f"Blocked attribute: {node.attr}")
self.generic_visit(node)
# --- string-based eval/exec inside calls --------------------------------
def visit_Call(self, node: ast.Call) -> None:
if isinstance(node.func, ast.Name) and node.func.id in ("eval", "exec", "compile", "__import__"):
self.errors.append(f"Blocked call: {node.func.id}()")
self.generic_visit(node)
def validate_code(source: str) -> list[str]:
"""Return a list of validation errors, or empty list if code is safe."""
try:
tree = ast.parse(source)
except SyntaxError as exc:
return [f"Syntax error: {exc}"]
validator = _CodeValidator()
validator.visit(tree)
return validator.errors
# ---------------------------------------------------------------------------
# Execution
# ---------------------------------------------------------------------------
def _timeout_handler(signum: int, frame: Any) -> None:
raise TimeoutError("Custom tool execution timed out.")
def execute_tool_code(
source: str,
params: dict[str, Any] | None = None,
timeout: int = config.EXEC_TIMEOUT_SECONDS,
) -> Any:
"""Validate and execute a custom tool's ``run()`` function.
Parameters
----------
source : str
Python source that defines a ``run(db, **kwargs)`` function.
params : dict
Keyword arguments forwarded to ``run()``.
timeout : int
Max execution seconds.
Returns
-------
The return value of ``run()``.
"""
errors = validate_code(source)
if errors:
raise ValueError(f"Code validation failed:\n" + "\n".join(f" - {e}" for e in errors))
source = textwrap.dedent(source)
# Build a restricted global namespace
import collections
import datetime
import json
import math
import re
import statistics
import numpy
import pandas
safe_globals: dict[str, Any] = {
"__builtins__": {
# safe subset of builtins
"True": True,
"False": False,
"None": None,
"abs": abs,
"all": all,
"any": any,
"bool": bool,
"dict": dict,
"enumerate": enumerate,
"filter": filter,
"float": float,
"frozenset": frozenset,
"int": int,
"isinstance": isinstance,
"len": len,
"list": list,
"map": map,
"max": max,
"min": min,
"print": print,
"range": range,
"reversed": reversed,
"round": round,
"set": set,
"sorted": sorted,
"str": str,
"sum": sum,
"tuple": tuple,
"type": type,
"zip": zip,
"ValueError": ValueError,
"TypeError": TypeError,
"KeyError": KeyError,
"IndexError": IndexError,
"Exception": Exception,
},
"pandas": pandas,
"pd": pandas,
"numpy": numpy,
"np": numpy,
"datetime": datetime,
"json": json,
"math": math,
"statistics": statistics,
"re": re,
"collections": collections,
}
local_ns: dict[str, Any] = {}
# Compile and exec the source to define `run()`
code_obj = compile(source, "<custom_tool>", "exec")
# Set timeout (Unix only; on Windows this is a no-op)
old_handler = None
try:
old_handler = signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(timeout)
except (OSError, AttributeError):
pass
try:
exec(code_obj, safe_globals, local_ns) # noqa: S102 — intentional sandboxed exec
run_fn = local_ns.get("run")
if run_fn is None:
raise ValueError("Custom tool must define a `run()` function.")
# Provide a read-only DB helper
def db_query(sql: str, params_inner: dict | None = None) -> list[dict]:
database.validate_sql(sql)
return database.execute_query(sql, params_inner)
kw = dict(params or {})
result = run_fn(db_query, **kw)
return result
finally:
try:
signal.alarm(0)
if old_handler is not None:
signal.signal(signal.SIGALRM, old_handler)
except (OSError, AttributeError):
pass