Files
SneakyCode/tests/unit/test_streaming.py

112 lines
4.0 KiB
Python

"""Tests for callback-based StreamHandler."""
import asyncio
from unittest.mock import MagicMock, call
import pytest
from app.models.config import DisplayConfig
from app.services.streaming import StreamHandler
def _make_chunk(content: str | None = None, reasoning: str | None = None) -> dict:
"""Helper to create a fake SSE chunk."""
delta: dict = {}
if content is not None:
delta["content"] = content
if reasoning is not None:
delta["reasoning"] = reasoning
return {"choices": [{"delta": delta}]}
def _make_tool_call_chunk(index: int, tc_id: str = "", name: str = "", args: str = "") -> dict:
"""Helper to create a fake tool call chunk."""
tc_delta: dict = {"index": index}
if tc_id:
tc_delta["id"] = tc_id
func: dict = {}
if name:
func["name"] = name
if args:
func["arguments"] = args
if func:
tc_delta["function"] = func
return {"choices": [{"delta": {"tool_calls": [tc_delta]}}]}
async def _async_iter(items: list[dict]):
for item in items:
yield item
class TestStreamHandlerCallbacks:
@pytest.mark.asyncio
async def test_on_content_called_with_accumulated_text(self) -> None:
handler = StreamHandler(DisplayConfig(stream_output=True))
on_content = MagicMock()
on_thinking = MagicMock()
on_done = MagicMock()
handler.set_callbacks(on_content=on_content, on_thinking=on_thinking, on_done=on_done)
chunks = [_make_chunk(content="Hello"), _make_chunk(content=" world")]
msg = await handler.process_stream(_async_iter(chunks))
assert msg.content == "Hello world"
assert on_content.call_count >= 1
# Last call should have full accumulated content
last_content = on_content.call_args_list[-1][0][0]
assert last_content == "Hello world"
on_done.assert_called_once()
@pytest.mark.asyncio
async def test_on_thinking_called_for_reasoning(self) -> None:
handler = StreamHandler(DisplayConfig(stream_output=True))
on_content = MagicMock()
on_thinking = MagicMock()
on_done = MagicMock()
handler.set_callbacks(on_content=on_content, on_thinking=on_thinking, on_done=on_done)
chunks = [_make_chunk(reasoning="let me think")]
msg = await handler.process_stream(_async_iter(chunks))
on_thinking.assert_called_once()
on_content.assert_not_called()
on_done.assert_called_once()
@pytest.mark.asyncio
async def test_display_callbacks_skip_when_stream_output_disabled(self) -> None:
"""on_content and on_thinking are suppressed, but on_done always fires."""
handler = StreamHandler(DisplayConfig(stream_output=False))
on_content = MagicMock()
on_thinking = MagicMock()
on_done = MagicMock()
handler.set_callbacks(on_content=on_content, on_thinking=on_thinking, on_done=on_done)
chunks = [_make_chunk(content="Hello")]
msg = await handler.process_stream(_async_iter(chunks))
assert msg.content == "Hello"
on_content.assert_not_called()
on_thinking.assert_not_called()
on_done.assert_called_once()
@pytest.mark.asyncio
async def test_no_callbacks_by_default(self) -> None:
"""process_stream works without set_callbacks (backward compat)."""
handler = StreamHandler(DisplayConfig(stream_output=True))
chunks = [_make_chunk(content="Hello")]
msg = await handler.process_stream(_async_iter(chunks))
assert msg.content == "Hello"
@pytest.mark.asyncio
async def test_tool_calls_still_accumulated(self) -> None:
handler = StreamHandler(DisplayConfig(stream_output=True))
chunks = [
_make_tool_call_chunk(0, tc_id="call_1", name="read_file"),
_make_tool_call_chunk(0, args='{"path": "foo.py"}'),
]
msg = await handler.process_stream(_async_iter(chunks))
assert msg.tool_calls is not None
assert len(msg.tool_calls) == 1
assert msg.tool_calls[0].function.name == "read_file"