Compare commits

...

2 Commits

Author SHA1 Message Date
4623489564 Merge branch 'feature/phase-5-agent-loop' 2026-03-11 08:39:04 -05:00
91187a0728 Add Phase 5: ReAct-style agent loop with tool execution
Implement the core autonomy layer — AgentLoop streams LLM responses,
parses tool calls, executes them with permission checks, feeds results
back, and repeats until the task completes or finish is called.

- Add FinishTool for explicit loop termination
- Add tools parameter to LLMClient.stream_chat() for function calling
- Add compact tool result display (status line, not full output)
- Refactor REPL to delegate to AgentLoop.run_turn()
- Fix Ollama null content rejection (always send content as string)
- Add finish to auto_approve permissions
- 9 unit tests for agent loop (34 total, zero regressions)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-11 08:37:22 -05:00
10 changed files with 609 additions and 54 deletions

230
app/agent/loop.py Normal file
View File

@@ -0,0 +1,230 @@
"""AgentLoop — ReAct-style tool-call loop for autonomous task execution."""
import json
from typing import Any
from app.agent.context import SessionContext
from app.models.config import AppConfig
from app.models.message import Message
from app.models.tool_call import ToolCall, ToolResult, ToolResultStatus
from app.services.llm import LLMClient, LLMConnectionError, LLMError
from app.services.permissions import PermissionsService
from app.services.streaming import StreamHandler
from app.tools.registry import ToolRegistry
from app.utils.display import (
print_error,
print_iteration_header,
print_tool_call,
print_tool_result,
print_token_usage,
print_warning,
)
from app.utils.logging import get_logger
logger = get_logger(__name__)
class AgentLoop:
"""ReAct-style agent loop that streams LLM responses and executes tool calls.
The loop sends conversation history to the LLM, parses tool calls from the
response, executes them with permission checks, feeds results back, and
repeats until the LLM produces a plain-text response or calls ``finish``.
"""
def __init__(
self,
config: AppConfig,
ctx: SessionContext,
client: LLMClient,
handler: StreamHandler,
registry: ToolRegistry,
permissions: PermissionsService,
) -> None:
self._config = config
self._ctx = ctx
self._client = client
self._handler = handler
self._registry = registry
self._permissions = permissions
self._tools_schema = registry.get_openai_tools_schema()
self._system_prompt = self._build_system_prompt()
def _build_system_prompt(self) -> str:
"""Build the system prompt including tool schemas and agent instructions."""
tool_names = [t["function"]["name"] for t in self._tools_schema]
return (
"You are SneakyCode, a local AI coding agent. "
"You help users with software engineering tasks by reading files, "
"searching code, and answering questions about their project.\n\n"
f"Workspace root: {self._config.agent.workspace_root}\n\n"
"Available tools: " + ", ".join(tool_names) + "\n\n"
"When you have fully completed the user's request, call the `finish` tool "
"with a brief summary. If you can answer directly without tools, just respond "
"with text (no tool call needed)."
)
def _get_messages_with_system_prompt(self) -> list[Message]:
"""Prepend the system prompt to conversation history."""
system_msg = Message(role="system", content=self._system_prompt)
return [system_msg] + self._ctx.get_history()
async def run_turn(self, user_input: str) -> None:
"""Execute one full agent turn: add user message, loop until done.
Args:
user_input: The user's message text.
"""
self._ctx.add_message("user", user_input)
max_iter = self._config.agent.max_iterations
for iteration in range(1, max_iter + 1):
# Check token budget
if self._ctx.token_counter.is_over_budget():
print_warning("Token budget exceeded. Stopping agent loop.")
break
if iteration > 1:
print_iteration_header(iteration, max_iter)
# Stream LLM response
assistant_msg = await self._llm_step()
if assistant_msg is None:
break
# Record assistant message
self._ctx.add_message(
"assistant",
assistant_msg.content,
tool_calls=assistant_msg.tool_calls,
)
# Record token usage
if self._handler.usage:
self._ctx.token_counter.count_usage(self._handler.usage)
if self._config.display.show_token_usage:
total = self._ctx.token_counter.cumulative_usage.total_tokens
if total == 0:
total = self._ctx.estimated_tokens
print_token_usage(total, self._ctx.token_counter.budget)
self._handler.reset()
# No tool calls → task complete (plain text response)
if not assistant_msg.tool_calls:
break
# Execute tool calls
results = self._execute_tool_calls(assistant_msg.tool_calls)
# Add tool results to context
for result in results:
content = result.output if result.status == ToolResultStatus.SUCCESS else (result.error or "Unknown error")
self._ctx.add_message(
"tool",
content,
tool_call_id=result.tool_call_id,
name=result.tool_name,
)
# Check if finish tool was called
if any(r.tool_name == "finish" for r in results):
break
else:
print_warning(f"Agent reached maximum iterations ({max_iter}). Stopping.")
async def _llm_step(self) -> Message | None:
"""Stream one LLM response and return the accumulated Message.
Returns:
The assistant Message, or None if an error occurred.
"""
messages = self._get_messages_with_system_prompt()
try:
chunk_iter = self._client.stream_chat(messages, tools=self._tools_schema)
return await self._handler.process_stream(chunk_iter)
except KeyboardInterrupt:
print_warning("Response interrupted.")
self._handler.reset()
return None
except LLMConnectionError as e:
print_error(f"Connection error: {e}")
return None
except LLMError as e:
print_error(f"LLM error: {e}")
return None
def _execute_tool_calls(self, tool_calls: list[ToolCall]) -> list[ToolResult]:
"""Execute a list of tool calls with permission checks.
Args:
tool_calls: Tool calls from the LLM response.
Returns:
List of ToolResult objects (one per tool call).
"""
results: list[ToolResult] = []
available_names = list(self._registry.get_all().keys())
for tc in tool_calls:
name = tc.function.name
tc_id = tc.id
# Display the tool call
if self._config.display.show_tool_calls:
print_tool_call(name, tc.function.arguments)
# Parse arguments
try:
parsed_args: dict[str, Any] = json.loads(tc.function.arguments) if tc.function.arguments else {}
except json.JSONDecodeError as e:
result = ToolResult(
tool_call_id=tc_id,
tool_name=name,
status=ToolResultStatus.ERROR,
error=f"Invalid JSON in arguments: {e}",
)
results.append(result)
if self._config.display.show_tool_calls:
print_tool_result(name, result.error or "", is_error=True)
continue
# Look up tool
tool = self._registry.get(name)
if tool is None:
result = ToolResult(
tool_call_id=tc_id,
tool_name=name,
status=ToolResultStatus.ERROR,
error=f"Unknown tool '{name}'. Available: {available_names}",
)
results.append(result)
if self._config.display.show_tool_calls:
print_tool_result(name, result.error or "", is_error=True)
continue
# Check permissions (truncate args for display in prompt)
desc = tc.function.arguments[:120] + "..." if len(tc.function.arguments) > 120 else tc.function.arguments
if not self._permissions.check(name, description=desc):
result = ToolResult(
tool_call_id=tc_id,
tool_name=name,
status=ToolResultStatus.ERROR,
error=f"Permission denied for tool '{name}'",
)
results.append(result)
if self._config.display.show_tool_calls:
print_tool_result(name, result.error or "", is_error=True)
continue
# Execute tool (BaseTool.run never raises)
result = tool.run(tc_id, parsed_args)
results.append(result)
if self._config.display.show_tool_calls:
is_error = result.status == ToolResultStatus.ERROR
output = result.error if is_error else result.output
print_tool_result(name, output or "", is_error=is_error)
return results

