112 lines
4.0 KiB
Python
112 lines
4.0 KiB
Python
"""Tests for callback-based StreamHandler."""
|
|
|
|
import asyncio
|
|
from unittest.mock import MagicMock, call
|
|
|
|
import pytest
|
|
|
|
from app.models.config import DisplayConfig
|
|
from app.services.streaming import StreamHandler
|
|
|
|
|
|
def _make_chunk(content: str | None = None, reasoning: str | None = None) -> dict:
|
|
"""Helper to create a fake SSE chunk."""
|
|
delta: dict = {}
|
|
if content is not None:
|
|
delta["content"] = content
|
|
if reasoning is not None:
|
|
delta["reasoning"] = reasoning
|
|
return {"choices": [{"delta": delta}]}
|
|
|
|
|
|
def _make_tool_call_chunk(index: int, tc_id: str = "", name: str = "", args: str = "") -> dict:
|
|
"""Helper to create a fake tool call chunk."""
|
|
tc_delta: dict = {"index": index}
|
|
if tc_id:
|
|
tc_delta["id"] = tc_id
|
|
func: dict = {}
|
|
if name:
|
|
func["name"] = name
|
|
if args:
|
|
func["arguments"] = args
|
|
if func:
|
|
tc_delta["function"] = func
|
|
return {"choices": [{"delta": {"tool_calls": [tc_delta]}}]}
|
|
|
|
|
|
async def _async_iter(items: list[dict]):
|
|
for item in items:
|
|
yield item
|
|
|
|
|
|
class TestStreamHandlerCallbacks:
|
|
@pytest.mark.asyncio
|
|
async def test_on_content_called_with_accumulated_text(self) -> None:
|
|
handler = StreamHandler(DisplayConfig(stream_output=True))
|
|
on_content = MagicMock()
|
|
on_thinking = MagicMock()
|
|
on_done = MagicMock()
|
|
handler.set_callbacks(on_content=on_content, on_thinking=on_thinking, on_done=on_done)
|
|
|
|
chunks = [_make_chunk(content="Hello"), _make_chunk(content=" world")]
|
|
msg = await handler.process_stream(_async_iter(chunks))
|
|
|
|
assert msg.content == "Hello world"
|
|
assert on_content.call_count >= 1
|
|
# Last call should have full accumulated content
|
|
last_content = on_content.call_args_list[-1][0][0]
|
|
assert last_content == "Hello world"
|
|
on_done.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_on_thinking_called_for_reasoning(self) -> None:
|
|
handler = StreamHandler(DisplayConfig(stream_output=True))
|
|
on_content = MagicMock()
|
|
on_thinking = MagicMock()
|
|
on_done = MagicMock()
|
|
handler.set_callbacks(on_content=on_content, on_thinking=on_thinking, on_done=on_done)
|
|
|
|
chunks = [_make_chunk(reasoning="let me think")]
|
|
msg = await handler.process_stream(_async_iter(chunks))
|
|
|
|
on_thinking.assert_called_once()
|
|
on_content.assert_not_called()
|
|
on_done.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_display_callbacks_skip_when_stream_output_disabled(self) -> None:
|
|
"""on_content and on_thinking are suppressed, but on_done always fires."""
|
|
handler = StreamHandler(DisplayConfig(stream_output=False))
|
|
on_content = MagicMock()
|
|
on_thinking = MagicMock()
|
|
on_done = MagicMock()
|
|
handler.set_callbacks(on_content=on_content, on_thinking=on_thinking, on_done=on_done)
|
|
|
|
chunks = [_make_chunk(content="Hello")]
|
|
msg = await handler.process_stream(_async_iter(chunks))
|
|
|
|
assert msg.content == "Hello"
|
|
on_content.assert_not_called()
|
|
on_thinking.assert_not_called()
|
|
on_done.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_callbacks_by_default(self) -> None:
|
|
"""process_stream works without set_callbacks (backward compat)."""
|
|
handler = StreamHandler(DisplayConfig(stream_output=True))
|
|
chunks = [_make_chunk(content="Hello")]
|
|
msg = await handler.process_stream(_async_iter(chunks))
|
|
assert msg.content == "Hello"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_calls_still_accumulated(self) -> None:
|
|
handler = StreamHandler(DisplayConfig(stream_output=True))
|
|
chunks = [
|
|
_make_tool_call_chunk(0, tc_id="call_1", name="read_file"),
|
|
_make_tool_call_chunk(0, args='{"path": "foo.py"}'),
|
|
]
|
|
msg = await handler.process_stream(_async_iter(chunks))
|
|
assert msg.tool_calls is not None
|
|
assert len(msg.tool_calls) == 1
|
|
assert msg.tool_calls[0].function.name == "read_file"
|