"""Tests for callback-based StreamHandler.""" import asyncio from unittest.mock import MagicMock, call import pytest from app.models.config import DisplayConfig from app.services.streaming import StreamHandler def _make_chunk(content: str | None = None, reasoning: str | None = None) -> dict: """Helper to create a fake SSE chunk.""" delta: dict = {} if content is not None: delta["content"] = content if reasoning is not None: delta["reasoning"] = reasoning return {"choices": [{"delta": delta}]} def _make_tool_call_chunk(index: int, tc_id: str = "", name: str = "", args: str = "") -> dict: """Helper to create a fake tool call chunk.""" tc_delta: dict = {"index": index} if tc_id: tc_delta["id"] = tc_id func: dict = {} if name: func["name"] = name if args: func["arguments"] = args if func: tc_delta["function"] = func return {"choices": [{"delta": {"tool_calls": [tc_delta]}}]} async def _async_iter(items: list[dict]): for item in items: yield item class TestStreamHandlerCallbacks: @pytest.mark.asyncio async def test_on_content_called_with_accumulated_text(self) -> None: handler = StreamHandler(DisplayConfig(stream_output=True)) on_content = MagicMock() on_thinking = MagicMock() on_done = MagicMock() handler.set_callbacks(on_content=on_content, on_thinking=on_thinking, on_done=on_done) chunks = [_make_chunk(content="Hello"), _make_chunk(content=" world")] msg = await handler.process_stream(_async_iter(chunks)) assert msg.content == "Hello world" assert on_content.call_count >= 1 # Last call should have full accumulated content last_content = on_content.call_args_list[-1][0][0] assert last_content == "Hello world" on_done.assert_called_once() @pytest.mark.asyncio async def test_on_thinking_called_for_reasoning(self) -> None: handler = StreamHandler(DisplayConfig(stream_output=True)) on_content = MagicMock() on_thinking = MagicMock() on_done = MagicMock() handler.set_callbacks(on_content=on_content, on_thinking=on_thinking, on_done=on_done) chunks = [_make_chunk(reasoning="let me think")] msg = await handler.process_stream(_async_iter(chunks)) on_thinking.assert_called_once() on_content.assert_not_called() on_done.assert_called_once() @pytest.mark.asyncio async def test_display_callbacks_skip_when_stream_output_disabled(self) -> None: """on_content and on_thinking are suppressed, but on_done always fires.""" handler = StreamHandler(DisplayConfig(stream_output=False)) on_content = MagicMock() on_thinking = MagicMock() on_done = MagicMock() handler.set_callbacks(on_content=on_content, on_thinking=on_thinking, on_done=on_done) chunks = [_make_chunk(content="Hello")] msg = await handler.process_stream(_async_iter(chunks)) assert msg.content == "Hello" on_content.assert_not_called() on_thinking.assert_not_called() on_done.assert_called_once() @pytest.mark.asyncio async def test_no_callbacks_by_default(self) -> None: """process_stream works without set_callbacks (backward compat).""" handler = StreamHandler(DisplayConfig(stream_output=True)) chunks = [_make_chunk(content="Hello")] msg = await handler.process_stream(_async_iter(chunks)) assert msg.content == "Hello" @pytest.mark.asyncio async def test_tool_calls_still_accumulated(self) -> None: handler = StreamHandler(DisplayConfig(stream_output=True)) chunks = [ _make_tool_call_chunk(0, tc_id="call_1", name="read_file"), _make_tool_call_chunk(0, args='{"path": "foo.py"}'), ] msg = await handler.process_stream(_async_iter(chunks)) assert msg.tool_calls is not None assert len(msg.tool_calls) == 1 assert msg.tool_calls[0].function.name == "read_file"