View File

@@ -8,16 +8,18 @@ from pathlib import Path
import structlog import structlog
from app.agent.context import SessionContext from app.agent.context import SessionContext
from app.agent.loop import AgentLoop
from app.models.config import AppConfig, load_config from app.models.config import AppConfig, load_config
from app.services.llm import LLMClient, LLMConnectionError, LLMError from app.services.llm import LLMClient, LLMConnectionError, LLMError
from app.services.permissions import PermissionsService
from app.services.streaming import StreamHandler from app.services.streaming import StreamHandler
from app.tools.registry import create_default_registry
from app.utils.display import ( from app.utils.display import (
print_banner, print_banner,
print_error, print_error,
print_history, print_history,
print_info, print_info,
print_success, print_success,
print_token_usage,
print_user_message, print_user_message,
print_warning, print_warning,
) )
@@ -73,8 +75,12 @@ async def _run_repl(
config: Application configuration. config: Application configuration.
logger: Structured logger instance. logger: Structured logger instance.
""" """
registry = create_default_registry(config.agent.workspace_root, config)
permissions = PermissionsService(config.permissions)
async with LLMClient(config.llm) as client: async with LLMClient(config.llm) as client:
handler = StreamHandler(config.display) handler = StreamHandler(config.display)
agent = AgentLoop(config, ctx, client, handler, registry, permissions)
while True: while True:
try: try:
@@ -102,50 +108,9 @@ async def _run_repl(
print_warning(f"Unknown command: {user_input}") print_warning(f"Unknown command: {user_input}")
continue continue
# Add user message and display it
ctx.add_message("user", user_input)
print_user_message(user_input) print_user_message(user_input)
await agent.run_turn(user_input)
# Stream LLM response logger.debug("turn_complete", message_count=ctx.message_count)
try:
chunk_iter = client.stream_chat(ctx.get_history())
assistant_msg = await handler.process_stream(chunk_iter)
except KeyboardInterrupt:
print_warning("Response interrupted.")
handler.reset()
continue
except LLMConnectionError as e:
print_error(f"Connection error: {e}")
continue
except LLMError as e:
print_error(f"LLM error: {e}")
continue
# Handle empty response
if assistant_msg.content is None and assistant_msg.tool_calls is None:
print_warning("Received empty response from model.")
# Record assistant message in history
ctx.add_message(
"assistant",
assistant_msg.content,
tool_calls=assistant_msg.tool_calls,
)
# Record API token usage if available, fall back to heuristic
if handler.usage:
ctx.token_counter.count_usage(handler.usage)
# Show token usage if configured
if config.display.show_token_usage:
print_token_usage(
ctx.token_counter.cumulative_usage.total_tokens
or ctx.estimated_tokens,
ctx.token_counter.budget,
)
handler.reset()
logger.debug("message_exchanged", message_count=ctx.message_count)
def main() -> None: def main() -> None:

View File

@@ -35,8 +35,9 @@ class Message(BaseModel):
""" """
data: dict[str, Any] = {"role": self.role} data: dict[str, Any] = {"role": self.role}
if self.content is not None: # Ollama requires content to be a string, never null/missing —
data["content"] = self.content # even on assistant messages that only contain tool_calls.
data["content"] = self.content or ""
if self.tool_calls is not None: if self.tool_calls is not None:
data["tool_calls"] = [tc.model_dump() for tc in self.tool_calls] data["tool_calls"] = [tc.model_dump() for tc in self.tool_calls]

View File

@@ -2,8 +2,7 @@
import json import json
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from contextlib import asynccontextmanager from typing import Any, Self
from typing import Self
import httpx import httpx
@@ -101,11 +100,16 @@ class LLMClient:
f"Model '{model}' not found. Available models: {available_str}" f"Model '{model}' not found. Available models: {available_str}"
) )
async def stream_chat(self, messages: list[Message]) -> AsyncIterator[dict]: async def stream_chat(
self,
messages: list[Message],
tools: list[dict[str, Any]] | None = None,
) -> AsyncIterator[dict]:
"""Stream a chat completion request, yielding parsed SSE chunks. """Stream a chat completion request, yielding parsed SSE chunks.
Args: Args:
messages: Conversation history to send to the model. messages: Conversation history to send to the model.
tools: Optional OpenAI function-calling tool schemas.
Yields: Yields:
Parsed JSON dicts from each SSE data line. Parsed JSON dicts from each SSE data line.
@@ -115,7 +119,7 @@ class LLMClient:
LLMResponseError: On non-2xx HTTP status. LLMResponseError: On non-2xx HTTP status.
LLMStreamError: On malformed SSE data (only if every line fails). LLMStreamError: On malformed SSE data (only if every line fails).
""" """
payload = { payload: dict[str, Any] = {
"model": self._config.model, "model": self._config.model,
"messages": [m.to_api_dict() for m in messages], "messages": [m.to_api_dict() for m in messages],
"stream": True, "stream": True,
@@ -123,6 +127,9 @@ class LLMClient:
"max_tokens": self._config.max_tokens, "max_tokens": self._config.max_tokens,
} }
if tools:
payload["tools"] = tools
try: try:
async with self._client.stream( async with self._client.stream(
"POST", self._config.api_path, json=payload "POST", self._config.api_path, json=payload

35
app/tools/finish.py Normal file
View File

@@ -0,0 +1,35 @@
"""FinishTool — signals agent loop termination."""
from typing import Any, ClassVar
from pydantic import BaseModel, Field
from app.models.config import AppConfig
from app.models.tool_call import ToolResult, ToolResultStatus
from app.tools.base import BaseTool
class FinishParams(BaseModel):
"""Parameters for the finish tool."""
message: str = Field(default="Task complete.", description="Final message to the user")
class FinishTool(BaseTool):
"""Signal that the agent has completed its task."""
name: ClassVar[str] = "finish"
description: ClassVar[str] = (
"Call this tool when you have completed the user's task. "
"Provide a brief summary message."
)
params_model: ClassVar[type[BaseModel]] = FinishParams
def execute(self, *, tool_call_id: str, **kwargs: Any) -> ToolResult:
message = kwargs.get("message", "Task complete.")
return ToolResult(
tool_call_id=tool_call_id,
tool_name=self.name,
status=ToolResultStatus.SUCCESS,
output=message,
)

View File

@@ -39,6 +39,7 @@ class ToolRegistry:
def create_default_registry(workspace_root: Path, config: AppConfig) -> ToolRegistry: def create_default_registry(workspace_root: Path, config: AppConfig) -> ToolRegistry:
"""Create a ToolRegistry populated with all built-in tools.""" """Create a ToolRegistry populated with all built-in tools."""
from app.tools.filesystem import ListDirTool, ReadFileTool from app.tools.filesystem import ListDirTool, ReadFileTool
from app.tools.finish import FinishTool
from app.tools.search import FindFilesTool, GrepFilesTool from app.tools.search import FindFilesTool, GrepFilesTool
registry = ToolRegistry() registry = ToolRegistry()
@@ -46,4 +47,5 @@ def create_default_registry(workspace_root: Path, config: AppConfig) -> ToolRegi
registry.register(ListDirTool(workspace_root, config)) registry.register(ListDirTool(workspace_root, config))
registry.register(GrepFilesTool(workspace_root, config)) registry.register(GrepFilesTool(workspace_root, config))
registry.register(FindFilesTool(workspace_root, config)) registry.register(FindFilesTool(workspace_root, config))
registry.register(FinishTool(workspace_root, config))
return registry return registry

View File

@@ -61,8 +61,33 @@ def print_assistant_message(content: str) -> None:
def print_tool_call(name: str, args: str) -> None: def print_tool_call(name: str, args: str) -> None:
"""Print a tool call summary (stub for Phase 4).""" """Print a compact tool call line — tool name + truncated key args."""
console.print(f"[tool]Tool: {name}[/tool] [dim]{args}[/dim]") truncated_args = args[:80] + "..." if len(args) > 80 else args
console.print(f" [tool]{name}[/tool] [dim]{truncated_args}[/dim]")
def print_tool_result(name: str, output: str, is_error: bool = False) -> None:
"""Print a compact tool result — status line only for success, detail for errors.
Args:
name: Tool name.
output: Tool output or error message.
is_error: Whether this is an error result.
"""
if is_error:
# Errors are shown prominently so the user knows something went wrong
truncated = output[:200] + "..." if len(output) > 200 else output
console.print(f" [error]{name}: {truncated}[/error]")
else:
# Success: just show a compact byte/line summary
lines = output.count("\n") + 1 if output else 0
chars = len(output)
console.print(f" [dim]{name}{lines} lines, {chars} chars[/dim]")
def print_iteration_header(iteration: int, max_iterations: int) -> None:
"""Print the current agent loop iteration."""
console.print(f"[dim]── iteration {iteration}/{max_iterations} ──[/dim]")
def print_token_usage(usage_tokens: int, budget: int) -> None: def print_token_usage(usage_tokens: int, budget: int) -> None:

View File

@@ -19,6 +19,7 @@ permissions:
- list_dir - list_dir
- grep_files - grep_files
- find_files - find_files
- finish
prompt_user: prompt_user:
- write_file - write_file
- delete_file - delete_file

View File

@@ -0,0 +1,289 @@
"""Unit tests for the AgentLoop ReAct-style tool-call loop."""
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.agent.context import SessionContext
from app.agent.loop import AgentLoop
from app.models.config import (
AgentConfig,
AppConfig,
DisplayConfig,
LLMConfig,
PermissionsConfig,
ToolsConfig,
)
from app.models.message import Message
from app.models.tool_call import ToolCall, ToolCallFunction, ToolResult, ToolResultStatus
from app.services.llm import LLMClient
from app.services.permissions import PermissionsService
from app.services.streaming import StreamHandler
from app.tools.registry import ToolRegistry, create_default_registry
@pytest.fixture
def config() -> AppConfig:
return AppConfig(
llm=LLMConfig(
model="test-model",
endpoint="http://localhost:11434",
),
agent=AgentConfig(
max_iterations=5,
max_conversation_tokens=32000,
workspace_root=Path("/tmp/test-workspace"),
),
permissions=PermissionsConfig(
auto_approve=["read_file", "list_dir", "grep_files", "find_files", "finish"],
),
display=DisplayConfig(
show_tool_calls=True,
show_token_usage=False,
stream_output=False,
),
)
@pytest.fixture
def ctx(config: AppConfig) -> SessionContext:
return SessionContext(config)
@pytest.fixture
def client() -> MagicMock:
return MagicMock(spec=LLMClient)
@pytest.fixture
def handler() -> MagicMock:
mock = MagicMock(spec=StreamHandler)
mock.usage = None
mock.reset = MagicMock()
return mock
@pytest.fixture
def registry(config: AppConfig) -> ToolRegistry:
return create_default_registry(config.agent.workspace_root, config)
@pytest.fixture
def permissions(config: AppConfig) -> PermissionsService:
return PermissionsService(config.permissions)
@pytest.fixture
def agent(
config: AppConfig,
ctx: SessionContext,
client: MagicMock,
handler: MagicMock,
registry: ToolRegistry,
permissions: PermissionsService,
) -> AgentLoop:
return AgentLoop(config, ctx, client, handler, registry, permissions)
def _make_text_message(content: str) -> Message:
"""Helper: create an assistant message with text only (no tool calls)."""
return Message(role="assistant", content=content, tool_calls=None)
def _make_tool_call_message(
tool_name: str,
arguments: str,
tc_id: str = "call_001",
content: str | None = None,
) -> Message:
"""Helper: create an assistant message with a single tool call."""
return Message(
role="assistant",
content=content,
tool_calls=[
ToolCall(
id=tc_id,
type="function",
function=ToolCallFunction(name=tool_name, arguments=arguments),
)
],
)
class TestAgentLoop:
@pytest.mark.asyncio
async def test_plain_text_response(self, agent: AgentLoop, handler: MagicMock, ctx: SessionContext) -> None:
"""LLM returns text with no tool calls — loop completes in 1 iteration."""
handler.process_stream = AsyncMock(return_value=_make_text_message("Hello!"))
await agent.run_turn("Hi there")
assert handler.process_stream.call_count == 1
# History: user + assistant
history = ctx.get_history()
assert len(history) == 2
assert history[0].role == "user"
assert history[1].role == "assistant"
assert history[1].content == "Hello!"
@pytest.mark.asyncio
async def test_single_tool_call(self, agent: AgentLoop, handler: MagicMock, ctx: SessionContext) -> None:
"""LLM calls a tool, then responds with text — 2 LLM calls."""
handler.process_stream = AsyncMock(
side_effect=[
_make_tool_call_message("list_dir", '{"directory_path": "."}'),
_make_text_message("Here are the files."),
]
)
await agent.run_turn("List files")
assert handler.process_stream.call_count == 2
history = ctx.get_history()
# user, assistant (tool_call), tool (result), assistant (text)
assert len(history) == 4
assert history[0].role == "user"
assert history[1].role == "assistant"
assert history[1].tool_calls is not None
assert history[2].role == "tool"
assert history[3].role == "assistant"
assert history[3].content == "Here are the files."
@pytest.mark.asyncio
async def test_finish_tool_breaks_loop(self, agent: AgentLoop, handler: MagicMock, ctx: SessionContext) -> None:
"""Calling the finish tool terminates the loop immediately."""
handler.process_stream = AsyncMock(
return_value=_make_tool_call_message("finish", '{"message": "All done!"}'),
)
await agent.run_turn("Do something")
assert handler.process_stream.call_count == 1
history = ctx.get_history()
# user, assistant (finish call), tool (finish result)
assert len(history) == 3
assert history[2].role == "tool"
assert "All done!" in (history[2].content or "")
@pytest.mark.asyncio
async def test_max_iterations(self, agent: AgentLoop, handler: MagicMock, config: AppConfig) -> None:
"""Loop stops at max_iterations when LLM keeps calling tools."""
handler.process_stream = AsyncMock(
return_value=_make_tool_call_message("list_dir", '{"directory_path": "."}'),
)
await agent.run_turn("Keep going")
# Should call LLM max_iterations times
assert handler.process_stream.call_count == config.agent.max_iterations
@pytest.mark.asyncio
async def test_bad_json_arguments(self, agent: AgentLoop, handler: MagicMock, ctx: SessionContext) -> None:
"""Invalid JSON in tool arguments produces an error result, no exception."""
handler.process_stream = AsyncMock(
side_effect=[
_make_tool_call_message("list_dir", "not valid json{{{"),
_make_text_message("Sorry about that."),
]
)
await agent.run_turn("Bad args")
history = ctx.get_history()
# user, assistant (bad call), tool (error), assistant (apology)
assert len(history) == 4
tool_msg = history[2]
assert tool_msg.role == "tool"
assert "Invalid JSON" in (tool_msg.content or "")
@pytest.mark.asyncio
async def test_unknown_tool(self, agent: AgentLoop, handler: MagicMock, ctx: SessionContext) -> None:
"""Unknown tool name produces an error result listing available tools."""
handler.process_stream = AsyncMock(
side_effect=[
_make_tool_call_message("nonexistent_tool", "{}"),
_make_text_message("I'll try something else."),
]
)
await agent.run_turn("Use fake tool")
history = ctx.get_history()
tool_msg = history[2]
assert tool_msg.role == "tool"
assert "Unknown tool" in (tool_msg.content or "")
assert "nonexistent_tool" in (tool_msg.content or "")
@pytest.mark.asyncio
async def test_permission_denied(self, agent: AgentLoop, handler: MagicMock, ctx: SessionContext, config: AppConfig) -> None:
"""Denied tool produces an error result."""
# Add list_dir to deny list
config.permissions.deny.append("list_dir")
# Recreate permissions service with updated config
agent._permissions = PermissionsService(config.permissions)
handler.process_stream = AsyncMock(
side_effect=[
_make_tool_call_message("list_dir", '{"directory_path": "."}'),
_make_text_message("Permission was denied."),
]
)
await agent.run_turn("List files")
history = ctx.get_history()
tool_msg = history[2]
assert tool_msg.role == "tool"
assert "Permission denied" in (tool_msg.content or "")
@pytest.mark.asyncio
async def test_llm_connection_error_stops_loop(self, agent: AgentLoop, handler: MagicMock, ctx: SessionContext) -> None:
"""LLM connection error terminates the loop gracefully."""
from app.services.llm import LLMConnectionError
handler.process_stream = AsyncMock(side_effect=LLMConnectionError("Connection refused"))
await agent.run_turn("Hello")
# Only the user message should be in history (no assistant message added)
history = ctx.get_history()
assert len(history) == 1
assert history[0].role == "user"
@pytest.mark.asyncio
async def test_multiple_tool_calls_in_single_response(self, agent: AgentLoop, handler: MagicMock, ctx: SessionContext) -> None:
"""Multiple tool calls in one response are all executed."""
multi_tc_msg = Message(
role="assistant",
content=None,
tool_calls=[
ToolCall(
id="call_001",
type="function",
function=ToolCallFunction(name="list_dir", arguments='{"directory_path": "."}'),
),
ToolCall(
id="call_002",
type="function",
function=ToolCallFunction(name="find_files", arguments='{"pattern": "*.py"}'),
),
],
)
handler.process_stream = AsyncMock(
side_effect=[
multi_tc_msg,
_make_text_message("Found everything."),
]
)
await agent.run_turn("List and find files")
history = ctx.get_history()
# user, assistant (2 tool calls), tool (result 1), tool (result 2), assistant (text)
assert len(history) == 5
assert history[2].role == "tool"
assert history[2].tool_call_id == "call_001"
assert history[3].role == "tool"
assert history[3].tool_call_id == "call_002"
assert history[4].content == "Found everything."

View File

@@ -96,12 +96,12 @@ class TestToolRegistry:
def test_create_default_registry(self, workspace: Path, config: AppConfig) -> None: def test_create_default_registry(self, workspace: Path, config: AppConfig) -> None:
registry = create_default_registry(workspace, config) registry = create_default_registry(workspace, config)
names = set(registry.get_all().keys()) names = set(registry.get_all().keys())
assert names == {"read_file", "list_dir", "grep_files", "find_files"} assert names == {"read_file", "list_dir", "grep_files", "find_files", "finish"}
def test_schema_export(self, workspace: Path, config: AppConfig) -> None: def test_schema_export(self, workspace: Path, config: AppConfig) -> None:
registry = create_default_registry(workspace, config) registry = create_default_registry(workspace, config)
schemas = registry.get_openai_tools_schema() schemas = registry.get_openai_tools_schema()
assert len(schemas) == 4 assert len(schemas) == 5
assert all(s["type"] == "function" for s in schemas) assert all(s["type"] == "function" for s in schemas)