from __future__ import annotations
import ast
import logging
from contextlib import redirect_stderr
from contextlib import redirect_stdout
from dataclasses import dataclass
from dataclasses import field
from datetime import datetime
from io import StringIO
from pathlib import Path
import django
from asgiref.sync import sync_to_async
from django.apps import apps
logger = logging.getLogger(__name__)
class DjangoShell:
def __init__(self):
logger.debug("Initializing %s", self.__class__.__name__)
if not apps.ready: # pragma: no cover
logger.info("Django not initialized, running django.setup()")
django.setup()
logger.debug("Django setup completed")
else:
logger.debug("Django already initialized, skipping setup")
self.history: list[Result] = []
logger.info("Shell initialized successfully")
def clear_history(self):
"""Clear the execution history.
Removes all entries from the shell history. Useful for starting fresh
or removing exploratory code before exporting.
"""
logger.info("Clearing shell history - previous entries: %s", len(self.history))
self.history = []
def export_history(
self,
filename: str | None = None,
) -> str:
"""Export shell session history as a Python script.
Generates a Python script containing all successfully executed code
from the session. Failed executions are excluded. Import statements
are deduplicated and placed at the top. Output and execution results
are not included in the export.
Args:
filename: Relative path to save the script. If None, returns the
script content as a string. Absolute paths are rejected.
Returns:
The Python script as a string if filename is None, otherwise a
confirmation message with preview of the exported file.
Raises:
ValueError: If an absolute path is provided for filename.
"""
logger.info(
"Exporting history - entries: %s, filename: %s",
len(self.history),
filename or "None",
)
if not self.history:
return "# No history to export\n"
imports_set = set()
steps = []
successful_codes = [
result.code for result in self.history if not isinstance(result, ErrorResult)
]
for step_num, code in enumerate(successful_codes, start=1):
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, (ast.Import, ast.ImportFrom)):
imports_set.add(ast.unparse(node))
except SyntaxError:
pass
steps.append(f"# Step {step_num}")
steps.append(code)
steps.append("")
script_parts = [
"# Django Shell Session Export",
f"# Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
"",
]
if imports_set:
script_parts.extend(sorted(imports_set))
script_parts.append("")
script_parts.extend(steps)
script = "\n".join(script_parts)
if filename:
# Security: only allow relative paths
if Path(filename).is_absolute():
raise ValueError("Absolute paths not allowed for security reasons")
if not filename.endswith(".py"):
filename += ".py"
filepath = Path(filename)
filepath.write_text(script)
logger.info("Exported history to file: %s", filepath)
line_count = len(script.split("\n"))
preview_lines = script.split("\n")[:20]
preview = "\n".join(preview_lines)
if line_count > 20:
preview += f"\n... ({line_count - 20} more lines)"
return f"Exported {line_count} lines to {filename}\n\n{preview}"
return script
async def execute(self, code: str) -> Result:
"""Execute Python code in the Django shell context (async).
Async wrapper around the synchronous _execute() method. Delegates
execution to a thread pool via sync_to_async to avoid Django's
SynchronousOnlyOperation errors when called from async contexts.
Args:
code: Python code to execute.
Returns:
StatementResult or ErrorResult depending on execution outcome.
"""
return await sync_to_async(self._execute)(code)
def _execute(self, code: str) -> Result:
"""Execute Python code in the Django shell context (synchronous).
Executes code in a fresh global namespace for stateless behavior,
ensuring code changes take effect and no stale modules persist between
executions. Captures stdout/stderr and saves results to history.
Args:
code: Python code to execute.
Returns:
StatementResult if execution succeeds.
ErrorResult if execution raises an exception.
"""
code_preview = (code[:100] + "..." if len(code) > 100 else code).replace(
"\n", "\\n"
)
logger.info("Executing code: %s", code_preview)
stdout = StringIO()
stderr = StringIO()
with redirect_stdout(stdout), redirect_stderr(stderr):
try:
logger.debug(
"Code to execute: %s",
code[:200] + "..." if len(code) > 200 else code,
)
exec(code, {})
logger.debug("Code executed successfully")
return self.save_result(
StatementResult(
code=code,
stdout=stdout.getvalue(),
stderr=stderr.getvalue(),
)
)
except Exception as e:
logger.error(
"Exception during code execution: %s - Code: %s",
f"{type(e).__name__}: {e}",
code_preview,
)
logger.debug("Full traceback for error:", exc_info=True)
return self.save_result(
ErrorResult(
code=code,
exception=e,
stdout=stdout.getvalue(),
stderr=stderr.getvalue(),
)
)
def save_result(self, result: Result) -> Result:
self.history.append(result)
return result
@dataclass
class StatementResult:
code: str
stdout: str
stderr: str
timestamp: datetime = field(default_factory=datetime.now)
def __post_init__(self):
logger.debug("%s created", self.__class__.__name__)
if self.stdout:
logger.debug("%s.stdout: %s", self.__class__.__name__, self.stdout[:200])
if self.stderr:
logger.debug("%s.stderr: %s", self.__class__.__name__, self.stderr[:200])
@dataclass
class ErrorResult:
code: str
exception: Exception
stdout: str
stderr: str
timestamp: datetime = field(default_factory=datetime.now)
def __post_init__(self):
logger.debug(
"%s created - exception type: %s",
self.__class__.__name__,
type(self.exception).__name__,
)
logger.debug("%s.message: %s", self.__class__.__name__, str(self.exception))
if self.stdout:
logger.debug("%s.stdout: %s", self.__class__.__name__, self.stdout[:200])
if self.stderr:
logger.debug("%s.stderr: %s", self.__class__.__name__, self.stderr[:200])
Result = StatementResult | ErrorResult
django_shell = DjangoShell()