Compare commits
2 Commits
5aff2183d6
...
501bf5c45b
| Author | SHA1 | Date | |
|---|---|---|---|
| 501bf5c45b | |||
| adbb442ce5 |
18
app/main.py
18
app/main.py
@@ -55,6 +55,12 @@ def parse_args() -> argparse.Namespace:
|
|||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
async def _preflight(config: AppConfig) -> None:
|
||||||
|
"""Check that Ollama is reachable and the configured model is available."""
|
||||||
|
async with LLMClient(config.llm) as client:
|
||||||
|
await client.preflight_check()
|
||||||
|
|
||||||
|
|
||||||
async def _run_repl(
|
async def _run_repl(
|
||||||
ctx: SessionContext,
|
ctx: SessionContext,
|
||||||
config: AppConfig,
|
config: AppConfig,
|
||||||
@@ -171,6 +177,18 @@ def main() -> None:
|
|||||||
if args.verbose:
|
if args.verbose:
|
||||||
print_info("Verbose mode enabled")
|
print_info("Verbose mode enabled")
|
||||||
|
|
||||||
|
# Preflight: check Ollama is reachable and model exists
|
||||||
|
try:
|
||||||
|
asyncio.run(_preflight(config))
|
||||||
|
except LLMConnectionError as e:
|
||||||
|
print_error(str(e))
|
||||||
|
sys.exit(1)
|
||||||
|
except LLMError as e:
|
||||||
|
print_error(str(e))
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print_success("Ollama connected, model ready.")
|
||||||
|
|
||||||
# Create session and start REPL
|
# Create session and start REPL
|
||||||
ctx = SessionContext(config)
|
ctx = SessionContext(config)
|
||||||
logger.info("startup_complete")
|
logger.info("startup_complete")
|
||||||
|
|||||||
@@ -59,6 +59,48 @@ class LLMClient:
|
|||||||
timeout=httpx.Timeout(config.timeout, connect=10.0),
|
timeout=httpx.Timeout(config.timeout, connect=10.0),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def preflight_check(self) -> None:
|
||||||
|
"""Verify the endpoint is reachable and the configured model is available.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
LLMConnectionError: If the endpoint is unreachable.
|
||||||
|
LLMResponseError: If the model is not found or the endpoint returns an error.
|
||||||
|
"""
|
||||||
|
# Check endpoint is reachable
|
||||||
|
try:
|
||||||
|
response = await self._client.get("/api/tags")
|
||||||
|
except (httpx.ConnectError, httpx.HTTPError, OSError) as e:
|
||||||
|
raise LLMConnectionError(
|
||||||
|
f"Cannot reach Ollama at {self._config.endpoint}. Is Ollama running?"
|
||||||
|
) from e
|
||||||
|
except httpx.TimeoutException as e:
|
||||||
|
raise LLMConnectionError(
|
||||||
|
f"Timed out connecting to {self._config.endpoint}."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise LLMResponseError(
|
||||||
|
f"Ollama returned {response.status_code} from /api/tags.",
|
||||||
|
status_code=response.status_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check model is available
|
||||||
|
try:
|
||||||
|
data = response.json()
|
||||||
|
except (ValueError, KeyError):
|
||||||
|
logger.warning("preflight_parse_error", msg="Could not parse /api/tags response")
|
||||||
|
return
|
||||||
|
|
||||||
|
available = [m.get("name", "") for m in data.get("models", [])]
|
||||||
|
model = self._config.model
|
||||||
|
|
||||||
|
# Match with or without tag suffix (e.g. "qwen3.5" matches "qwen3.5:latest")
|
||||||
|
if not any(model == name or model == name.split(":")[0] for name in available):
|
||||||
|
available_str = ", ".join(available) if available else "(none)"
|
||||||
|
raise LLMResponseError(
|
||||||
|
f"Model '{model}' not found. Available models: {available_str}"
|
||||||
|
)
|
||||||
|
|
||||||
async def stream_chat(self, messages: list[Message]) -> AsyncIterator[dict]:
|
async def stream_chat(self, messages: list[Message]) -> AsyncIterator[dict]:
|
||||||
"""Stream a chat completion request, yielding parsed SSE chunks.
|
"""Stream a chat completion request, yielding parsed SSE chunks.
|
||||||
|
|
||||||
|
|||||||
46
app/services/permissions.py
Normal file
46
app/services/permissions.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""Permission gating for tool execution."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from rich.prompt import Confirm
|
||||||
|
|
||||||
|
from app.models.config import PermissionsConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionDenied(Exception):
|
||||||
|
"""Raised when a tool is denied execution by permissions policy."""
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionsService:
|
||||||
|
"""Check whether a tool is allowed to execute based on config tiers."""
|
||||||
|
|
||||||
|
def __init__(self, config: PermissionsConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def check(self, tool_name: str, description: str = "") -> bool:
|
||||||
|
"""Check if a tool is permitted to run.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if permitted, False if denied.
|
||||||
|
"""
|
||||||
|
if tool_name in self.config.deny:
|
||||||
|
logger.info("Tool '%s' is in deny list — blocked", tool_name)
|
||||||
|
return False
|
||||||
|
|
||||||
|
if tool_name in self.config.auto_approve:
|
||||||
|
logger.debug("Tool '%s' is auto-approved", tool_name)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Explicit prompt_user list or unlisted tools both trigger a prompt
|
||||||
|
return self._prompt_user(tool_name, description)
|
||||||
|
|
||||||
|
def _prompt_user(self, tool_name: str, description: str) -> bool:
|
||||||
|
"""Prompt the user for approval via the terminal."""
|
||||||
|
prompt_text = f"Allow tool [bold]{tool_name}[/bold]"
|
||||||
|
if description:
|
||||||
|
prompt_text += f" — {description}"
|
||||||
|
prompt_text += "?"
|
||||||
|
|
||||||
|
return Confirm.ask(prompt_text, default=False)
|
||||||
@@ -4,11 +4,11 @@ from collections.abc import AsyncIterator
|
|||||||
|
|
||||||
from rich.live import Live
|
from rich.live import Live
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
|
from rich.panel import Panel
|
||||||
|
|
||||||
from app.models.config import DisplayConfig
|
from app.models.config import DisplayConfig
|
||||||
from app.models.message import Message
|
from app.models.message import Message
|
||||||
from app.models.tool_call import ToolCall, ToolCallFunction
|
from app.models.tool_call import ToolCall, ToolCallFunction
|
||||||
from app.utils.display import print_assistant_message
|
|
||||||
from app.utils.logging import console, get_logger
|
from app.utils.logging import console, get_logger
|
||||||
from app.utils.token_counter import TokenUsage
|
from app.utils.token_counter import TokenUsage
|
||||||
|
|
||||||
@@ -50,14 +50,19 @@ class StreamHandler:
|
|||||||
# Show reasoning while waiting for content
|
# Show reasoning while waiting for content
|
||||||
display_text = self._accumulated_content
|
display_text = self._accumulated_content
|
||||||
if not display_text and self._accumulated_reasoning:
|
if not display_text and self._accumulated_reasoning:
|
||||||
display_text = f"*thinking...*"
|
display_text = "*thinking...*"
|
||||||
|
|
||||||
if display_text and self._display_config.stream_output:
|
if display_text and self._display_config.stream_output:
|
||||||
live.update(Markdown(display_text))
|
# Render inside the same Assistant panel used for final output
|
||||||
|
# so the live display and final frame are visually consistent
|
||||||
# Final static render
|
live.update(
|
||||||
if self._accumulated_content:
|
Panel(
|
||||||
print_assistant_message(self._accumulated_content)
|
Markdown(display_text),
|
||||||
|
title="Assistant",
|
||||||
|
border_style="green",
|
||||||
|
expand=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
tool_calls = self._build_tool_calls() or None
|
tool_calls = self._build_tool_calls() or None
|
||||||
return Message(
|
return Message(
|
||||||
|
|||||||
@@ -0,0 +1,6 @@
|
|||||||
|
"""Tool framework: base class, registry, and built-in tools."""
|
||||||
|
|
||||||
|
from app.tools.base import BaseTool
|
||||||
|
from app.tools.registry import ToolRegistry, create_default_registry
|
||||||
|
|
||||||
|
__all__ = ["BaseTool", "ToolRegistry", "create_default_registry"]
|
||||||
|
|||||||
81
app/tools/base.py
Normal file
81
app/tools/base.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
"""BaseTool ABC — foundation for all agent-callable tools."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
|
from app.models.config import AppConfig
|
||||||
|
from app.models.tool_call import ToolResult, ToolResultStatus
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTool(ABC):
|
||||||
|
"""Abstract base class for all agent tools.
|
||||||
|
|
||||||
|
Subclasses must set the class-level ``name``, ``description``, and
|
||||||
|
``params_model`` attributes and implement ``execute``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: ClassVar[str]
|
||||||
|
description: ClassVar[str]
|
||||||
|
params_model: ClassVar[type[BaseModel]]
|
||||||
|
|
||||||
|
def __init__(self, workspace_root: Path, config: AppConfig) -> None:
|
||||||
|
self.workspace_root = workspace_root
|
||||||
|
self.config = config
|
||||||
|
self.logger = logging.getLogger(f"{__name__}.{self.name}")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def execute(self, *, tool_call_id: str, **kwargs: Any) -> ToolResult:
|
||||||
|
"""Execute the tool with validated parameters.
|
||||||
|
|
||||||
|
Subclasses implement the actual tool logic here.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def run(self, tool_call_id: str, arguments: dict[str, Any]) -> ToolResult:
|
||||||
|
"""Public entry point: validate arguments, execute, guarantee a ToolResult.
|
||||||
|
|
||||||
|
Never raises — all exceptions are caught and returned as error results.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
validated = self.params_model(**arguments)
|
||||||
|
except ValidationError as exc:
|
||||||
|
self.logger.warning("Validation error for %s: %s", self.name, exc)
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=f"Invalid arguments: {exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self.execute(tool_call_id=tool_call_id, **validated.model_dump())
|
||||||
|
except Exception as exc:
|
||||||
|
self.logger.exception("Unexpected error in tool %s", self.name)
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=f"Tool execution failed: {exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_openai_schema(self) -> dict[str, Any]:
|
||||||
|
"""Return the OpenAI function-calling schema for this tool."""
|
||||||
|
schema = self.params_model.model_json_schema()
|
||||||
|
# Remove the top-level title/description that Pydantic adds —
|
||||||
|
# those belong on the function object, not the parameters.
|
||||||
|
schema.pop("title", None)
|
||||||
|
schema.pop("description", None)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"parameters": schema,
|
||||||
|
},
|
||||||
|
}
|
||||||
149
app/tools/filesystem.py
Normal file
149
app/tools/filesystem.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
"""Filesystem tools: read_file and list_dir."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.models.tool_call import ToolResult, ToolResultStatus
|
||||||
|
from app.tools.base import BaseTool
|
||||||
|
from app.utils.file_helpers import (
|
||||||
|
BinaryFileError,
|
||||||
|
FileSizeError,
|
||||||
|
PathSecurityError,
|
||||||
|
resolve_safe_path,
|
||||||
|
safe_read_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ReadFileParams(BaseModel):
|
||||||
|
"""Parameters for the read_file tool."""
|
||||||
|
|
||||||
|
file_path: str = Field(description="Path to the file to read (relative to workspace root)")
|
||||||
|
|
||||||
|
|
||||||
|
class ReadFileTool(BaseTool):
|
||||||
|
"""Read the contents of a file within the workspace."""
|
||||||
|
|
||||||
|
name = "read_file"
|
||||||
|
description = "Read the full contents of a text file. Returns the file content as a string."
|
||||||
|
params_model = ReadFileParams
|
||||||
|
|
||||||
|
def execute(self, *, tool_call_id: str, file_path: str, **kwargs: Any) -> ToolResult:
|
||||||
|
fs_config = self.config.tools.filesystem
|
||||||
|
try:
|
||||||
|
content = safe_read_file(
|
||||||
|
file_path,
|
||||||
|
self.workspace_root,
|
||||||
|
max_size_bytes=fs_config.max_file_size_bytes,
|
||||||
|
check_binary=fs_config.binary_detection,
|
||||||
|
)
|
||||||
|
except PathSecurityError as exc:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=str(exc),
|
||||||
|
)
|
||||||
|
except FileNotFoundError as exc:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=str(exc),
|
||||||
|
)
|
||||||
|
except FileSizeError as exc:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=str(exc),
|
||||||
|
)
|
||||||
|
except BinaryFileError as exc:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=str(exc),
|
||||||
|
)
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.SUCCESS,
|
||||||
|
output=content,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ListDirParams(BaseModel):
|
||||||
|
"""Parameters for the list_dir tool."""
|
||||||
|
|
||||||
|
directory_path: str = Field(
|
||||||
|
default=".", description="Path to the directory to list (relative to workspace root)"
|
||||||
|
)
|
||||||
|
recursive: bool = Field(default=False, description="If true, list entries recursively")
|
||||||
|
|
||||||
|
|
||||||
|
_MAX_ENTRIES = 500
|
||||||
|
|
||||||
|
|
||||||
|
class ListDirTool(BaseTool):
|
||||||
|
"""List files and directories within the workspace."""
|
||||||
|
|
||||||
|
name = "list_dir"
|
||||||
|
description = (
|
||||||
|
"List the contents of a directory. Directories are suffixed with '/'. "
|
||||||
|
"Results are sorted with directories first, then files."
|
||||||
|
)
|
||||||
|
params_model = ListDirParams
|
||||||
|
|
||||||
|
def execute(
|
||||||
|
self, *, tool_call_id: str, directory_path: str = ".", recursive: bool = False, **kwargs: Any
|
||||||
|
) -> ToolResult:
|
||||||
|
try:
|
||||||
|
safe_path = resolve_safe_path(directory_path, self.workspace_root)
|
||||||
|
except PathSecurityError as exc:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=str(exc),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not safe_path.is_dir():
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=f"Not a directory: {safe_path}",
|
||||||
|
)
|
||||||
|
|
||||||
|
entries: list[Path] = []
|
||||||
|
if recursive:
|
||||||
|
entries = list(safe_path.rglob("*"))
|
||||||
|
else:
|
||||||
|
entries = list(safe_path.iterdir())
|
||||||
|
|
||||||
|
# Sort: directories first, then files, alphabetical within each group
|
||||||
|
dirs = sorted([e for e in entries if e.is_dir()])
|
||||||
|
files = sorted([e for e in entries if e.is_file()])
|
||||||
|
sorted_entries = dirs + files
|
||||||
|
|
||||||
|
lines: list[str] = []
|
||||||
|
for entry in sorted_entries[:_MAX_ENTRIES]:
|
||||||
|
try:
|
||||||
|
rel = entry.relative_to(self.workspace_root)
|
||||||
|
except ValueError:
|
||||||
|
rel = entry
|
||||||
|
suffix = "/" if entry.is_dir() else ""
|
||||||
|
lines.append(f"{rel}{suffix}")
|
||||||
|
|
||||||
|
if len(sorted_entries) > _MAX_ENTRIES:
|
||||||
|
lines.append(f"\n... truncated ({len(sorted_entries)} total entries, showing {_MAX_ENTRIES})")
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.SUCCESS,
|
||||||
|
output="\n".join(lines),
|
||||||
|
)
|
||||||
49
app/tools/registry.py
Normal file
49
app/tools/registry.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
"""Tool registration and schema export."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.models.config import AppConfig
|
||||||
|
from app.tools.base import BaseTool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolRegistry:
|
||||||
|
"""Registry of available tools, keyed by name."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._tools: dict[str, BaseTool] = {}
|
||||||
|
|
||||||
|
def register(self, tool: BaseTool) -> None:
|
||||||
|
"""Register a tool instance. Raises ValueError on duplicate name."""
|
||||||
|
if tool.name in self._tools:
|
||||||
|
raise ValueError(f"Duplicate tool name: '{tool.name}'")
|
||||||
|
self._tools[tool.name] = tool
|
||||||
|
logger.debug("Registered tool: %s", tool.name)
|
||||||
|
|
||||||
|
def get(self, name: str) -> BaseTool | None:
|
||||||
|
"""Look up a tool by name."""
|
||||||
|
return self._tools.get(name)
|
||||||
|
|
||||||
|
def get_all(self) -> dict[str, BaseTool]:
|
||||||
|
"""Return all registered tools."""
|
||||||
|
return dict(self._tools)
|
||||||
|
|
||||||
|
def get_openai_tools_schema(self) -> list[dict[str, Any]]:
|
||||||
|
"""Return OpenAI function-calling schemas for all registered tools."""
|
||||||
|
return [tool.get_openai_schema() for tool in self._tools.values()]
|
||||||
|
|
||||||
|
|
||||||
|
def create_default_registry(workspace_root: Path, config: AppConfig) -> ToolRegistry:
|
||||||
|
"""Create a ToolRegistry populated with all built-in tools."""
|
||||||
|
from app.tools.filesystem import ListDirTool, ReadFileTool
|
||||||
|
from app.tools.search import FindFilesTool, GrepFilesTool
|
||||||
|
|
||||||
|
registry = ToolRegistry()
|
||||||
|
registry.register(ReadFileTool(workspace_root, config))
|
||||||
|
registry.register(ListDirTool(workspace_root, config))
|
||||||
|
registry.register(GrepFilesTool(workspace_root, config))
|
||||||
|
registry.register(FindFilesTool(workspace_root, config))
|
||||||
|
return registry
|
||||||
184
app/tools/search.py
Normal file
184
app/tools/search.py
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
"""Search tools: grep_files and find_files."""
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.models.tool_call import ToolResult, ToolResultStatus
|
||||||
|
from app.tools.base import BaseTool
|
||||||
|
from app.utils.file_helpers import PathSecurityError, resolve_safe_path
|
||||||
|
|
||||||
|
_GREP_MAX_MATCHES = 100
|
||||||
|
_FIND_MAX_RESULTS = 200
|
||||||
|
|
||||||
|
|
||||||
|
class GrepFilesParams(BaseModel):
|
||||||
|
"""Parameters for the grep_files tool."""
|
||||||
|
|
||||||
|
pattern: str = Field(description="Regular expression pattern to search for")
|
||||||
|
path: str = Field(default=".", description="Directory or file to search in (relative to workspace root)")
|
||||||
|
file_pattern: str | None = Field(
|
||||||
|
default=None, description="Glob pattern to filter files (e.g. '*.py')"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GrepFilesTool(BaseTool):
|
||||||
|
"""Search file contents using grep."""
|
||||||
|
|
||||||
|
name = "grep_files"
|
||||||
|
description = (
|
||||||
|
"Search for a regex pattern in file contents. Returns matching lines with "
|
||||||
|
"file paths and line numbers."
|
||||||
|
)
|
||||||
|
params_model = GrepFilesParams
|
||||||
|
|
||||||
|
def execute(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
tool_call_id: str,
|
||||||
|
pattern: str,
|
||||||
|
path: str = ".",
|
||||||
|
file_pattern: str | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ToolResult:
|
||||||
|
try:
|
||||||
|
safe_path = resolve_safe_path(path, self.workspace_root)
|
||||||
|
except PathSecurityError as exc:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=str(exc),
|
||||||
|
)
|
||||||
|
|
||||||
|
cmd = ["grep", "-rn", pattern, str(safe_path)]
|
||||||
|
if file_pattern:
|
||||||
|
cmd.insert(3, f"--include={file_pattern}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error="grep timed out after 30 seconds",
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode == 1:
|
||||||
|
# No matches — not an error
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.SUCCESS,
|
||||||
|
output="No matches found.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode >= 2:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=result.stderr.strip() or f"grep exited with code {result.returncode}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Truncate to max matches
|
||||||
|
lines = result.stdout.splitlines()
|
||||||
|
output_lines = lines[:_GREP_MAX_MATCHES]
|
||||||
|
if len(lines) > _GREP_MAX_MATCHES:
|
||||||
|
output_lines.append(f"\n... truncated ({len(lines)} matches, showing {_GREP_MAX_MATCHES})")
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.SUCCESS,
|
||||||
|
output="\n".join(output_lines),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FindFilesParams(BaseModel):
|
||||||
|
"""Parameters for the find_files tool."""
|
||||||
|
|
||||||
|
pattern: str = Field(description="File name pattern to search for (e.g. '*.py', 'config.yaml')")
|
||||||
|
path: str = Field(default=".", description="Directory to search in (relative to workspace root)")
|
||||||
|
|
||||||
|
|
||||||
|
class FindFilesTool(BaseTool):
|
||||||
|
"""Find files by name pattern."""
|
||||||
|
|
||||||
|
name = "find_files"
|
||||||
|
description = "Search for files matching a name pattern. Returns relative file paths."
|
||||||
|
params_model = FindFilesParams
|
||||||
|
|
||||||
|
def execute(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
tool_call_id: str,
|
||||||
|
pattern: str,
|
||||||
|
path: str = ".",
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ToolResult:
|
||||||
|
try:
|
||||||
|
safe_path = resolve_safe_path(path, self.workspace_root)
|
||||||
|
except PathSecurityError as exc:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=str(exc),
|
||||||
|
)
|
||||||
|
|
||||||
|
cmd = ["find", str(safe_path), "-name", pattern, "-type", "f"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error="find timed out after 30 seconds",
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.returncode != 0:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=result.stderr.strip() or f"find exited with code {result.returncode}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make paths relative to workspace root and truncate
|
||||||
|
lines = result.stdout.strip().splitlines() if result.stdout.strip() else []
|
||||||
|
relative_lines: list[str] = []
|
||||||
|
for line in lines[:_FIND_MAX_RESULTS]:
|
||||||
|
try:
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
rel = Path(line).relative_to(self.workspace_root)
|
||||||
|
relative_lines.append(str(rel))
|
||||||
|
except ValueError:
|
||||||
|
relative_lines.append(line)
|
||||||
|
|
||||||
|
if len(lines) > _FIND_MAX_RESULTS:
|
||||||
|
relative_lines.append(f"\n... truncated ({len(lines)} results, showing {_FIND_MAX_RESULTS})")
|
||||||
|
|
||||||
|
output = "\n".join(relative_lines) if relative_lines else "No files found."
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.SUCCESS,
|
||||||
|
output=output,
|
||||||
|
)
|
||||||
203
tests/unit/test_tools.py
Normal file
203
tests/unit/test_tools.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
"""Tests for the tool framework and core tools (Phase 4)."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.models.config import AppConfig, PermissionsConfig, load_config
|
||||||
|
from app.models.tool_call import ToolResultStatus
|
||||||
|
from app.services.permissions import PermissionsService
|
||||||
|
from app.tools.base import BaseTool
|
||||||
|
from app.tools.filesystem import ListDirTool, ReadFileTool
|
||||||
|
from app.tools.registry import ToolRegistry, create_default_registry
|
||||||
|
from app.tools.search import FindFilesTool, GrepFilesTool
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config() -> AppConfig:
|
||||||
|
return load_config()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def workspace(config: AppConfig) -> Path:
|
||||||
|
return config.agent.workspace_root
|
||||||
|
|
||||||
|
|
||||||
|
# --- BaseTool ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseTool:
|
||||||
|
def test_run_with_invalid_args_returns_error(self, workspace: Path, config: AppConfig) -> None:
|
||||||
|
tool = ReadFileTool(workspace, config)
|
||||||
|
# missing required 'file_path'
|
||||||
|
result = tool.run("tc-1", {})
|
||||||
|
assert result.status == ToolResultStatus.ERROR
|
||||||
|
assert "Invalid arguments" in (result.error or "")
|
||||||
|
|
||||||
|
def test_get_openai_schema_structure(self, workspace: Path, config: AppConfig) -> None:
|
||||||
|
tool = ReadFileTool(workspace, config)
|
||||||
|
schema = tool.get_openai_schema()
|
||||||
|
assert schema["type"] == "function"
|
||||||
|
assert schema["function"]["name"] == "read_file"
|
||||||
|
assert "parameters" in schema["function"]
|
||||||
|
assert schema["function"]["parameters"]["type"] == "object"
|
||||||
|
|
||||||
|
|
||||||
|
# --- PermissionsService ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestPermissionsService:
|
||||||
|
def test_deny_list_blocks(self) -> None:
|
||||||
|
svc = PermissionsService(PermissionsConfig(deny=["dangerous_tool"]))
|
||||||
|
assert svc.check("dangerous_tool") is False
|
||||||
|
|
||||||
|
def test_auto_approve_allows(self) -> None:
|
||||||
|
svc = PermissionsService(PermissionsConfig(auto_approve=["read_file"]))
|
||||||
|
assert svc.check("read_file") is True
|
||||||
|
|
||||||
|
@patch("app.services.permissions.Confirm.ask", return_value=True)
|
||||||
|
def test_prompt_user_approved(self, mock_ask: object) -> None:
|
||||||
|
svc = PermissionsService(PermissionsConfig(prompt_user=["write_file"]))
|
||||||
|
assert svc.check("write_file") is True
|
||||||
|
|
||||||
|
@patch("app.services.permissions.Confirm.ask", return_value=False)
|
||||||
|
def test_prompt_user_denied(self, mock_ask: object) -> None:
|
||||||
|
svc = PermissionsService(PermissionsConfig(prompt_user=["write_file"]))
|
||||||
|
assert svc.check("write_file") is False
|
||||||
|
|
||||||
|
@patch("app.services.permissions.Confirm.ask", return_value=False)
|
||||||
|
def test_unlisted_tool_prompts(self, mock_ask: object) -> None:
|
||||||
|
svc = PermissionsService(PermissionsConfig())
|
||||||
|
assert svc.check("unknown_tool") is False
|
||||||
|
|
||||||
|
|
||||||
|
# --- ToolRegistry ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolRegistry:
|
||||||
|
def test_register_and_get(self, workspace: Path, config: AppConfig) -> None:
|
||||||
|
registry = ToolRegistry()
|
||||||
|
tool = ReadFileTool(workspace, config)
|
||||||
|
registry.register(tool)
|
||||||
|
assert registry.get("read_file") is tool
|
||||||
|
|
||||||
|
def test_duplicate_raises(self, workspace: Path, config: AppConfig) -> None:
|
||||||
|
registry = ToolRegistry()
|
||||||
|
tool = ReadFileTool(workspace, config)
|
||||||
|
registry.register(tool)
|
||||||
|
with pytest.raises(ValueError, match="Duplicate"):
|
||||||
|
registry.register(tool)
|
||||||
|
|
||||||
|
def test_get_missing_returns_none(self) -> None:
|
||||||
|
registry = ToolRegistry()
|
||||||
|
assert registry.get("nonexistent") is None
|
||||||
|
|
||||||
|
def test_create_default_registry(self, workspace: Path, config: AppConfig) -> None:
|
||||||
|
registry = create_default_registry(workspace, config)
|
||||||
|
names = set(registry.get_all().keys())
|
||||||
|
assert names == {"read_file", "list_dir", "grep_files", "find_files"}
|
||||||
|
|
||||||
|
def test_schema_export(self, workspace: Path, config: AppConfig) -> None:
|
||||||
|
registry = create_default_registry(workspace, config)
|
||||||
|
schemas = registry.get_openai_tools_schema()
|
||||||
|
assert len(schemas) == 4
|
||||||
|
assert all(s["type"] == "function" for s in schemas)
|
||||||
|
|
||||||
|
|
||||||
|
# --- ReadFileTool ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestReadFileTool:
|
||||||
|
def test_read_existing_file(self, workspace: Path, config: AppConfig) -> None:
|
||||||
|
tool = ReadFileTool(workspace, config)
|
||||||
|
result = tool.run("tc-1", {"file_path": "config/config.yaml"})
|
||||||
|
assert result.status == ToolResultStatus.SUCCESS
|
||||||
|
assert "llm:" in result.output
|
||||||
|
|
||||||
|
def test_read_nonexistent_file(self, workspace: Path, config: AppConfig) -> None:
|
||||||
|
tool = ReadFileTool(workspace, config)
|
||||||
|
result = tool.run("tc-2", {"file_path": "nonexistent.txt"})
|
||||||
|
assert result.status == ToolResultStatus.ERROR
|
||||||
|
assert "not found" in (result.error or "").lower()
|
||||||
|
|
||||||
|
def test_path_traversal_blocked(self, workspace: Path, config: AppConfig) -> None:
|
||||||
|
tool = ReadFileTool(workspace, config)
|
||||||
|
result = tool.run("tc-3", {"file_path": "../../etc/passwd"})
|
||||||
|
assert result.status == ToolResultStatus.ERROR
|
||||||
|
assert "outside" in (result.error or "").lower()
|
||||||
|
|
||||||
|
|
||||||
|
# --- ListDirTool ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestListDirTool:
|
||||||
|
def test_list_root(self, workspace: Path, config: AppConfig) -> None:
|
||||||
|
tool = ListDirTool(workspace, config)
|
||||||
|
result = tool.run("tc-1", {})
|
||||||
|
assert result.status == ToolResultStatus.SUCCESS
|
||||||
|
assert "app/" in result.output
|
||||||
|
|
||||||
|
def test_list_subdir(self, workspace: Path, config: AppConfig) -> None:
|
||||||
|
tool = ListDirTool(workspace, config)
|
||||||
|
result = tool.run("tc-2", {"directory_path": "app/tools"})
|
||||||
|
assert result.status == ToolResultStatus.SUCCESS
|
||||||
|
assert "base.py" in result.output
|
||||||
|
|
||||||
|
def test_list_nonexistent_dir(self, workspace: Path, config: AppConfig) -> None:
|
||||||
|
tool = ListDirTool(workspace, config)
|
||||||
|
result = tool.run("tc-3", {"directory_path": "nonexistent_dir"})
|
||||||
|
assert result.status == ToolResultStatus.ERROR
|
||||||
|
|
||||||
|
def test_list_recursive(self, workspace: Path, config: AppConfig) -> None:
|
||||||
|
tool = ListDirTool(workspace, config)
|
||||||
|
result = tool.run("tc-4", {"directory_path": "config", "recursive": True})
|
||||||
|
assert result.status == ToolResultStatus.SUCCESS
|
||||||
|
assert "config.yaml" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
# --- GrepFilesTool ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestGrepFilesTool:
|
||||||
|
def test_grep_finds_matches(self, workspace: Path, config: AppConfig) -> None:
|
||||||
|
tool = GrepFilesTool(workspace, config)
|
||||||
|
result = tool.run("tc-1", {"pattern": "BaseTool", "path": "app/tools", "file_pattern": "*.py"})
|
||||||
|
assert result.status == ToolResultStatus.SUCCESS
|
||||||
|
assert "BaseTool" in result.output
|
||||||
|
|
||||||
|
def test_grep_no_matches(self, workspace: Path, config: AppConfig) -> None:
|
||||||
|
tool = GrepFilesTool(workspace, config)
|
||||||
|
# Search only in config/ to avoid matching the test file itself
|
||||||
|
result = tool.run("tc-2", {"pattern": "zzz_will_never_match_zzz", "path": "config"})
|
||||||
|
assert result.status == ToolResultStatus.SUCCESS
|
||||||
|
assert "No matches" in result.output
|
||||||
|
|
||||||
|
def test_grep_path_traversal_blocked(self, workspace: Path, config: AppConfig) -> None:
|
||||||
|
tool = GrepFilesTool(workspace, config)
|
||||||
|
result = tool.run("tc-3", {"pattern": "root", "path": "../../etc"})
|
||||||
|
assert result.status == ToolResultStatus.ERROR
|
||||||
|
|
||||||
|
|
||||||
|
# --- FindFilesTool ---
|
||||||
|
|
||||||
|
|
||||||
|
class TestFindFilesTool:
|
||||||
|
def test_find_files(self, workspace: Path, config: AppConfig) -> None:
|
||||||
|
tool = FindFilesTool(workspace, config)
|
||||||
|
result = tool.run("tc-1", {"pattern": "*.py", "path": "app/tools"})
|
||||||
|
assert result.status == ToolResultStatus.SUCCESS
|
||||||
|
assert "base.py" in result.output
|
||||||
|
|
||||||
|
def test_find_no_results(self, workspace: Path, config: AppConfig) -> None:
|
||||||
|
tool = FindFilesTool(workspace, config)
|
||||||
|
result = tool.run("tc-2", {"pattern": "*.nonexistent_ext"})
|
||||||
|
assert result.status == ToolResultStatus.SUCCESS
|
||||||
|
assert "No files found" in result.output
|
||||||
|
|
||||||
|
def test_find_relative_paths(self, workspace: Path, config: AppConfig) -> None:
|
||||||
|
tool = FindFilesTool(workspace, config)
|
||||||
|
result = tool.run("tc-3", {"pattern": "config.yaml"})
|
||||||
|
assert result.status == ToolResultStatus.SUCCESS
|
||||||
|
# Should be relative, not absolute
|
||||||
|
assert not result.output.startswith("/")
|
||||||
Reference in New Issue
Block a user