"""Claude Agent SDK slash command handlers.
Implements:
- claude_slash_list: list available slash commands (via system/init)
- claude_slash_run: run a specific slash command
Follows the same handler patterns as `handlers/cli.py`:
- validate inputs
- invoke the backend (SDK invoker)
- format response (+ optionally append to handoff_file)
"""
from __future__ import annotations
import asyncio
import logging
from pathlib import Path
from typing import Any
import anyio
from mcp.types import TextContent
from .base import ToolContext, ToolHandler
from ..shared.claude_agent_sdk_invoker import (
ClaudeAgentSDKInvoker,
ClaudeSlashListParams,
ClaudeSlashRunParams,
)
from ..shared.invokers.types import Permission
from ..shared.response_formatter import (
DebugInfo as FormatterDebugInfo,
ResponseData,
format_error_response,
get_formatter,
)
from ..utils.xml_wrapper import build_wrapper
__all__ = ["ClaudeSlashListHandler", "ClaudeSlashRunHandler"]
logger = logging.getLogger(__name__)
# Progress report interval (seconds) - keep-alive for long-running tasks
PROGRESS_REPORT_INTERVAL = 30
def _normalize_workspace(workspace_raw: str) -> Path:
workspace = Path(workspace_raw).expanduser()
return workspace.resolve() if not workspace.is_absolute() else workspace
def _append_to_handoff_file(
*,
agent: str,
continuation_id: str,
task_note: str,
task_index: int,
status: str,
prompt: str,
response_markdown: str,
handoff_file_path: str,
workspace: Path,
) -> tuple[bool, str]:
"""Append wrapped markdown output to handoff_file (best-effort)."""
try:
wrapped = build_wrapper(
agent,
continuation_id,
task_note,
task_index,
status,
prompt,
response_markdown,
)
# Keep behavior consistent with `handlers/cli.py`:
# - expanduser before workspace-relative join (so "~" works even when relative)
# - if relative, resolve against workspace
handoff_path = Path(handoff_file_path).expanduser()
if not handoff_path.is_absolute():
handoff_path = workspace / handoff_path
handoff_path = handoff_path.resolve()
handoff_path.parent.mkdir(parents=True, exist_ok=True)
if handoff_path.exists():
with handoff_path.open("a", encoding="utf-8") as f:
f.write("\n" + wrapped)
else:
handoff_path.write_text(wrapped, encoding="utf-8")
logger.info("Appended output to: %s", handoff_path)
return True, str(handoff_path)
except Exception as e:
# `handoff_file` is purely best-effort persistence; never fail the tool call on write errors.
resolved = str(handoff_file_path)
try:
handoff_path = Path(handoff_file_path).expanduser()
if not handoff_path.is_absolute():
handoff_path = workspace / handoff_path
resolved = str(handoff_path.resolve())
except Exception:
# Keep the original user-provided path if we cannot resolve it.
pass
logger.warning(
"Failed to save output to %s: %s",
resolved,
e,
exc_info=True,
)
return False, resolved
class ClaudeSlashListHandler(ToolHandler):
"""Handler for `claude_slash_list`."""
@property
def name(self) -> str:
return "claude_slash_list"
@property
def description(self) -> str:
from ..tool_schema import TOOL_DESCRIPTIONS
return TOOL_DESCRIPTIONS.get(self.name, "")
def get_input_schema(self) -> dict[str, Any]:
from ..tool_schema import create_claude_slash_list_schema
return create_claude_slash_list_schema()
def validate(self, arguments: dict[str, Any]) -> str | None:
workspace = arguments.get("workspace")
if not workspace or not str(workspace).strip():
return "workspace is required"
return None
async def handle(self, arguments: dict[str, Any], ctx: ToolContext) -> list[TextContent]:
error = self.validate(arguments)
if error:
return format_error_response(error)
task_note = str(arguments.get("task_note", "") or "")
debug_enabled = ctx.resolve_debug(arguments)
workspace = _normalize_workspace(str(arguments["workspace"]))
handoff_file = str(arguments.get("handoff_file", "") or "").strip()
# The SDK requires a prompt even for discovery; keep it simple.
prompt = "Hello"
ctx.push_user_prompt("claude", prompt, task_note)
event_callback = ctx.make_event_callback("claude", task_note, 0)
invoker = ClaudeAgentSDKInvoker(event_callback=event_callback)
progress_task: asyncio.Task[None] | None = None
progress_counter = 0
async def progress_reporter() -> None:
nonlocal progress_counter
try:
while True:
await asyncio.sleep(PROGRESS_REPORT_INTERVAL)
progress_counter += 1
await ctx.report_progress_safe(
progress=progress_counter,
message=f"Processing... ({progress_counter * PROGRESS_REPORT_INTERVAL}s)",
)
except (anyio.get_cancelled_exc_class(), asyncio.CancelledError):
raise
except Exception as e:
logger.warning("Progress reporter crashed: %s", e, exc_info=True)
async def stop_progress_reporter() -> None:
nonlocal progress_task
if not progress_task:
return
if not progress_task.done():
progress_task.cancel()
try:
await progress_task
except (anyio.get_cancelled_exc_class(), asyncio.CancelledError):
pass
except Exception as e:
logger.warning("Progress reporter task failed: %s", e, exc_info=True)
finally:
progress_task = None
try:
if ctx.has_progress_token():
progress_task = asyncio.create_task(progress_reporter())
result = await invoker.list_slash_commands(
ClaudeSlashListParams(
workspace=workspace,
task_note=task_note,
)
)
formatter = get_formatter()
response_data = ResponseData(
answer=result.agent_messages,
message=None,
session_id=result.session_id or "",
thought_steps=result.thought_steps if not result.success else [],
debug_info=None,
success=result.success,
error=result.error,
)
# `handoff_file` is optional persistence only; the full response is always returned.
handoff_written: bool | None = None
resolved_handoff: str | None = None
if handoff_file:
file_content = formatter.format_for_file(response_data)
status = "success" if result.success else "error"
wrote, resolved = _append_to_handoff_file(
agent=self.name,
continuation_id=result.session_id or "",
task_note=task_note,
task_index=0,
status=status,
prompt=prompt,
response_markdown=file_content,
handoff_file_path=handoff_file,
workspace=workspace,
)
handoff_written = wrote
resolved_handoff = resolved
if debug_enabled:
response_data.debug_info = FormatterDebugInfo(
model=result.debug_info.model if result.debug_info else None,
duration_sec=result.debug_info.duration_sec if result.debug_info else 0.0,
message_count=result.debug_info.message_count if result.debug_info else 0,
tool_call_count=result.debug_info.tool_call_count if result.debug_info else 0,
input_tokens=result.debug_info.input_tokens if result.debug_info else None,
output_tokens=result.debug_info.output_tokens if result.debug_info else None,
cancelled=result.cancelled,
log_file=ctx.config.log_file if ctx.config.log_debug else None,
handoff_file=resolved_handoff or None,
handoff_file_written=handoff_written,
)
response = formatter.format(response_data, debug=debug_enabled)
await stop_progress_reporter()
await ctx.report_progress_safe(
progress=100,
total=100,
message="Completed" if result.success else "Failed",
)
return [TextContent(type="text", text=response)]
except anyio.get_cancelled_exc_class():
logger.info("Tool '%s' cancelled", self.name)
raise
except asyncio.CancelledError:
logger.info("Tool '%s' cancelled via asyncio.CancelledError", self.name)
raise
except Exception as e:
logger.error("Tool '%s' error: %s", self.name, e, exc_info=True)
await stop_progress_reporter()
await ctx.report_progress_safe(progress=100, total=100, message="Failed")
return format_error_response(str(e))
finally:
await stop_progress_reporter()
class ClaudeSlashRunHandler(ToolHandler):
"""Handler for `claude_slash_run`."""
@property
def name(self) -> str:
return "claude_slash_run"
@property
def description(self) -> str:
from ..tool_schema import TOOL_DESCRIPTIONS
return TOOL_DESCRIPTIONS.get(self.name, "")
def get_input_schema(self) -> dict[str, Any]:
from ..tool_schema import create_claude_slash_run_schema
return create_claude_slash_run_schema()
def validate(self, arguments: dict[str, Any]) -> str | None:
workspace = arguments.get("workspace")
handoff_file = arguments.get("handoff_file")
slash_command = arguments.get("slash_command")
if not workspace or not str(workspace).strip():
return "workspace is required"
if not handoff_file or not str(handoff_file).strip():
return "handoff_file is required"
if not slash_command or not str(slash_command).strip():
return "slash_command is required"
return None
async def handle(self, arguments: dict[str, Any], ctx: ToolContext) -> list[TextContent]:
error = self.validate(arguments)
if error:
return format_error_response(error)
task_note = str(arguments.get("task_note", "") or "")
debug_enabled = ctx.resolve_debug(arguments)
workspace = _normalize_workspace(str(arguments["workspace"]))
handoff_file = str(arguments["handoff_file"])
permission = Permission(arguments.get("permission", "read-only"))
max_turns = int(arguments.get("max_turns", 200) or 200)
model = str(arguments.get("model", "") or "")
slash_command = str(arguments.get("slash_command", "") or "")
command_prompt = str(arguments.get("prompt", "") or "")
# GUI user prompt (actual SDK prompt).
cmd_name = slash_command.strip()
sdk_prompt = cmd_name if cmd_name.startswith("/") else f"/{cmd_name}"
if command_prompt.strip():
sdk_prompt += f" {command_prompt.strip()}"
ctx.push_user_prompt("claude", sdk_prompt, task_note)
event_callback = ctx.make_event_callback("claude", task_note, 0)
invoker = ClaudeAgentSDKInvoker(event_callback=event_callback)
progress_task: asyncio.Task[None] | None = None
progress_counter = 0
async def progress_reporter() -> None:
nonlocal progress_counter
try:
while True:
await asyncio.sleep(PROGRESS_REPORT_INTERVAL)
progress_counter += 1
await ctx.report_progress_safe(
progress=progress_counter,
message=f"Processing... ({progress_counter * PROGRESS_REPORT_INTERVAL}s)",
)
except (anyio.get_cancelled_exc_class(), asyncio.CancelledError):
raise
except Exception as e:
logger.warning("Progress reporter crashed: %s", e, exc_info=True)
async def stop_progress_reporter() -> None:
nonlocal progress_task
if not progress_task:
return
if not progress_task.done():
progress_task.cancel()
try:
await progress_task
except (anyio.get_cancelled_exc_class(), asyncio.CancelledError):
pass
except Exception as e:
logger.warning("Progress reporter task failed: %s", e, exc_info=True)
finally:
progress_task = None
try:
if ctx.has_progress_token():
progress_task = asyncio.create_task(progress_reporter())
result = await invoker.run_slash(
ClaudeSlashRunParams(
workspace=workspace,
slash_command=slash_command,
prompt=command_prompt,
max_turns=max_turns,
permission=permission,
model=model,
task_note=task_note,
)
)
formatter = get_formatter()
response_data = ResponseData(
answer=result.agent_messages,
message=None,
session_id=result.session_id or "",
thought_steps=result.thought_steps if not result.success else [],
debug_info=None,
success=result.success,
error=result.error,
)
file_content = formatter.format_for_file(response_data)
status = "success" if result.success else "error"
handoff_written, resolved_handoff = _append_to_handoff_file(
agent=self.name,
continuation_id=result.session_id or "",
task_note=task_note,
task_index=0,
status=status,
prompt=sdk_prompt,
response_markdown=file_content,
handoff_file_path=handoff_file,
workspace=workspace,
)
if debug_enabled:
response_data.debug_info = FormatterDebugInfo(
model=result.debug_info.model if result.debug_info else None,
duration_sec=result.debug_info.duration_sec if result.debug_info else 0.0,
message_count=result.debug_info.message_count if result.debug_info else 0,
tool_call_count=result.debug_info.tool_call_count if result.debug_info else 0,
input_tokens=result.debug_info.input_tokens if result.debug_info else None,
output_tokens=result.debug_info.output_tokens if result.debug_info else None,
cancelled=result.cancelled,
log_file=ctx.config.log_file if ctx.config.log_debug else None,
handoff_file=resolved_handoff or None,
handoff_file_written=handoff_written,
)
response = formatter.format(response_data, debug=debug_enabled)
await stop_progress_reporter()
await ctx.report_progress_safe(
progress=100,
total=100,
message="Completed" if result.success else "Failed",
)
return [TextContent(type="text", text=response)]
except anyio.get_cancelled_exc_class():
logger.info("Tool '%s' cancelled", self.name)
raise
except asyncio.CancelledError:
logger.info("Tool '%s' cancelled via asyncio.CancelledError", self.name)
raise
except Exception as e:
logger.error("Tool '%s' error: %s", self.name, e, exc_info=True)
await stop_progress_reporter()
await ctx.report_progress_safe(progress=100, total=100, message="Failed")
return format_error_response(str(e))
finally:
await stop_progress_reporter()