"""
Tools for the semantic search agent.
Available tools:
- ListDir: List directory contents with file statistics
- RunTerminalCommand: Execute shell commands (ripgrep, grep, find, etc.)
- SubmitAnswer: Submit final search results
"""
from __future__ import annotations
import os
import subprocess
from dataclasses import dataclass
from pathlib import Path
from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field
# ============================================================================
# Tool Schemas (for LLM structured output)
# ============================================================================
class ListDirInput(BaseModel):
"""Input for ListDir tool."""
path: str = Field(
description="Path to directory to list (relative to repo root or absolute)"
)
max_depth: int = Field(
default=1,
description="Maximum depth to recurse (1 = immediate children only)"
)
class RunTerminalCommandInput(BaseModel):
"""Input for RunTerminalCommand tool."""
command: str = Field(
description="Shell command to execute. Use ripgrep (rg) for searching code."
)
working_dir: Optional[str] = Field(
default=None,
description="Working directory for command (default: repo root)"
)
class SearchResultItem(BaseModel):
"""A single search result item."""
file_path: str = Field(
description="Relative path to the file from repo root"
)
content: str = Field(
description="The relevant code snippet"
)
line_start: Optional[int] = Field(
default=None,
description="Starting line number of the snippet"
)
line_end: Optional[int] = Field(
default=None,
description="Ending line number of the snippet"
)
class SubmitAnswerInput(BaseModel):
"""Input for SubmitAnswer tool - final response."""
items: List[SearchResultItem] = Field(
description="List of relevant code snippets found (max 10)"
)
reasoning: str = Field(
description="Brief explanation of why these results are relevant"
)
class ToolCall(BaseModel):
"""A tool call from the agent."""
tool: Literal["ListDir", "RunTerminalCommand", "SubmitAnswer"] = Field(
description="Name of the tool to call"
)
input: Union[ListDirInput, RunTerminalCommandInput, SubmitAnswerInput] = Field(
description="Input parameters for the tool"
)
class AgentResponse(BaseModel):
"""Response from the agent containing a tool call."""
thought: str = Field(
description="Agent's reasoning about what to do next"
)
tool_call: ToolCall = Field(
description="The tool to call with its parameters"
)
# ============================================================================
# Tool Implementations
# ============================================================================
@dataclass
class ListDirResult:
"""Result of ListDir tool."""
path: str
items: List[str]
stats: dict
error: Optional[str] = None
def to_string(self) -> str:
if self.error:
return f"Error: {self.error}"
lines = [f"Directory: {self.path}"]
lines.append(f"Stats: {self.stats['total_files']} files, {self.stats['total_dirs']} dirs")
if self.stats.get('by_extension'):
ext_info = ", ".join(
f"{ext}: {count}"
for ext, count in sorted(
self.stats['by_extension'].items(),
key=lambda x: -x[1]
)[:10]
)
lines.append(f"Extensions: {ext_info}")
lines.append("\nContents:")
for item in self.items[:50]: # Limit output
lines.append(f" {item}")
if len(self.items) > 50:
lines.append(f" ... and {len(self.items) - 50} more")
return "\n".join(lines)
@dataclass
class RunCommandResult:
"""Result of RunTerminalCommand tool."""
command: str
stdout: str
stderr: str
return_code: int
error: Optional[str] = None
def to_string(self) -> str:
if self.error:
return f"Error: {self.error}"
lines = [f"Command: {self.command}"]
lines.append(f"Exit code: {self.return_code}")
if self.stdout:
# Limit output
stdout_lines = self.stdout.split('\n')
if len(stdout_lines) > 100:
output = '\n'.join(stdout_lines[:100])
output += f"\n... ({len(stdout_lines) - 100} more lines)"
else:
output = self.stdout
lines.append(f"\nOutput:\n{output}")
if self.stderr:
lines.append(f"\nStderr:\n{self.stderr[:500]}")
return "\n".join(lines)
def execute_list_dir(input: ListDirInput, repo_path: str) -> ListDirResult:
"""Execute ListDir tool."""
try:
# Resolve path
if os.path.isabs(input.path):
target_path = Path(input.path)
else:
target_path = Path(repo_path) / input.path
if not target_path.exists():
return ListDirResult(
path=str(target_path),
items=[],
stats={},
error=f"Path does not exist: {target_path}"
)
if not target_path.is_dir():
return ListDirResult(
path=str(target_path),
items=[],
stats={},
error=f"Not a directory: {target_path}"
)
# Collect items
items = []
stats = {
'total_files': 0,
'total_dirs': 0,
'by_extension': {}
}
def process_dir(path: Path, depth: int, prefix: str = ""):
if depth > input.max_depth:
return
try:
entries = sorted(path.iterdir(), key=lambda p: (not p.is_dir(), p.name.lower()))
except PermissionError:
return
for entry in entries:
# Skip hidden files and common non-code dirs
if entry.name.startswith('.'):
continue
if entry.name in ('node_modules', '__pycache__', 'venv', '.git', 'dist', 'build'):
items.append(f"{prefix}{entry.name}/ [skipped]")
continue
if entry.is_dir():
stats['total_dirs'] += 1
items.append(f"{prefix}{entry.name}/")
if depth < input.max_depth:
process_dir(entry, depth + 1, prefix + " ")
else:
stats['total_files'] += 1
ext = entry.suffix.lower() or '[no ext]'
stats['by_extension'][ext] = stats['by_extension'].get(ext, 0) + 1
items.append(f"{prefix}{entry.name}")
process_dir(target_path, 1)
# Make path relative for display
try:
display_path = str(target_path.relative_to(repo_path))
except ValueError:
display_path = str(target_path)
return ListDirResult(
path=display_path,
items=items,
stats=stats
)
except Exception as e:
return ListDirResult(
path=input.path,
items=[],
stats={},
error=str(e)
)
def execute_run_command(
input: RunTerminalCommandInput,
repo_path: str,
allowed_commands: Optional[List[str]] = None,
timeout: float = 30.0,
) -> RunCommandResult:
"""
Execute RunTerminalCommand tool.
Security: By default, only allows certain safe commands.
Args:
input: Command input
repo_path: Repository root path
allowed_commands: List of allowed command prefixes
timeout: Command timeout in seconds (default: 30)
"""
# Default allowed command prefixes
if allowed_commands is None:
allowed_commands = [
'rg ', 'rg\t', # ripgrep
'grep ', 'grep\t',
'find ',
'head ', 'tail ',
'cat ',
'wc ',
'ls ',
'file ',
]
# Security check
cmd_lower = input.command.strip().lower()
is_allowed = any(cmd_lower.startswith(prefix) for prefix in allowed_commands)
if not is_allowed:
return RunCommandResult(
command=input.command,
stdout="",
stderr="",
return_code=-1,
error=f"Command not allowed. Allowed prefixes: {', '.join(allowed_commands)}"
)
# Resolve working directory
if input.working_dir:
if os.path.isabs(input.working_dir):
cwd = input.working_dir
else:
cwd = str(Path(repo_path) / input.working_dir)
else:
cwd = repo_path
try:
result = subprocess.run(
input.command,
shell=True,
capture_output=True,
text=True,
timeout=timeout,
cwd=cwd,
)
return RunCommandResult(
command=input.command,
stdout=result.stdout,
stderr=result.stderr,
return_code=result.returncode,
)
except subprocess.TimeoutExpired:
return RunCommandResult(
command=input.command,
stdout="",
stderr="",
return_code=-1,
error=f"Command timed out after {timeout:.0f} seconds"
)
except Exception as e:
return RunCommandResult(
command=input.command,
stdout="",
stderr="",
return_code=-1,
error=str(e)
)