Compare commits
2 Commits
0cf0d01657
...
d845fa45a3
| Author | SHA1 | Date | |
|---|---|---|---|
| d845fa45a3 | |||
| f60c47a85f |
@@ -41,6 +41,13 @@ class SessionContext:
|
|||||||
self._message_count += 1
|
self._message_count += 1
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
def pop_last_message(self) -> Message | None:
|
||||||
|
"""Remove and return the last message, or None if history is empty."""
|
||||||
|
if self._history:
|
||||||
|
self._message_count -= 1
|
||||||
|
return self._history.pop()
|
||||||
|
return None
|
||||||
|
|
||||||
def get_history(self) -> list[Message]:
|
def get_history(self) -> list[Message]:
|
||||||
"""Return a shallow copy of the conversation history."""
|
"""Return a shallow copy of the conversation history."""
|
||||||
return list(self._history)
|
return list(self._history)
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ from app.utils.logging import get_logger
|
|||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
_MAX_REASONING_RETRIES = 2
|
||||||
|
|
||||||
|
|
||||||
class AgentLoop:
|
class AgentLoop:
|
||||||
"""ReAct-style agent loop that streams LLM responses and executes tool calls.
|
"""ReAct-style agent loop that streams LLM responses and executes tool calls.
|
||||||
@@ -81,6 +83,7 @@ class AgentLoop:
|
|||||||
self._ctx.add_message("user", user_input)
|
self._ctx.add_message("user", user_input)
|
||||||
|
|
||||||
max_iter = self._config.agent.max_iterations
|
max_iter = self._config.agent.max_iterations
|
||||||
|
reasoning_only_streak = 0
|
||||||
for iteration in range(1, max_iter + 1):
|
for iteration in range(1, max_iter + 1):
|
||||||
# Check token budget
|
# Check token budget
|
||||||
if self._ctx.token_counter.is_over_budget():
|
if self._ctx.token_counter.is_over_budget():
|
||||||
@@ -102,6 +105,9 @@ class AgentLoop:
|
|||||||
tool_calls=assistant_msg.tool_calls,
|
tool_calls=assistant_msg.tool_calls,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Detect reasoning-only response (model thought but produced nothing)
|
||||||
|
reasoning_only = self._handler.had_reasoning_only
|
||||||
|
|
||||||
# Record token usage
|
# Record token usage
|
||||||
if self._handler.usage:
|
if self._handler.usage:
|
||||||
self._ctx.token_counter.count_usage(self._handler.usage)
|
self._ctx.token_counter.count_usage(self._handler.usage)
|
||||||
@@ -114,6 +120,29 @@ class AgentLoop:
|
|||||||
|
|
||||||
self._handler.reset()
|
self._handler.reset()
|
||||||
|
|
||||||
|
# Reasoning-only: model produced thinking tokens but no content or tool calls.
|
||||||
|
if reasoning_only:
|
||||||
|
reasoning_only_streak += 1
|
||||||
|
self._ctx.pop_last_message()
|
||||||
|
|
||||||
|
if reasoning_only_streak >= _MAX_REASONING_RETRIES:
|
||||||
|
# Nudge the model by injecting a user hint
|
||||||
|
print_warning(
|
||||||
|
f"Model produced reasoning but no response {reasoning_only_streak} times. "
|
||||||
|
"Nudging model to respond..."
|
||||||
|
)
|
||||||
|
self._ctx.add_message(
|
||||||
|
"user",
|
||||||
|
"Please respond with your answer. Do not just think — provide your actual response.",
|
||||||
|
)
|
||||||
|
reasoning_only_streak = 0
|
||||||
|
else:
|
||||||
|
print_warning("Model produced reasoning but no response. Retrying...")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Successful response — reset streak
|
||||||
|
reasoning_only_streak = 0
|
||||||
|
|
||||||
# No tool calls → task complete (plain text response)
|
# No tool calls → task complete (plain text response)
|
||||||
if not assistant_msg.tool_calls:
|
if not assistant_msg.tool_calls:
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -147,6 +147,11 @@ class StreamHandler:
|
|||||||
"""Token usage reported by the API, if available."""
|
"""Token usage reported by the API, if available."""
|
||||||
return self._usage
|
return self._usage
|
||||||
|
|
||||||
|
@property
|
||||||
|
def had_reasoning_only(self) -> bool:
|
||||||
|
"""True if the model produced reasoning tokens but no content or tool calls."""
|
||||||
|
return bool(self._accumulated_reasoning) and not self._accumulated_content and not self._tool_calls
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
"""Clear all accumulators for the next turn."""
|
"""Clear all accumulators for the next turn."""
|
||||||
self._accumulated_content = ""
|
self._accumulated_content = ""
|
||||||
|
|||||||
208
app/tools/edit.py
Normal file
208
app/tools/edit.py
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
"""Edit tools: str_replace and patch_apply."""
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.models.tool_call import ToolResult, ToolResultStatus
|
||||||
|
from app.tools.base import BaseTool
|
||||||
|
from app.utils.file_helpers import (
|
||||||
|
FileSizeError,
|
||||||
|
PathSecurityError,
|
||||||
|
resolve_safe_path,
|
||||||
|
safe_read_file,
|
||||||
|
safe_write_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
_PATCH_TIMEOUT = 30
|
||||||
|
|
||||||
|
|
||||||
|
class StrReplaceParams(BaseModel):
|
||||||
|
"""Parameters for the str_replace tool."""
|
||||||
|
|
||||||
|
file_path: str = Field(description="Path to the file to edit (relative to workspace root)")
|
||||||
|
old_str: str = Field(description="The exact string to find and replace (must be unique in file)")
|
||||||
|
new_str: str = Field(description="The replacement string")
|
||||||
|
|
||||||
|
|
||||||
|
class StrReplaceTool(BaseTool):
|
||||||
|
"""Replace a unique string occurrence in a file."""
|
||||||
|
|
||||||
|
name = "str_replace"
|
||||||
|
description = (
|
||||||
|
"Replace exactly one occurrence of old_str with new_str in a file. "
|
||||||
|
"Fails if old_str is not found or appears more than once."
|
||||||
|
)
|
||||||
|
params_model = StrReplaceParams
|
||||||
|
|
||||||
|
def execute(
|
||||||
|
self, *, tool_call_id: str, file_path: str, old_str: str, new_str: str, **kwargs: Any
|
||||||
|
) -> ToolResult:
|
||||||
|
fs_config = self.config.tools.filesystem
|
||||||
|
|
||||||
|
# Read the file
|
||||||
|
try:
|
||||||
|
content = safe_read_file(
|
||||||
|
file_path,
|
||||||
|
self.workspace_root,
|
||||||
|
max_size_bytes=fs_config.max_file_size_bytes,
|
||||||
|
check_binary=fs_config.binary_detection,
|
||||||
|
)
|
||||||
|
except PathSecurityError as exc:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=str(exc),
|
||||||
|
)
|
||||||
|
except FileNotFoundError as exc:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=str(exc),
|
||||||
|
)
|
||||||
|
except FileSizeError as exc:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=str(exc),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Count occurrences
|
||||||
|
count = content.count(old_str)
|
||||||
|
if count == 0:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=f"old_str not found in {file_path}",
|
||||||
|
)
|
||||||
|
if count > 1:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=f"old_str appears {count} times in {file_path} (must be unique)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform replacement and write back
|
||||||
|
new_content = content.replace(old_str, new_str, 1)
|
||||||
|
try:
|
||||||
|
safe_write_file(
|
||||||
|
file_path,
|
||||||
|
new_content,
|
||||||
|
self.workspace_root,
|
||||||
|
max_size_bytes=fs_config.max_file_size_bytes,
|
||||||
|
)
|
||||||
|
except PathSecurityError as exc:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=str(exc),
|
||||||
|
)
|
||||||
|
except FileSizeError as exc:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=str(exc),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
safe_path = resolve_safe_path(file_path, self.workspace_root)
|
||||||
|
rel_path = safe_path.relative_to(self.workspace_root)
|
||||||
|
except (PathSecurityError, ValueError):
|
||||||
|
rel_path = Path(file_path)
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.SUCCESS,
|
||||||
|
output=f"Replaced 1 occurrence in {rel_path}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PatchApplyParams(BaseModel):
|
||||||
|
"""Parameters for the patch_apply tool."""
|
||||||
|
|
||||||
|
file_path: str = Field(description="Path to the file to patch (relative to workspace root)")
|
||||||
|
patch: str = Field(description="Unified diff format patch to apply")
|
||||||
|
|
||||||
|
|
||||||
|
class PatchApplyTool(BaseTool):
|
||||||
|
"""Apply a unified diff patch to a file."""
|
||||||
|
|
||||||
|
name = "patch_apply"
|
||||||
|
description = (
|
||||||
|
"Apply a unified diff (patch) to a file. The patch must be in standard "
|
||||||
|
"unified diff format."
|
||||||
|
)
|
||||||
|
params_model = PatchApplyParams
|
||||||
|
|
||||||
|
def execute(self, *, tool_call_id: str, file_path: str, patch: str, **kwargs: Any) -> ToolResult:
|
||||||
|
try:
|
||||||
|
safe_path = resolve_safe_path(file_path, self.workspace_root)
|
||||||
|
except PathSecurityError as exc:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=str(exc),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not safe_path.exists():
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=f"File not found: {safe_path}",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
["patch", "--forward", "--no-backup-if-mismatch", str(safe_path)],
|
||||||
|
input=patch,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=_PATCH_TIMEOUT,
|
||||||
|
cwd=self.workspace_root,
|
||||||
|
)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=f"Patch timed out after {_PATCH_TIMEOUT}s",
|
||||||
|
)
|
||||||
|
except FileNotFoundError:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error="'patch' command not found on system",
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=f"Patch failed (exit {result.returncode}): {result.stderr or result.stdout}",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
rel_path = safe_path.relative_to(self.workspace_root)
|
||||||
|
except ValueError:
|
||||||
|
rel_path = safe_path
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.SUCCESS,
|
||||||
|
output=f"Patch applied to {rel_path}",
|
||||||
|
)
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Filesystem tools: read_file and list_dir."""
|
"""Filesystem tools: read_file, list_dir, write_file, make_dir, delete_file."""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -13,6 +13,7 @@ from app.utils.file_helpers import (
|
|||||||
PathSecurityError,
|
PathSecurityError,
|
||||||
resolve_safe_path,
|
resolve_safe_path,
|
||||||
safe_read_file,
|
safe_read_file,
|
||||||
|
safe_write_file,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -147,3 +148,175 @@ class ListDirTool(BaseTool):
|
|||||||
status=ToolResultStatus.SUCCESS,
|
status=ToolResultStatus.SUCCESS,
|
||||||
output="\n".join(lines),
|
output="\n".join(lines),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WriteFileParams(BaseModel):
|
||||||
|
"""Parameters for the write_file tool."""
|
||||||
|
|
||||||
|
file_path: str = Field(description="Path to the file to write (relative to workspace root)")
|
||||||
|
content: str = Field(description="Content to write to the file")
|
||||||
|
|
||||||
|
|
||||||
|
class WriteFileTool(BaseTool):
|
||||||
|
"""Write content to a file within the workspace."""
|
||||||
|
|
||||||
|
name = "write_file"
|
||||||
|
description = (
|
||||||
|
"Write text content to a file. Creates parent directories if needed. "
|
||||||
|
"Overwrites existing file content."
|
||||||
|
)
|
||||||
|
params_model = WriteFileParams
|
||||||
|
|
||||||
|
def execute(self, *, tool_call_id: str, file_path: str, content: str, **kwargs: Any) -> ToolResult:
|
||||||
|
fs_config = self.config.tools.filesystem
|
||||||
|
try:
|
||||||
|
safe_path = safe_write_file(
|
||||||
|
file_path,
|
||||||
|
content,
|
||||||
|
self.workspace_root,
|
||||||
|
max_size_bytes=fs_config.max_file_size_bytes,
|
||||||
|
)
|
||||||
|
except PathSecurityError as exc:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=str(exc),
|
||||||
|
)
|
||||||
|
except FileSizeError as exc:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=str(exc),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
rel_path = safe_path.relative_to(self.workspace_root)
|
||||||
|
except ValueError:
|
||||||
|
rel_path = safe_path
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.SUCCESS,
|
||||||
|
output=f"Wrote {len(content)} characters to {rel_path}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MakeDirParams(BaseModel):
|
||||||
|
"""Parameters for the make_dir tool."""
|
||||||
|
|
||||||
|
directory_path: str = Field(description="Path to the directory to create (relative to workspace root)")
|
||||||
|
|
||||||
|
|
||||||
|
class MakeDirTool(BaseTool):
|
||||||
|
"""Create a directory (and any missing parents) within the workspace."""
|
||||||
|
|
||||||
|
name = "make_dir"
|
||||||
|
description = "Create a directory and any necessary parent directories."
|
||||||
|
params_model = MakeDirParams
|
||||||
|
|
||||||
|
def execute(self, *, tool_call_id: str, directory_path: str, **kwargs: Any) -> ToolResult:
|
||||||
|
try:
|
||||||
|
safe_path = resolve_safe_path(directory_path, self.workspace_root)
|
||||||
|
except PathSecurityError as exc:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=str(exc),
|
||||||
|
)
|
||||||
|
|
||||||
|
if safe_path.exists() and not safe_path.is_dir():
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=f"Path exists and is not a directory: {safe_path}",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
safe_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
except OSError as exc:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=f"Failed to create directory: {exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
rel_path = safe_path.relative_to(self.workspace_root)
|
||||||
|
except ValueError:
|
||||||
|
rel_path = safe_path
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.SUCCESS,
|
||||||
|
output=f"Created directory: {rel_path}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteFileParams(BaseModel):
|
||||||
|
"""Parameters for the delete_file tool."""
|
||||||
|
|
||||||
|
file_path: str = Field(description="Path to the file to delete (relative to workspace root)")
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteFileTool(BaseTool):
|
||||||
|
"""Delete a file within the workspace. Refuses to delete directories."""
|
||||||
|
|
||||||
|
name = "delete_file"
|
||||||
|
description = "Delete a single file. Does not delete directories."
|
||||||
|
params_model = DeleteFileParams
|
||||||
|
|
||||||
|
def execute(self, *, tool_call_id: str, file_path: str, **kwargs: Any) -> ToolResult:
|
||||||
|
try:
|
||||||
|
safe_path = resolve_safe_path(file_path, self.workspace_root)
|
||||||
|
except PathSecurityError as exc:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=str(exc),
|
||||||
|
)
|
||||||
|
|
||||||
|
if safe_path.is_dir():
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=f"Path is a directory, not a file: {safe_path}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not safe_path.exists():
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=f"File not found: {safe_path}",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
safe_path.unlink()
|
||||||
|
except OSError as exc:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=f"Failed to delete file: {exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
rel_path = safe_path.relative_to(self.workspace_root)
|
||||||
|
except ValueError:
|
||||||
|
rel_path = safe_path
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.SUCCESS,
|
||||||
|
output=f"Deleted: {rel_path}",
|
||||||
|
)
|
||||||
|
|||||||
@@ -38,14 +38,47 @@ class ToolRegistry:
|
|||||||
|
|
||||||
def create_default_registry(workspace_root: Path, config: AppConfig) -> ToolRegistry:
|
def create_default_registry(workspace_root: Path, config: AppConfig) -> ToolRegistry:
|
||||||
"""Create a ToolRegistry populated with all built-in tools."""
|
"""Create a ToolRegistry populated with all built-in tools."""
|
||||||
|
# Read tools
|
||||||
from app.tools.filesystem import ListDirTool, ReadFileTool
|
from app.tools.filesystem import ListDirTool, ReadFileTool
|
||||||
|
|
||||||
|
# Write tools
|
||||||
|
from app.tools.filesystem import DeleteFileTool, MakeDirTool, WriteFileTool
|
||||||
|
|
||||||
|
# Edit tools
|
||||||
|
from app.tools.edit import PatchApplyTool, StrReplaceTool
|
||||||
|
|
||||||
|
# Shell tools
|
||||||
|
from app.tools.shell import RunCommandTool
|
||||||
|
|
||||||
|
# Control flow
|
||||||
from app.tools.finish import FinishTool
|
from app.tools.finish import FinishTool
|
||||||
|
|
||||||
|
# Search tools
|
||||||
from app.tools.search import FindFilesTool, GrepFilesTool
|
from app.tools.search import FindFilesTool, GrepFilesTool
|
||||||
|
|
||||||
registry = ToolRegistry()
|
registry = ToolRegistry()
|
||||||
|
|
||||||
|
# Read
|
||||||
registry.register(ReadFileTool(workspace_root, config))
|
registry.register(ReadFileTool(workspace_root, config))
|
||||||
registry.register(ListDirTool(workspace_root, config))
|
registry.register(ListDirTool(workspace_root, config))
|
||||||
|
|
||||||
|
# Search
|
||||||
registry.register(GrepFilesTool(workspace_root, config))
|
registry.register(GrepFilesTool(workspace_root, config))
|
||||||
registry.register(FindFilesTool(workspace_root, config))
|
registry.register(FindFilesTool(workspace_root, config))
|
||||||
|
|
||||||
|
# Write
|
||||||
|
registry.register(WriteFileTool(workspace_root, config))
|
||||||
|
registry.register(MakeDirTool(workspace_root, config))
|
||||||
|
registry.register(DeleteFileTool(workspace_root, config))
|
||||||
|
|
||||||
|
# Edit
|
||||||
|
registry.register(StrReplaceTool(workspace_root, config))
|
||||||
|
registry.register(PatchApplyTool(workspace_root, config))
|
||||||
|
|
||||||
|
# Shell
|
||||||
|
registry.register(RunCommandTool(workspace_root, config))
|
||||||
|
|
||||||
|
# Control flow
|
||||||
registry.register(FinishTool(workspace_root, config))
|
registry.register(FinishTool(workspace_root, config))
|
||||||
|
|
||||||
return registry
|
return registry
|
||||||
|
|||||||
113
app/tools/shell.py
Normal file
113
app/tools/shell.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
"""Shell tool: run_command."""
|
||||||
|
|
||||||
|
import shlex
|
||||||
|
import subprocess
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.models.tool_call import ToolResult, ToolResultStatus
|
||||||
|
from app.tools.base import BaseTool
|
||||||
|
|
||||||
|
_DEFAULT_TIMEOUT = 30
|
||||||
|
|
||||||
|
|
||||||
|
class RunCommandParams(BaseModel):
|
||||||
|
"""Parameters for the run_command tool."""
|
||||||
|
|
||||||
|
command: str = Field(description="Shell command to execute")
|
||||||
|
timeout: int | None = Field(default=None, description="Timeout in seconds (default: 30)")
|
||||||
|
|
||||||
|
|
||||||
|
class RunCommandTool(BaseTool):
|
||||||
|
"""Execute a shell command within the workspace."""
|
||||||
|
|
||||||
|
name = "run_command"
|
||||||
|
description = (
|
||||||
|
"Run a shell command in the workspace directory. "
|
||||||
|
"Only allowed commands may be executed; dangerous commands are blocked."
|
||||||
|
)
|
||||||
|
params_model = RunCommandParams
|
||||||
|
|
||||||
|
def execute(self, *, tool_call_id: str, command: str, timeout: int | None = None, **kwargs: Any) -> ToolResult:
|
||||||
|
shell_config = self.config.tools.shell
|
||||||
|
effective_timeout = timeout if timeout is not None else _DEFAULT_TIMEOUT
|
||||||
|
|
||||||
|
# Deny check: prefix match against full command string
|
||||||
|
for denied in shell_config.denied_commands:
|
||||||
|
if command.startswith(denied):
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=f"Command denied: matches blocked prefix '{denied}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Allow check: first token must be in allowed_commands
|
||||||
|
try:
|
||||||
|
tokens = shlex.split(command)
|
||||||
|
except ValueError as exc:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=f"Failed to parse command: {exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not tokens:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error="Empty command",
|
||||||
|
)
|
||||||
|
|
||||||
|
base_cmd = tokens[0]
|
||||||
|
if shell_config.allowed_commands and base_cmd not in shell_config.allowed_commands:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=f"Command '{base_cmd}' is not in the allowed commands list",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
command,
|
||||||
|
shell=True,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=effective_timeout,
|
||||||
|
cwd=self.workspace_root,
|
||||||
|
)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=f"Command timed out after {effective_timeout}s",
|
||||||
|
)
|
||||||
|
except OSError as exc:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=f"Failed to execute command: {exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine output and truncate
|
||||||
|
output = result.stdout + result.stderr
|
||||||
|
max_bytes = shell_config.max_output_bytes
|
||||||
|
if len(output.encode("utf-8", errors="replace")) > max_bytes:
|
||||||
|
output = output[:max_bytes] + "\n... (output truncated)"
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
output = f"Exit code: {result.returncode}\n{output}"
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.SUCCESS,
|
||||||
|
output=output,
|
||||||
|
)
|
||||||
@@ -60,6 +60,7 @@ def client() -> MagicMock:
|
|||||||
def handler() -> MagicMock:
|
def handler() -> MagicMock:
|
||||||
mock = MagicMock(spec=StreamHandler)
|
mock = MagicMock(spec=StreamHandler)
|
||||||
mock.usage = None
|
mock.usage = None
|
||||||
|
mock.had_reasoning_only = False
|
||||||
mock.reset = MagicMock()
|
mock.reset = MagicMock()
|
||||||
return mock
|
return mock
|
||||||
|
|
||||||
|
|||||||
145
tests/unit/test_edit.py
Normal file
145
tests/unit/test_edit.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
"""Tests for edit tools: str_replace and patch_apply (Phase 6)."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch as mock_patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.models.config import AppConfig, load_config
|
||||||
|
from app.models.tool_call import ToolResultStatus
|
||||||
|
from app.tools.edit import PatchApplyTool, StrReplaceTool
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config() -> AppConfig:
|
||||||
|
return load_config()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tmp_workspace(tmp_path: Path, config: AppConfig) -> tuple[Path, AppConfig]:
|
||||||
|
"""Create a temporary workspace for edit tests."""
|
||||||
|
config.agent.workspace_root = tmp_path
|
||||||
|
return tmp_path, config
|
||||||
|
|
||||||
|
|
||||||
|
# --- StrReplaceTool ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestStrReplaceTool:
|
||||||
|
def test_replace_unique_match(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
(ws / "test.py").write_text("def hello():\n return 'hello'\n")
|
||||||
|
tool = StrReplaceTool(ws, cfg)
|
||||||
|
result = tool.run(
|
||||||
|
"tc-1",
|
||||||
|
{"file_path": "test.py", "old_str": "return 'hello'", "new_str": "return 'world'"},
|
||||||
|
)
|
||||||
|
assert result.status == ToolResultStatus.SUCCESS
|
||||||
|
assert "1 occurrence" in result.output
|
||||||
|
assert (ws / "test.py").read_text() == "def hello():\n return 'world'\n"
|
||||||
|
|
||||||
|
def test_replace_no_match(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
(ws / "test.py").write_text("some content")
|
||||||
|
tool = StrReplaceTool(ws, cfg)
|
||||||
|
result = tool.run(
|
||||||
|
"tc-2",
|
||||||
|
{"file_path": "test.py", "old_str": "nonexistent", "new_str": "replacement"},
|
||||||
|
)
|
||||||
|
assert result.status == ToolResultStatus.ERROR
|
||||||
|
assert "not found" in (result.error or "").lower()
|
||||||
|
|
||||||
|
def test_replace_multiple_matches_fails(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
(ws / "test.py").write_text("foo bar foo baz foo")
|
||||||
|
tool = StrReplaceTool(ws, cfg)
|
||||||
|
result = tool.run(
|
||||||
|
"tc-3",
|
||||||
|
{"file_path": "test.py", "old_str": "foo", "new_str": "qux"},
|
||||||
|
)
|
||||||
|
assert result.status == ToolResultStatus.ERROR
|
||||||
|
assert "3 times" in (result.error or "")
|
||||||
|
|
||||||
|
def test_replace_nonexistent_file(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
tool = StrReplaceTool(ws, cfg)
|
||||||
|
result = tool.run(
|
||||||
|
"tc-4",
|
||||||
|
{"file_path": "missing.py", "old_str": "a", "new_str": "b"},
|
||||||
|
)
|
||||||
|
assert result.status == ToolResultStatus.ERROR
|
||||||
|
assert "not found" in (result.error or "").lower()
|
||||||
|
|
||||||
|
def test_replace_path_traversal_blocked(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
tool = StrReplaceTool(ws, cfg)
|
||||||
|
result = tool.run(
|
||||||
|
"tc-5",
|
||||||
|
{"file_path": "../../etc/passwd", "old_str": "a", "new_str": "b"},
|
||||||
|
)
|
||||||
|
assert result.status == ToolResultStatus.ERROR
|
||||||
|
assert "outside" in (result.error or "").lower()
|
||||||
|
|
||||||
|
|
||||||
|
# --- PatchApplyTool ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestPatchApplyTool:
|
||||||
|
def test_apply_valid_patch(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
(ws / "target.txt").write_text("line1\nline2\nline3\n")
|
||||||
|
patch_text = (
|
||||||
|
"--- a/target.txt\n"
|
||||||
|
"+++ b/target.txt\n"
|
||||||
|
"@@ -1,3 +1,3 @@\n"
|
||||||
|
" line1\n"
|
||||||
|
"-line2\n"
|
||||||
|
"+line2_modified\n"
|
||||||
|
" line3\n"
|
||||||
|
)
|
||||||
|
tool = PatchApplyTool(ws, cfg)
|
||||||
|
result = tool.run("tc-1", {"file_path": "target.txt", "patch": patch_text})
|
||||||
|
assert result.status == ToolResultStatus.SUCCESS
|
||||||
|
assert "line2_modified" in (ws / "target.txt").read_text()
|
||||||
|
|
||||||
|
def test_apply_bad_patch(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
(ws / "target.txt").write_text("line1\nline2\n")
|
||||||
|
tool = PatchApplyTool(ws, cfg)
|
||||||
|
result = tool.run(
|
||||||
|
"tc-2",
|
||||||
|
{"file_path": "target.txt", "patch": "this is not a valid patch"},
|
||||||
|
)
|
||||||
|
assert result.status == ToolResultStatus.ERROR
|
||||||
|
|
||||||
|
def test_apply_nonexistent_file(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
tool = PatchApplyTool(ws, cfg)
|
||||||
|
result = tool.run(
|
||||||
|
"tc-3",
|
||||||
|
{"file_path": "missing.txt", "patch": "--- a\n+++ b\n"},
|
||||||
|
)
|
||||||
|
assert result.status == ToolResultStatus.ERROR
|
||||||
|
assert "not found" in (result.error or "").lower()
|
||||||
|
|
||||||
|
def test_apply_path_traversal_blocked(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
tool = PatchApplyTool(ws, cfg)
|
||||||
|
result = tool.run(
|
||||||
|
"tc-4",
|
||||||
|
{"file_path": "../../etc/passwd", "patch": "--- a\n+++ b\n"},
|
||||||
|
)
|
||||||
|
assert result.status == ToolResultStatus.ERROR
|
||||||
|
assert "outside" in (result.error or "").lower()
|
||||||
|
|
||||||
|
def test_apply_timeout(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
(ws / "target.txt").write_text("content\n")
|
||||||
|
tool = PatchApplyTool(ws, cfg)
|
||||||
|
with mock_patch("app.tools.edit.subprocess.run", side_effect=__import__("subprocess").TimeoutExpired("patch", 30)):
|
||||||
|
result = tool.run(
|
||||||
|
"tc-5",
|
||||||
|
{"file_path": "target.txt", "patch": "--- a\n+++ b\n"},
|
||||||
|
)
|
||||||
|
assert result.status == ToolResultStatus.ERROR
|
||||||
|
assert "timed out" in (result.error or "").lower()
|
||||||
138
tests/unit/test_filesystem_write.py
Normal file
138
tests/unit/test_filesystem_write.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
"""Tests for write/mkdir/delete filesystem tools (Phase 6)."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.models.config import AppConfig, load_config
|
||||||
|
from app.models.tool_call import ToolResultStatus
|
||||||
|
from app.tools.filesystem import DeleteFileTool, MakeDirTool, WriteFileTool
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config() -> AppConfig:
|
||||||
|
return load_config()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def workspace(config: AppConfig) -> Path:
|
||||||
|
return config.agent.workspace_root
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tmp_workspace(tmp_path: Path, config: AppConfig) -> tuple[Path, AppConfig]:
|
||||||
|
"""Create a temporary workspace for write tests."""
|
||||||
|
# Override workspace_root to tmp_path for isolation
|
||||||
|
config.agent.workspace_root = tmp_path
|
||||||
|
return tmp_path, config
|
||||||
|
|
||||||
|
|
||||||
|
# --- WriteFileTool ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestWriteFileTool:
|
||||||
|
def test_write_new_file(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
tool = WriteFileTool(ws, cfg)
|
||||||
|
result = tool.run("tc-1", {"file_path": "hello.txt", "content": "Hello, world!"})
|
||||||
|
assert result.status == ToolResultStatus.SUCCESS
|
||||||
|
assert "13 characters" in result.output
|
||||||
|
assert (ws / "hello.txt").read_text() == "Hello, world!"
|
||||||
|
|
||||||
|
def test_write_creates_parent_dirs(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
tool = WriteFileTool(ws, cfg)
|
||||||
|
result = tool.run("tc-2", {"file_path": "a/b/c/deep.txt", "content": "deep"})
|
||||||
|
assert result.status == ToolResultStatus.SUCCESS
|
||||||
|
assert (ws / "a" / "b" / "c" / "deep.txt").read_text() == "deep"
|
||||||
|
|
||||||
|
def test_write_overwrites_existing(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
(ws / "existing.txt").write_text("old content")
|
||||||
|
tool = WriteFileTool(ws, cfg)
|
||||||
|
result = tool.run("tc-3", {"file_path": "existing.txt", "content": "new content"})
|
||||||
|
assert result.status == ToolResultStatus.SUCCESS
|
||||||
|
assert (ws / "existing.txt").read_text() == "new content"
|
||||||
|
|
||||||
|
def test_write_path_traversal_blocked(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
tool = WriteFileTool(ws, cfg)
|
||||||
|
result = tool.run("tc-4", {"file_path": "../../etc/evil.txt", "content": "bad"})
|
||||||
|
assert result.status == ToolResultStatus.ERROR
|
||||||
|
assert "outside" in (result.error or "").lower()
|
||||||
|
|
||||||
|
|
||||||
|
# --- MakeDirTool ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestMakeDirTool:
|
||||||
|
def test_make_new_dir(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
tool = MakeDirTool(ws, cfg)
|
||||||
|
result = tool.run("tc-1", {"directory_path": "newdir"})
|
||||||
|
assert result.status == ToolResultStatus.SUCCESS
|
||||||
|
assert (ws / "newdir").is_dir()
|
||||||
|
|
||||||
|
def test_make_nested_dirs(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
tool = MakeDirTool(ws, cfg)
|
||||||
|
result = tool.run("tc-2", {"directory_path": "a/b/c"})
|
||||||
|
assert result.status == ToolResultStatus.SUCCESS
|
||||||
|
assert (ws / "a" / "b" / "c").is_dir()
|
||||||
|
|
||||||
|
def test_make_existing_dir_ok(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
(ws / "existing_dir").mkdir()
|
||||||
|
tool = MakeDirTool(ws, cfg)
|
||||||
|
result = tool.run("tc-3", {"directory_path": "existing_dir"})
|
||||||
|
assert result.status == ToolResultStatus.SUCCESS
|
||||||
|
|
||||||
|
def test_make_dir_over_file_fails(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
(ws / "afile").write_text("content")
|
||||||
|
tool = MakeDirTool(ws, cfg)
|
||||||
|
result = tool.run("tc-4", {"directory_path": "afile"})
|
||||||
|
assert result.status == ToolResultStatus.ERROR
|
||||||
|
assert "not a directory" in (result.error or "").lower()
|
||||||
|
|
||||||
|
def test_make_dir_path_traversal_blocked(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
tool = MakeDirTool(ws, cfg)
|
||||||
|
result = tool.run("tc-5", {"directory_path": "../../outside"})
|
||||||
|
assert result.status == ToolResultStatus.ERROR
|
||||||
|
assert "outside" in (result.error or "").lower()
|
||||||
|
|
||||||
|
|
||||||
|
# --- DeleteFileTool ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeleteFileTool:
|
||||||
|
def test_delete_existing_file(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
(ws / "doomed.txt").write_text("bye")
|
||||||
|
tool = DeleteFileTool(ws, cfg)
|
||||||
|
result = tool.run("tc-1", {"file_path": "doomed.txt"})
|
||||||
|
assert result.status == ToolResultStatus.SUCCESS
|
||||||
|
assert not (ws / "doomed.txt").exists()
|
||||||
|
|
||||||
|
def test_delete_nonexistent_file(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
tool = DeleteFileTool(ws, cfg)
|
||||||
|
result = tool.run("tc-2", {"file_path": "ghost.txt"})
|
||||||
|
assert result.status == ToolResultStatus.ERROR
|
||||||
|
assert "not found" in (result.error or "").lower()
|
||||||
|
|
||||||
|
def test_delete_directory_refused(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
(ws / "adir").mkdir()
|
||||||
|
tool = DeleteFileTool(ws, cfg)
|
||||||
|
result = tool.run("tc-3", {"file_path": "adir"})
|
||||||
|
assert result.status == ToolResultStatus.ERROR
|
||||||
|
assert "directory" in (result.error or "").lower()
|
||||||
|
|
||||||
|
def test_delete_path_traversal_blocked(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
tool = DeleteFileTool(ws, cfg)
|
||||||
|
result = tool.run("tc-4", {"file_path": "../../etc/passwd"})
|
||||||
|
assert result.status == ToolResultStatus.ERROR
|
||||||
|
assert "outside" in (result.error or "").lower()
|
||||||
95
tests/unit/test_shell.py
Normal file
95
tests/unit/test_shell.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
"""Tests for the run_command shell tool (Phase 6)."""
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch as mock_patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.models.config import AppConfig, load_config
|
||||||
|
from app.models.tool_call import ToolResultStatus
|
||||||
|
from app.tools.shell import RunCommandTool
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config() -> AppConfig:
|
||||||
|
return load_config()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tmp_workspace(tmp_path: Path, config: AppConfig) -> tuple[Path, AppConfig]:
|
||||||
|
"""Create a temporary workspace for shell tests."""
|
||||||
|
config.agent.workspace_root = tmp_path
|
||||||
|
return tmp_path, config
|
||||||
|
|
||||||
|
|
||||||
|
class TestRunCommandTool:
|
||||||
|
def test_allowed_command_runs(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
tool = RunCommandTool(ws, cfg)
|
||||||
|
result = tool.run("tc-1", {"command": "echo hello"})
|
||||||
|
assert result.status == ToolResultStatus.SUCCESS
|
||||||
|
assert "hello" in result.output
|
||||||
|
|
||||||
|
def test_denied_command_blocked(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
tool = RunCommandTool(ws, cfg)
|
||||||
|
result = tool.run("tc-2", {"command": "sudo rm -rf /"})
|
||||||
|
assert result.status == ToolResultStatus.ERROR
|
||||||
|
assert "denied" in (result.error or "").lower()
|
||||||
|
|
||||||
|
def test_disallowed_command_blocked(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
tool = RunCommandTool(ws, cfg)
|
||||||
|
result = tool.run("tc-3", {"command": "curl http://example.com"})
|
||||||
|
assert result.status == ToolResultStatus.ERROR
|
||||||
|
assert "denied" in (result.error or "").lower()
|
||||||
|
|
||||||
|
def test_unlisted_command_blocked(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
tool = RunCommandTool(ws, cfg)
|
||||||
|
result = tool.run("tc-4", {"command": "nc -l 1234"})
|
||||||
|
assert result.status == ToolResultStatus.ERROR
|
||||||
|
assert "not in the allowed" in (result.error or "").lower()
|
||||||
|
|
||||||
|
def test_nonzero_exit_code_reported(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
tool = RunCommandTool(ws, cfg)
|
||||||
|
# ls on a nonexistent path returns non-zero
|
||||||
|
result = tool.run("tc-5", {"command": "ls /nonexistent_path_xyz"})
|
||||||
|
assert result.status == ToolResultStatus.SUCCESS
|
||||||
|
assert "Exit code:" in result.output
|
||||||
|
|
||||||
|
def test_timeout(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
tool = RunCommandTool(ws, cfg)
|
||||||
|
with mock_patch(
|
||||||
|
"app.tools.shell.subprocess.run",
|
||||||
|
side_effect=subprocess.TimeoutExpired("cmd", 1),
|
||||||
|
):
|
||||||
|
result = tool.run("tc-6", {"command": "echo test", "timeout": 1})
|
||||||
|
assert result.status == ToolResultStatus.ERROR
|
||||||
|
assert "timed out" in (result.error or "").lower()
|
||||||
|
|
||||||
|
def test_empty_command(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
tool = RunCommandTool(ws, cfg)
|
||||||
|
result = tool.run("tc-7", {"command": ""})
|
||||||
|
assert result.status == ToolResultStatus.ERROR
|
||||||
|
assert "empty" in (result.error or "").lower()
|
||||||
|
|
||||||
|
def test_custom_timeout(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
tool = RunCommandTool(ws, cfg)
|
||||||
|
result = tool.run("tc-8", {"command": "echo fast", "timeout": 60})
|
||||||
|
assert result.status == ToolResultStatus.SUCCESS
|
||||||
|
assert "fast" in result.output
|
||||||
|
|
||||||
|
def test_runs_in_workspace_dir(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
# Create a file in the workspace to verify cwd
|
||||||
|
(ws / "marker.txt").write_text("found")
|
||||||
|
tool = RunCommandTool(ws, cfg)
|
||||||
|
result = tool.run("tc-9", {"command": "cat marker.txt"})
|
||||||
|
assert result.status == ToolResultStatus.SUCCESS
|
||||||
|
assert "found" in result.output
|
||||||
@@ -96,12 +96,18 @@ class TestToolRegistry:
|
|||||||
def test_create_default_registry(self, workspace: Path, config: AppConfig) -> None:
|
def test_create_default_registry(self, workspace: Path, config: AppConfig) -> None:
|
||||||
registry = create_default_registry(workspace, config)
|
registry = create_default_registry(workspace, config)
|
||||||
names = set(registry.get_all().keys())
|
names = set(registry.get_all().keys())
|
||||||
assert names == {"read_file", "list_dir", "grep_files", "find_files", "finish"}
|
assert names == {
|
||||||
|
"read_file", "list_dir", "grep_files", "find_files",
|
||||||
|
"write_file", "make_dir", "delete_file",
|
||||||
|
"str_replace", "patch_apply",
|
||||||
|
"run_command",
|
||||||
|
"finish",
|
||||||
|
}
|
||||||
|
|
||||||
def test_schema_export(self, workspace: Path, config: AppConfig) -> None:
|
def test_schema_export(self, workspace: Path, config: AppConfig) -> None:
|
||||||
registry = create_default_registry(workspace, config)
|
registry = create_default_registry(workspace, config)
|
||||||
schemas = registry.get_openai_tools_schema()
|
schemas = registry.get_openai_tools_schema()
|
||||||
assert len(schemas) == 5
|
assert len(schemas) == 11
|
||||||
assert all(s["type"] == "function" for s in schemas)
|
assert all(s["type"] == "function" for s in schemas)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user