156 lines
5.6 KiB
Python
156 lines
5.6 KiB
Python
"""Streaming response handler — accumulates SSE chunks into a complete Message."""
|
|
|
|
from collections.abc import AsyncIterator
|
|
|
|
from rich.live import Live
|
|
from rich.markdown import Markdown
|
|
from rich.panel import Panel
|
|
|
|
from app.models.config import DisplayConfig
|
|
from app.models.message import Message
|
|
from app.models.tool_call import ToolCall, ToolCallFunction
|
|
from app.utils.logging import console, get_logger
|
|
from app.utils.token_counter import TokenUsage
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class StreamHandler:
|
|
"""Processes an SSE chunk stream into a Rich live display and final Message.
|
|
|
|
Accumulates content deltas and tool call fragments, renders a live Markdown
|
|
panel during streaming, and produces a complete assistant Message on finish.
|
|
"""
|
|
|
|
def __init__(self, display_config: DisplayConfig) -> None:
|
|
"""Initialize the stream handler.
|
|
|
|
Args:
|
|
display_config: Display preferences (streaming toggle, etc.).
|
|
"""
|
|
self._display_config = display_config
|
|
self._accumulated_content: str = ""
|
|
self._accumulated_reasoning: str = ""
|
|
self._tool_calls: dict[int, dict[str, str]] = {}
|
|
self._usage: TokenUsage | None = None
|
|
|
|
async def process_stream(self, chunk_iter: AsyncIterator[dict]) -> Message:
|
|
"""Consume a chunk iterator, rendering live output and returning the final Message.
|
|
|
|
Args:
|
|
chunk_iter: Async iterator of parsed SSE chunk dicts.
|
|
|
|
Returns:
|
|
Complete assistant Message with accumulated content and tool calls.
|
|
"""
|
|
with Live(console=console, refresh_per_second=8) as live:
|
|
async for chunk in chunk_iter:
|
|
self._process_chunk(chunk)
|
|
|
|
# Show reasoning while waiting for content
|
|
display_text = self._accumulated_content
|
|
if not display_text and self._accumulated_reasoning:
|
|
display_text = "*thinking...*"
|
|
|
|
if display_text and self._display_config.stream_output:
|
|
# Render inside the same Assistant panel used for final output
|
|
# so the live display and final frame are visually consistent
|
|
live.update(
|
|
Panel(
|
|
Markdown(display_text),
|
|
title="Assistant",
|
|
border_style="green",
|
|
expand=True,
|
|
)
|
|
)
|
|
|
|
tool_calls = self._build_tool_calls() or None
|
|
return Message(
|
|
role="assistant",
|
|
content=self._accumulated_content or None,
|
|
tool_calls=tool_calls,
|
|
)
|
|
|
|
def _process_chunk(self, chunk: dict) -> None:
|
|
"""Extract content, tool calls, and usage from a single SSE chunk.
|
|
|
|
Args:
|
|
chunk: Parsed JSON dict from one SSE data line.
|
|
"""
|
|
# Content delta
|
|
choices = chunk.get("choices", [])
|
|
if choices:
|
|
delta = choices[0].get("delta", {})
|
|
|
|
content_piece = delta.get("content")
|
|
if content_piece:
|
|
self._accumulated_content += content_piece
|
|
|
|
# Reasoning tokens (e.g. qwen3.5 thinking mode)
|
|
reasoning_piece = delta.get("reasoning")
|
|
if reasoning_piece:
|
|
self._accumulated_reasoning += reasoning_piece
|
|
|
|
# Tool call deltas (accumulated by index)
|
|
for tc_delta in delta.get("tool_calls", []):
|
|
idx = tc_delta.get("index", 0)
|
|
if idx not in self._tool_calls:
|
|
self._tool_calls[idx] = {
|
|
"id": tc_delta.get("id", ""),
|
|
"name": "",
|
|
"arguments": "",
|
|
}
|
|
entry = self._tool_calls[idx]
|
|
if tc_delta.get("id"):
|
|
entry["id"] = tc_delta["id"]
|
|
func = tc_delta.get("function", {})
|
|
if func.get("name"):
|
|
entry["name"] += func["name"]
|
|
if func.get("arguments"):
|
|
entry["arguments"] += func["arguments"]
|
|
|
|
# Token usage (typically in the final chunk)
|
|
usage_data = chunk.get("usage")
|
|
if usage_data:
|
|
self._usage = TokenUsage(
|
|
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
|
completion_tokens=usage_data.get("completion_tokens", 0),
|
|
total_tokens=usage_data.get("total_tokens", 0),
|
|
)
|
|
|
|
def _build_tool_calls(self) -> list[ToolCall]:
|
|
"""Convert accumulated tool call fragments into sorted ToolCall list.
|
|
|
|
Returns:
|
|
List of ToolCall objects sorted by stream index.
|
|
"""
|
|
if not self._tool_calls:
|
|
return []
|
|
|
|
result: list[ToolCall] = []
|
|
for idx in sorted(self._tool_calls):
|
|
entry = self._tool_calls[idx]
|
|
result.append(
|
|
ToolCall(
|
|
id=entry["id"],
|
|
type="function",
|
|
function=ToolCallFunction(
|
|
name=entry["name"],
|
|
arguments=entry["arguments"],
|
|
),
|
|
)
|
|
)
|
|
return result
|
|
|
|
@property
|
|
def usage(self) -> TokenUsage | None:
|
|
"""Token usage reported by the API, if available."""
|
|
return self._usage
|
|
|
|
def reset(self) -> None:
|
|
"""Clear all accumulators for the next turn."""
|
|
self._accumulated_content = ""
|
|
self._accumulated_reasoning = ""
|
|
self._tool_calls.clear()
|
|
self._usage = None
|