Compare commits
2 Commits
501bf5c45b
...
4623489564
| Author | SHA1 | Date | |
|---|---|---|---|
| 4623489564 | |||
| 91187a0728 |
230
app/agent/loop.py
Normal file
230
app/agent/loop.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
"""AgentLoop — ReAct-style tool-call loop for autonomous task execution."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agent.context import SessionContext
|
||||||
|
from app.models.config import AppConfig
|
||||||
|
from app.models.message import Message
|
||||||
|
from app.models.tool_call import ToolCall, ToolResult, ToolResultStatus
|
||||||
|
from app.services.llm import LLMClient, LLMConnectionError, LLMError
|
||||||
|
from app.services.permissions import PermissionsService
|
||||||
|
from app.services.streaming import StreamHandler
|
||||||
|
from app.tools.registry import ToolRegistry
|
||||||
|
from app.utils.display import (
|
||||||
|
print_error,
|
||||||
|
print_iteration_header,
|
||||||
|
print_tool_call,
|
||||||
|
print_tool_result,
|
||||||
|
print_token_usage,
|
||||||
|
print_warning,
|
||||||
|
)
|
||||||
|
from app.utils.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentLoop:
|
||||||
|
"""ReAct-style agent loop that streams LLM responses and executes tool calls.
|
||||||
|
|
||||||
|
The loop sends conversation history to the LLM, parses tool calls from the
|
||||||
|
response, executes them with permission checks, feeds results back, and
|
||||||
|
repeats until the LLM produces a plain-text response or calls ``finish``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: AppConfig,
|
||||||
|
ctx: SessionContext,
|
||||||
|
client: LLMClient,
|
||||||
|
handler: StreamHandler,
|
||||||
|
registry: ToolRegistry,
|
||||||
|
permissions: PermissionsService,
|
||||||
|
) -> None:
|
||||||
|
self._config = config
|
||||||
|
self._ctx = ctx
|
||||||
|
self._client = client
|
||||||
|
self._handler = handler
|
||||||
|
self._registry = registry
|
||||||
|
self._permissions = permissions
|
||||||
|
self._tools_schema = registry.get_openai_tools_schema()
|
||||||
|
self._system_prompt = self._build_system_prompt()
|
||||||
|
|
||||||
|
def _build_system_prompt(self) -> str:
|
||||||
|
"""Build the system prompt including tool schemas and agent instructions."""
|
||||||
|
tool_names = [t["function"]["name"] for t in self._tools_schema]
|
||||||
|
return (
|
||||||
|
"You are SneakyCode, a local AI coding agent. "
|
||||||
|
"You help users with software engineering tasks by reading files, "
|
||||||
|
"searching code, and answering questions about their project.\n\n"
|
||||||
|
f"Workspace root: {self._config.agent.workspace_root}\n\n"
|
||||||
|
"Available tools: " + ", ".join(tool_names) + "\n\n"
|
||||||
|
"When you have fully completed the user's request, call the `finish` tool "
|
||||||
|
"with a brief summary. If you can answer directly without tools, just respond "
|
||||||
|
"with text (no tool call needed)."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_messages_with_system_prompt(self) -> list[Message]:
|
||||||
|
"""Prepend the system prompt to conversation history."""
|
||||||
|
system_msg = Message(role="system", content=self._system_prompt)
|
||||||
|
return [system_msg] + self._ctx.get_history()
|
||||||
|
|
||||||
|
async def run_turn(self, user_input: str) -> None:
|
||||||
|
"""Execute one full agent turn: add user message, loop until done.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_input: The user's message text.
|
||||||
|
"""
|
||||||
|
self._ctx.add_message("user", user_input)
|
||||||
|
|
||||||
|
max_iter = self._config.agent.max_iterations
|
||||||
|
for iteration in range(1, max_iter + 1):
|
||||||
|
# Check token budget
|
||||||
|
if self._ctx.token_counter.is_over_budget():
|
||||||
|
print_warning("Token budget exceeded. Stopping agent loop.")
|
||||||
|
break
|
||||||
|
|
||||||
|
if iteration > 1:
|
||||||
|
print_iteration_header(iteration, max_iter)
|
||||||
|
|
||||||
|
# Stream LLM response
|
||||||
|
assistant_msg = await self._llm_step()
|
||||||
|
if assistant_msg is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Record assistant message
|
||||||
|
self._ctx.add_message(
|
||||||
|
"assistant",
|
||||||
|
assistant_msg.content,
|
||||||
|
tool_calls=assistant_msg.tool_calls,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Record token usage
|
||||||
|
if self._handler.usage:
|
||||||
|
self._ctx.token_counter.count_usage(self._handler.usage)
|
||||||
|
|
||||||
|
if self._config.display.show_token_usage:
|
||||||
|
total = self._ctx.token_counter.cumulative_usage.total_tokens
|
||||||
|
if total == 0:
|
||||||
|
total = self._ctx.estimated_tokens
|
||||||
|
print_token_usage(total, self._ctx.token_counter.budget)
|
||||||
|
|
||||||
|
self._handler.reset()
|
||||||
|
|
||||||
|
# No tool calls → task complete (plain text response)
|
||||||
|
if not assistant_msg.tool_calls:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Execute tool calls
|
||||||
|
results = self._execute_tool_calls(assistant_msg.tool_calls)
|
||||||
|
|
||||||
|
# Add tool results to context
|
||||||
|
for result in results:
|
||||||
|
content = result.output if result.status == ToolResultStatus.SUCCESS else (result.error or "Unknown error")
|
||||||
|
self._ctx.add_message(
|
||||||
|
"tool",
|
||||||
|
content,
|
||||||
|
tool_call_id=result.tool_call_id,
|
||||||
|
name=result.tool_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if finish tool was called
|
||||||
|
if any(r.tool_name == "finish" for r in results):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
print_warning(f"Agent reached maximum iterations ({max_iter}). Stopping.")
|
||||||
|
|
||||||
|
async def _llm_step(self) -> Message | None:
|
||||||
|
"""Stream one LLM response and return the accumulated Message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The assistant Message, or None if an error occurred.
|
||||||
|
"""
|
||||||
|
messages = self._get_messages_with_system_prompt()
|
||||||
|
try:
|
||||||
|
chunk_iter = self._client.stream_chat(messages, tools=self._tools_schema)
|
||||||
|
return await self._handler.process_stream(chunk_iter)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print_warning("Response interrupted.")
|
||||||
|
self._handler.reset()
|
||||||
|
return None
|
||||||
|
except LLMConnectionError as e:
|
||||||
|
print_error(f"Connection error: {e}")
|
||||||
|
return None
|
||||||
|
except LLMError as e:
|
||||||
|
print_error(f"LLM error: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _execute_tool_calls(self, tool_calls: list[ToolCall]) -> list[ToolResult]:
|
||||||
|
"""Execute a list of tool calls with permission checks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_calls: Tool calls from the LLM response.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ToolResult objects (one per tool call).
|
||||||
|
"""
|
||||||
|
results: list[ToolResult] = []
|
||||||
|
available_names = list(self._registry.get_all().keys())
|
||||||
|
|
||||||
|
for tc in tool_calls:
|
||||||
|
name = tc.function.name
|
||||||
|
tc_id = tc.id
|
||||||
|
|
||||||
|
# Display the tool call
|
||||||
|
if self._config.display.show_tool_calls:
|
||||||
|
print_tool_call(name, tc.function.arguments)
|
||||||
|
|
||||||
|
# Parse arguments
|
||||||
|
try:
|
||||||
|
parsed_args: dict[str, Any] = json.loads(tc.function.arguments) if tc.function.arguments else {}
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
result = ToolResult(
|
||||||
|
tool_call_id=tc_id,
|
||||||
|
tool_name=name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=f"Invalid JSON in arguments: {e}",
|
||||||
|
)
|
||||||
|
results.append(result)
|
||||||
|
if self._config.display.show_tool_calls:
|
||||||
|
print_tool_result(name, result.error or "", is_error=True)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Look up tool
|
||||||
|
tool = self._registry.get(name)
|
||||||
|
if tool is None:
|
||||||
|
result = ToolResult(
|
||||||
|
tool_call_id=tc_id,
|
||||||
|
tool_name=name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=f"Unknown tool '{name}'. Available: {available_names}",
|
||||||
|
)
|
||||||
|
results.append(result)
|
||||||
|
if self._config.display.show_tool_calls:
|
||||||
|
print_tool_result(name, result.error or "", is_error=True)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check permissions (truncate args for display in prompt)
|
||||||
|
desc = tc.function.arguments[:120] + "..." if len(tc.function.arguments) > 120 else tc.function.arguments
|
||||||
|
if not self._permissions.check(name, description=desc):
|
||||||
|
result = ToolResult(
|
||||||
|
tool_call_id=tc_id,
|
||||||
|
tool_name=name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=f"Permission denied for tool '{name}'",
|
||||||
|
)
|
||||||
|
results.append(result)
|
||||||
|
if self._config.display.show_tool_calls:
|
||||||
|
print_tool_result(name, result.error or "", is_error=True)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Execute tool (BaseTool.run never raises)
|
||||||
|
result = tool.run(tc_id, parsed_args)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
if self._config.display.show_tool_calls:
|
||||||
|
is_error = result.status == ToolResultStatus.ERROR
|
||||||
|
output = result.error if is_error else result.output
|
||||||
|
print_tool_result(name, output or "", is_error=is_error)
|
||||||
|
|
||||||
|
return results
|
||||||
53
app/main.py
53
app/main.py
@@ -8,16 +8,18 @@ from pathlib import Path
|
|||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
from app.agent.context import SessionContext
|
from app.agent.context import SessionContext
|
||||||
|
from app.agent.loop import AgentLoop
|
||||||
from app.models.config import AppConfig, load_config
|
from app.models.config import AppConfig, load_config
|
||||||
from app.services.llm import LLMClient, LLMConnectionError, LLMError
|
from app.services.llm import LLMClient, LLMConnectionError, LLMError
|
||||||
|
from app.services.permissions import PermissionsService
|
||||||
from app.services.streaming import StreamHandler
|
from app.services.streaming import StreamHandler
|
||||||
|
from app.tools.registry import create_default_registry
|
||||||
from app.utils.display import (
|
from app.utils.display import (
|
||||||
print_banner,
|
print_banner,
|
||||||
print_error,
|
print_error,
|
||||||
print_history,
|
print_history,
|
||||||
print_info,
|
print_info,
|
||||||
print_success,
|
print_success,
|
||||||
print_token_usage,
|
|
||||||
print_user_message,
|
print_user_message,
|
||||||
print_warning,
|
print_warning,
|
||||||
)
|
)
|
||||||
@@ -73,8 +75,12 @@ async def _run_repl(
|
|||||||
config: Application configuration.
|
config: Application configuration.
|
||||||
logger: Structured logger instance.
|
logger: Structured logger instance.
|
||||||
"""
|
"""
|
||||||
|
registry = create_default_registry(config.agent.workspace_root, config)
|
||||||
|
permissions = PermissionsService(config.permissions)
|
||||||
|
|
||||||
async with LLMClient(config.llm) as client:
|
async with LLMClient(config.llm) as client:
|
||||||
handler = StreamHandler(config.display)
|
handler = StreamHandler(config.display)
|
||||||
|
agent = AgentLoop(config, ctx, client, handler, registry, permissions)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@@ -102,50 +108,9 @@ async def _run_repl(
|
|||||||
print_warning(f"Unknown command: {user_input}")
|
print_warning(f"Unknown command: {user_input}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Add user message and display it
|
|
||||||
ctx.add_message("user", user_input)
|
|
||||||
print_user_message(user_input)
|
print_user_message(user_input)
|
||||||
|
await agent.run_turn(user_input)
|
||||||
# Stream LLM response
|
logger.debug("turn_complete", message_count=ctx.message_count)
|
||||||
try:
|
|
||||||
chunk_iter = client.stream_chat(ctx.get_history())
|
|
||||||
assistant_msg = await handler.process_stream(chunk_iter)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print_warning("Response interrupted.")
|
|
||||||
handler.reset()
|
|
||||||
continue
|
|
||||||
except LLMConnectionError as e:
|
|
||||||
print_error(f"Connection error: {e}")
|
|
||||||
continue
|
|
||||||
except LLMError as e:
|
|
||||||
print_error(f"LLM error: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Handle empty response
|
|
||||||
if assistant_msg.content is None and assistant_msg.tool_calls is None:
|
|
||||||
print_warning("Received empty response from model.")
|
|
||||||
|
|
||||||
# Record assistant message in history
|
|
||||||
ctx.add_message(
|
|
||||||
"assistant",
|
|
||||||
assistant_msg.content,
|
|
||||||
tool_calls=assistant_msg.tool_calls,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Record API token usage if available, fall back to heuristic
|
|
||||||
if handler.usage:
|
|
||||||
ctx.token_counter.count_usage(handler.usage)
|
|
||||||
|
|
||||||
# Show token usage if configured
|
|
||||||
if config.display.show_token_usage:
|
|
||||||
print_token_usage(
|
|
||||||
ctx.token_counter.cumulative_usage.total_tokens
|
|
||||||
or ctx.estimated_tokens,
|
|
||||||
ctx.token_counter.budget,
|
|
||||||
)
|
|
||||||
|
|
||||||
handler.reset()
|
|
||||||
logger.debug("message_exchanged", message_count=ctx.message_count)
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
|
|||||||
@@ -35,8 +35,9 @@ class Message(BaseModel):
|
|||||||
"""
|
"""
|
||||||
data: dict[str, Any] = {"role": self.role}
|
data: dict[str, Any] = {"role": self.role}
|
||||||
|
|
||||||
if self.content is not None:
|
# Ollama requires content to be a string, never null/missing —
|
||||||
data["content"] = self.content
|
# even on assistant messages that only contain tool_calls.
|
||||||
|
data["content"] = self.content or ""
|
||||||
|
|
||||||
if self.tool_calls is not None:
|
if self.tool_calls is not None:
|
||||||
data["tool_calls"] = [tc.model_dump() for tc in self.tool_calls]
|
data["tool_calls"] = [tc.model_dump() for tc in self.tool_calls]
|
||||||
|
|||||||
@@ -2,8 +2,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from contextlib import asynccontextmanager
|
from typing import Any, Self
|
||||||
from typing import Self
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
@@ -101,11 +100,16 @@ class LLMClient:
|
|||||||
f"Model '{model}' not found. Available models: {available_str}"
|
f"Model '{model}' not found. Available models: {available_str}"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def stream_chat(self, messages: list[Message]) -> AsyncIterator[dict]:
|
async def stream_chat(
|
||||||
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
) -> AsyncIterator[dict]:
|
||||||
"""Stream a chat completion request, yielding parsed SSE chunks.
|
"""Stream a chat completion request, yielding parsed SSE chunks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: Conversation history to send to the model.
|
messages: Conversation history to send to the model.
|
||||||
|
tools: Optional OpenAI function-calling tool schemas.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Parsed JSON dicts from each SSE data line.
|
Parsed JSON dicts from each SSE data line.
|
||||||
@@ -115,7 +119,7 @@ class LLMClient:
|
|||||||
LLMResponseError: On non-2xx HTTP status.
|
LLMResponseError: On non-2xx HTTP status.
|
||||||
LLMStreamError: On malformed SSE data (only if every line fails).
|
LLMStreamError: On malformed SSE data (only if every line fails).
|
||||||
"""
|
"""
|
||||||
payload = {
|
payload: dict[str, Any] = {
|
||||||
"model": self._config.model,
|
"model": self._config.model,
|
||||||
"messages": [m.to_api_dict() for m in messages],
|
"messages": [m.to_api_dict() for m in messages],
|
||||||
"stream": True,
|
"stream": True,
|
||||||
@@ -123,6 +127,9 @@ class LLMClient:
|
|||||||
"max_tokens": self._config.max_tokens,
|
"max_tokens": self._config.max_tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
payload["tools"] = tools
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with self._client.stream(
|
async with self._client.stream(
|
||||||
"POST", self._config.api_path, json=payload
|
"POST", self._config.api_path, json=payload
|
||||||
|
|||||||
35
app/tools/finish.py
Normal file
35
app/tools/finish.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
"""FinishTool — signals agent loop termination."""
|
||||||
|
|
||||||
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.models.config import AppConfig
|
||||||
|
from app.models.tool_call import ToolResult, ToolResultStatus
|
||||||
|
from app.tools.base import BaseTool
|
||||||
|
|
||||||
|
|
||||||
|
class FinishParams(BaseModel):
|
||||||
|
"""Parameters for the finish tool."""
|
||||||
|
|
||||||
|
message: str = Field(default="Task complete.", description="Final message to the user")
|
||||||
|
|
||||||
|
|
||||||
|
class FinishTool(BaseTool):
|
||||||
|
"""Signal that the agent has completed its task."""
|
||||||
|
|
||||||
|
name: ClassVar[str] = "finish"
|
||||||
|
description: ClassVar[str] = (
|
||||||
|
"Call this tool when you have completed the user's task. "
|
||||||
|
"Provide a brief summary message."
|
||||||
|
)
|
||||||
|
params_model: ClassVar[type[BaseModel]] = FinishParams
|
||||||
|
|
||||||
|
def execute(self, *, tool_call_id: str, **kwargs: Any) -> ToolResult:
|
||||||
|
message = kwargs.get("message", "Task complete.")
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.SUCCESS,
|
||||||
|
output=message,
|
||||||
|
)
|
||||||
@@ -39,6 +39,7 @@ class ToolRegistry:
|
|||||||
def create_default_registry(workspace_root: Path, config: AppConfig) -> ToolRegistry:
|
def create_default_registry(workspace_root: Path, config: AppConfig) -> ToolRegistry:
|
||||||
"""Create a ToolRegistry populated with all built-in tools."""
|
"""Create a ToolRegistry populated with all built-in tools."""
|
||||||
from app.tools.filesystem import ListDirTool, ReadFileTool
|
from app.tools.filesystem import ListDirTool, ReadFileTool
|
||||||
|
from app.tools.finish import FinishTool
|
||||||
from app.tools.search import FindFilesTool, GrepFilesTool
|
from app.tools.search import FindFilesTool, GrepFilesTool
|
||||||
|
|
||||||
registry = ToolRegistry()
|
registry = ToolRegistry()
|
||||||
@@ -46,4 +47,5 @@ def create_default_registry(workspace_root: Path, config: AppConfig) -> ToolRegi
|
|||||||
registry.register(ListDirTool(workspace_root, config))
|
registry.register(ListDirTool(workspace_root, config))
|
||||||
registry.register(GrepFilesTool(workspace_root, config))
|
registry.register(GrepFilesTool(workspace_root, config))
|
||||||
registry.register(FindFilesTool(workspace_root, config))
|
registry.register(FindFilesTool(workspace_root, config))
|
||||||
|
registry.register(FinishTool(workspace_root, config))
|
||||||
return registry
|
return registry
|
||||||
|
|||||||
@@ -61,8 +61,33 @@ def print_assistant_message(content: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def print_tool_call(name: str, args: str) -> None:
|
def print_tool_call(name: str, args: str) -> None:
|
||||||
"""Print a tool call summary (stub for Phase 4)."""
|
"""Print a compact tool call line — tool name + truncated key args."""
|
||||||
console.print(f"[tool]Tool: {name}[/tool] [dim]{args}[/dim]")
|
truncated_args = args[:80] + "..." if len(args) > 80 else args
|
||||||
|
console.print(f" [tool]{name}[/tool] [dim]{truncated_args}[/dim]")
|
||||||
|
|
||||||
|
|
||||||
|
def print_tool_result(name: str, output: str, is_error: bool = False) -> None:
|
||||||
|
"""Print a compact tool result — status line only for success, detail for errors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Tool name.
|
||||||
|
output: Tool output or error message.
|
||||||
|
is_error: Whether this is an error result.
|
||||||
|
"""
|
||||||
|
if is_error:
|
||||||
|
# Errors are shown prominently so the user knows something went wrong
|
||||||
|
truncated = output[:200] + "..." if len(output) > 200 else output
|
||||||
|
console.print(f" [error]{name}: {truncated}[/error]")
|
||||||
|
else:
|
||||||
|
# Success: just show a compact byte/line summary
|
||||||
|
lines = output.count("\n") + 1 if output else 0
|
||||||
|
chars = len(output)
|
||||||
|
console.print(f" [dim]{name} — {lines} lines, {chars} chars[/dim]")
|
||||||
|
|
||||||
|
|
||||||
|
def print_iteration_header(iteration: int, max_iterations: int) -> None:
|
||||||
|
"""Print the current agent loop iteration."""
|
||||||
|
console.print(f"[dim]── iteration {iteration}/{max_iterations} ──[/dim]")
|
||||||
|
|
||||||
|
|
||||||
def print_token_usage(usage_tokens: int, budget: int) -> None:
|
def print_token_usage(usage_tokens: int, budget: int) -> None:
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ permissions:
|
|||||||
- list_dir
|
- list_dir
|
||||||
- grep_files
|
- grep_files
|
||||||
- find_files
|
- find_files
|
||||||
|
- finish
|
||||||
prompt_user:
|
prompt_user:
|
||||||
- write_file
|
- write_file
|
||||||
- delete_file
|
- delete_file
|
||||||
|
|||||||
289
tests/unit/test_agent_loop.py
Normal file
289
tests/unit/test_agent_loop.py
Normal file
@@ -0,0 +1,289 @@
|
|||||||
|
"""Unit tests for the AgentLoop ReAct-style tool-call loop."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agent.context import SessionContext
|
||||||
|
from app.agent.loop import AgentLoop
|
||||||
|
from app.models.config import (
|
||||||
|
AgentConfig,
|
||||||
|
AppConfig,
|
||||||
|
DisplayConfig,
|
||||||
|
LLMConfig,
|
||||||
|
PermissionsConfig,
|
||||||
|
ToolsConfig,
|
||||||
|
)
|
||||||
|
from app.models.message import Message
|
||||||
|
from app.models.tool_call import ToolCall, ToolCallFunction, ToolResult, ToolResultStatus
|
||||||
|
from app.services.llm import LLMClient
|
||||||
|
from app.services.permissions import PermissionsService
|
||||||
|
from app.services.streaming import StreamHandler
|
||||||
|
from app.tools.registry import ToolRegistry, create_default_registry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config() -> AppConfig:
|
||||||
|
return AppConfig(
|
||||||
|
llm=LLMConfig(
|
||||||
|
model="test-model",
|
||||||
|
endpoint="http://localhost:11434",
|
||||||
|
),
|
||||||
|
agent=AgentConfig(
|
||||||
|
max_iterations=5,
|
||||||
|
max_conversation_tokens=32000,
|
||||||
|
workspace_root=Path("/tmp/test-workspace"),
|
||||||
|
),
|
||||||
|
permissions=PermissionsConfig(
|
||||||
|
auto_approve=["read_file", "list_dir", "grep_files", "find_files", "finish"],
|
||||||
|
),
|
||||||
|
display=DisplayConfig(
|
||||||
|
show_tool_calls=True,
|
||||||
|
show_token_usage=False,
|
||||||
|
stream_output=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def ctx(config: AppConfig) -> SessionContext:
|
||||||
|
return SessionContext(config)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client() -> MagicMock:
|
||||||
|
return MagicMock(spec=LLMClient)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def handler() -> MagicMock:
|
||||||
|
mock = MagicMock(spec=StreamHandler)
|
||||||
|
mock.usage = None
|
||||||
|
mock.reset = MagicMock()
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def registry(config: AppConfig) -> ToolRegistry:
|
||||||
|
return create_default_registry(config.agent.workspace_root, config)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def permissions(config: AppConfig) -> PermissionsService:
|
||||||
|
return PermissionsService(config.permissions)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def agent(
|
||||||
|
config: AppConfig,
|
||||||
|
ctx: SessionContext,
|
||||||
|
client: MagicMock,
|
||||||
|
handler: MagicMock,
|
||||||
|
registry: ToolRegistry,
|
||||||
|
permissions: PermissionsService,
|
||||||
|
) -> AgentLoop:
|
||||||
|
return AgentLoop(config, ctx, client, handler, registry, permissions)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_text_message(content: str) -> Message:
|
||||||
|
"""Helper: create an assistant message with text only (no tool calls)."""
|
||||||
|
return Message(role="assistant", content=content, tool_calls=None)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_tool_call_message(
|
||||||
|
tool_name: str,
|
||||||
|
arguments: str,
|
||||||
|
tc_id: str = "call_001",
|
||||||
|
content: str | None = None,
|
||||||
|
) -> Message:
|
||||||
|
"""Helper: create an assistant message with a single tool call."""
|
||||||
|
return Message(
|
||||||
|
role="assistant",
|
||||||
|
content=content,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
id=tc_id,
|
||||||
|
type="function",
|
||||||
|
function=ToolCallFunction(name=tool_name, arguments=arguments),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentLoop:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_plain_text_response(self, agent: AgentLoop, handler: MagicMock, ctx: SessionContext) -> None:
|
||||||
|
"""LLM returns text with no tool calls — loop completes in 1 iteration."""
|
||||||
|
handler.process_stream = AsyncMock(return_value=_make_text_message("Hello!"))
|
||||||
|
|
||||||
|
await agent.run_turn("Hi there")
|
||||||
|
|
||||||
|
assert handler.process_stream.call_count == 1
|
||||||
|
# History: user + assistant
|
||||||
|
history = ctx.get_history()
|
||||||
|
assert len(history) == 2
|
||||||
|
assert history[0].role == "user"
|
||||||
|
assert history[1].role == "assistant"
|
||||||
|
assert history[1].content == "Hello!"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_single_tool_call(self, agent: AgentLoop, handler: MagicMock, ctx: SessionContext) -> None:
|
||||||
|
"""LLM calls a tool, then responds with text — 2 LLM calls."""
|
||||||
|
handler.process_stream = AsyncMock(
|
||||||
|
side_effect=[
|
||||||
|
_make_tool_call_message("list_dir", '{"directory_path": "."}'),
|
||||||
|
_make_text_message("Here are the files."),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
await agent.run_turn("List files")
|
||||||
|
|
||||||
|
assert handler.process_stream.call_count == 2
|
||||||
|
history = ctx.get_history()
|
||||||
|
# user, assistant (tool_call), tool (result), assistant (text)
|
||||||
|
assert len(history) == 4
|
||||||
|
assert history[0].role == "user"
|
||||||
|
assert history[1].role == "assistant"
|
||||||
|
assert history[1].tool_calls is not None
|
||||||
|
assert history[2].role == "tool"
|
||||||
|
assert history[3].role == "assistant"
|
||||||
|
assert history[3].content == "Here are the files."
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_finish_tool_breaks_loop(self, agent: AgentLoop, handler: MagicMock, ctx: SessionContext) -> None:
|
||||||
|
"""Calling the finish tool terminates the loop immediately."""
|
||||||
|
handler.process_stream = AsyncMock(
|
||||||
|
return_value=_make_tool_call_message("finish", '{"message": "All done!"}'),
|
||||||
|
)
|
||||||
|
|
||||||
|
await agent.run_turn("Do something")
|
||||||
|
|
||||||
|
assert handler.process_stream.call_count == 1
|
||||||
|
history = ctx.get_history()
|
||||||
|
# user, assistant (finish call), tool (finish result)
|
||||||
|
assert len(history) == 3
|
||||||
|
assert history[2].role == "tool"
|
||||||
|
assert "All done!" in (history[2].content or "")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_max_iterations(self, agent: AgentLoop, handler: MagicMock, config: AppConfig) -> None:
|
||||||
|
"""Loop stops at max_iterations when LLM keeps calling tools."""
|
||||||
|
handler.process_stream = AsyncMock(
|
||||||
|
return_value=_make_tool_call_message("list_dir", '{"directory_path": "."}'),
|
||||||
|
)
|
||||||
|
|
||||||
|
await agent.run_turn("Keep going")
|
||||||
|
|
||||||
|
# Should call LLM max_iterations times
|
||||||
|
assert handler.process_stream.call_count == config.agent.max_iterations
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bad_json_arguments(self, agent: AgentLoop, handler: MagicMock, ctx: SessionContext) -> None:
|
||||||
|
"""Invalid JSON in tool arguments produces an error result, no exception."""
|
||||||
|
handler.process_stream = AsyncMock(
|
||||||
|
side_effect=[
|
||||||
|
_make_tool_call_message("list_dir", "not valid json{{{"),
|
||||||
|
_make_text_message("Sorry about that."),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
await agent.run_turn("Bad args")
|
||||||
|
|
||||||
|
history = ctx.get_history()
|
||||||
|
# user, assistant (bad call), tool (error), assistant (apology)
|
||||||
|
assert len(history) == 4
|
||||||
|
tool_msg = history[2]
|
||||||
|
assert tool_msg.role == "tool"
|
||||||
|
assert "Invalid JSON" in (tool_msg.content or "")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unknown_tool(self, agent: AgentLoop, handler: MagicMock, ctx: SessionContext) -> None:
|
||||||
|
"""Unknown tool name produces an error result listing available tools."""
|
||||||
|
handler.process_stream = AsyncMock(
|
||||||
|
side_effect=[
|
||||||
|
_make_tool_call_message("nonexistent_tool", "{}"),
|
||||||
|
_make_text_message("I'll try something else."),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
await agent.run_turn("Use fake tool")
|
||||||
|
|
||||||
|
history = ctx.get_history()
|
||||||
|
tool_msg = history[2]
|
||||||
|
assert tool_msg.role == "tool"
|
||||||
|
assert "Unknown tool" in (tool_msg.content or "")
|
||||||
|
assert "nonexistent_tool" in (tool_msg.content or "")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_permission_denied(self, agent: AgentLoop, handler: MagicMock, ctx: SessionContext, config: AppConfig) -> None:
|
||||||
|
"""Denied tool produces an error result."""
|
||||||
|
# Add list_dir to deny list
|
||||||
|
config.permissions.deny.append("list_dir")
|
||||||
|
# Recreate permissions service with updated config
|
||||||
|
agent._permissions = PermissionsService(config.permissions)
|
||||||
|
|
||||||
|
handler.process_stream = AsyncMock(
|
||||||
|
side_effect=[
|
||||||
|
_make_tool_call_message("list_dir", '{"directory_path": "."}'),
|
||||||
|
_make_text_message("Permission was denied."),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
await agent.run_turn("List files")
|
||||||
|
|
||||||
|
history = ctx.get_history()
|
||||||
|
tool_msg = history[2]
|
||||||
|
assert tool_msg.role == "tool"
|
||||||
|
assert "Permission denied" in (tool_msg.content or "")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_llm_connection_error_stops_loop(self, agent: AgentLoop, handler: MagicMock, ctx: SessionContext) -> None:
|
||||||
|
"""LLM connection error terminates the loop gracefully."""
|
||||||
|
from app.services.llm import LLMConnectionError
|
||||||
|
|
||||||
|
handler.process_stream = AsyncMock(side_effect=LLMConnectionError("Connection refused"))
|
||||||
|
|
||||||
|
await agent.run_turn("Hello")
|
||||||
|
|
||||||
|
# Only the user message should be in history (no assistant message added)
|
||||||
|
history = ctx.get_history()
|
||||||
|
assert len(history) == 1
|
||||||
|
assert history[0].role == "user"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_tool_calls_in_single_response(self, agent: AgentLoop, handler: MagicMock, ctx: SessionContext) -> None:
|
||||||
|
"""Multiple tool calls in one response are all executed."""
|
||||||
|
multi_tc_msg = Message(
|
||||||
|
role="assistant",
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
id="call_001",
|
||||||
|
type="function",
|
||||||
|
function=ToolCallFunction(name="list_dir", arguments='{"directory_path": "."}'),
|
||||||
|
),
|
||||||
|
ToolCall(
|
||||||
|
id="call_002",
|
||||||
|
type="function",
|
||||||
|
function=ToolCallFunction(name="find_files", arguments='{"pattern": "*.py"}'),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
handler.process_stream = AsyncMock(
|
||||||
|
side_effect=[
|
||||||
|
multi_tc_msg,
|
||||||
|
_make_text_message("Found everything."),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
await agent.run_turn("List and find files")
|
||||||
|
|
||||||
|
history = ctx.get_history()
|
||||||
|
# user, assistant (2 tool calls), tool (result 1), tool (result 2), assistant (text)
|
||||||
|
assert len(history) == 5
|
||||||
|
assert history[2].role == "tool"
|
||||||
|
assert history[2].tool_call_id == "call_001"
|
||||||
|
assert history[3].role == "tool"
|
||||||
|
assert history[3].tool_call_id == "call_002"
|
||||||
|
assert history[4].content == "Found everything."
|
||||||
@@ -96,12 +96,12 @@ class TestToolRegistry:
|
|||||||
def test_create_default_registry(self, workspace: Path, config: AppConfig) -> None:
|
def test_create_default_registry(self, workspace: Path, config: AppConfig) -> None:
|
||||||
registry = create_default_registry(workspace, config)
|
registry = create_default_registry(workspace, config)
|
||||||
names = set(registry.get_all().keys())
|
names = set(registry.get_all().keys())
|
||||||
assert names == {"read_file", "list_dir", "grep_files", "find_files"}
|
assert names == {"read_file", "list_dir", "grep_files", "find_files", "finish"}
|
||||||
|
|
||||||
def test_schema_export(self, workspace: Path, config: AppConfig) -> None:
|
def test_schema_export(self, workspace: Path, config: AppConfig) -> None:
|
||||||
registry = create_default_registry(workspace, config)
|
registry = create_default_registry(workspace, config)
|
||||||
schemas = registry.get_openai_tools_schema()
|
schemas = registry.get_openai_tools_schema()
|
||||||
assert len(schemas) == 4
|
assert len(schemas) == 5
|
||||||
assert all(s["type"] == "function" for s in schemas)
|
assert all(s["type"] == "function" for s in schemas)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user