194 lines
6.8 KiB
Python
194 lines
6.8 KiB
Python
"""Streaming response handler — accumulates SSE chunks into a complete Message."""
|
|
|
|
import time
|
|
from collections.abc import AsyncIterator, Callable
|
|
|
|
from app.models.config import DisplayConfig
|
|
from app.models.message import Message
|
|
from app.models.tool_call import ToolCall, ToolCallFunction
|
|
from app.utils.logging import get_logger
|
|
from app.utils.token_counter import TokenUsage
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
# Minimum interval between content update callbacks (seconds)
|
|
_UPDATE_THROTTLE_INTERVAL = 0.1
|
|
|
|
|
|
class StreamHandler:
|
|
"""Processes an SSE chunk stream and produces a complete assistant Message.
|
|
|
|
Accumulates content deltas and tool call fragments. Notifies the UI via
|
|
optional callbacks during streaming.
|
|
"""
|
|
|
|
def __init__(self, display_config: DisplayConfig) -> None:
|
|
self._display_config = display_config
|
|
self._accumulated_content: str = ""
|
|
self._accumulated_reasoning: str = ""
|
|
self._tool_calls: dict[int, dict[str, str]] = {}
|
|
self._usage: TokenUsage | None = None
|
|
self._on_content: Callable[[str], None] | None = None
|
|
self._on_thinking: Callable[[], None] | None = None
|
|
self._on_done: Callable[[], None] | None = None
|
|
|
|
def set_callbacks(
|
|
self,
|
|
on_content: Callable[[str], None] | None = None,
|
|
on_thinking: Callable[[], None] | None = None,
|
|
on_done: Callable[[], None] | None = None,
|
|
) -> None:
|
|
"""Set UI callbacks for streaming updates.
|
|
|
|
Args:
|
|
on_content: Called with accumulated content string (throttled to ~100ms).
|
|
on_thinking: Called once when first reasoning token arrives.
|
|
on_done: Called when streaming is complete.
|
|
"""
|
|
self._on_content = on_content
|
|
self._on_thinking = on_thinking
|
|
self._on_done = on_done
|
|
|
|
async def process_stream(self, chunk_iter: AsyncIterator[dict]) -> Message:
|
|
"""Consume a chunk iterator and return the final Message.
|
|
|
|
Args:
|
|
chunk_iter: Async iterator of parsed SSE chunk dicts.
|
|
|
|
Returns:
|
|
Complete assistant Message with accumulated content and tool calls.
|
|
"""
|
|
thinking_notified = False
|
|
last_update_time = 0.0
|
|
|
|
async for chunk in chunk_iter:
|
|
self._process_chunk(chunk)
|
|
|
|
if not self._display_config.stream_output:
|
|
continue
|
|
|
|
# Notify thinking once
|
|
if (
|
|
not thinking_notified
|
|
and not self._accumulated_content
|
|
and self._accumulated_reasoning
|
|
and self._on_thinking is not None
|
|
):
|
|
self._on_thinking()
|
|
thinking_notified = True
|
|
|
|
# Throttled content updates
|
|
if self._accumulated_content and self._on_content is not None:
|
|
now = time.monotonic()
|
|
if now - last_update_time >= _UPDATE_THROTTLE_INTERVAL:
|
|
self._on_content(self._accumulated_content)
|
|
last_update_time = now
|
|
|
|
# Final content update (ensures last chunk is shown)
|
|
if (
|
|
self._display_config.stream_output
|
|
and self._accumulated_content
|
|
and self._on_content is not None
|
|
):
|
|
self._on_content(self._accumulated_content)
|
|
|
|
if self._on_done is not None:
|
|
self._on_done()
|
|
|
|
tool_calls = self._build_tool_calls() or None
|
|
return Message(
|
|
role="assistant",
|
|
content=self._accumulated_content or None,
|
|
tool_calls=tool_calls,
|
|
)
|
|
|
|
def _process_chunk(self, chunk: dict) -> None:
|
|
"""Extract content, tool calls, and usage from a single SSE chunk."""
|
|
choices = chunk.get("choices", [])
|
|
if choices:
|
|
delta = choices[0].get("delta", {})
|
|
|
|
content_piece = delta.get("content")
|
|
if content_piece:
|
|
self._accumulated_content += content_piece
|
|
|
|
reasoning_piece = delta.get("reasoning")
|
|
if reasoning_piece:
|
|
self._accumulated_reasoning += reasoning_piece
|
|
|
|
for tc_delta in delta.get("tool_calls", []):
|
|
idx = tc_delta.get("index", 0)
|
|
if idx not in self._tool_calls:
|
|
self._tool_calls[idx] = {
|
|
"id": tc_delta.get("id", ""),
|
|
"name": "",
|
|
"arguments": "",
|
|
}
|
|
entry = self._tool_calls[idx]
|
|
if tc_delta.get("id"):
|
|
entry["id"] = tc_delta["id"]
|
|
func = tc_delta.get("function", {})
|
|
if func.get("name"):
|
|
entry["name"] += func["name"]
|
|
if func.get("arguments"):
|
|
entry["arguments"] += func["arguments"]
|
|
|
|
usage_data = chunk.get("usage")
|
|
if usage_data:
|
|
self._usage = TokenUsage(
|
|
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
|
completion_tokens=usage_data.get("completion_tokens", 0),
|
|
total_tokens=usage_data.get("total_tokens", 0),
|
|
)
|
|
|
|
def _build_tool_calls(self) -> list[ToolCall]:
|
|
"""Convert accumulated tool call fragments into sorted ToolCall list."""
|
|
if not self._tool_calls:
|
|
return []
|
|
|
|
result: list[ToolCall] = []
|
|
for idx in sorted(self._tool_calls):
|
|
entry = self._tool_calls[idx]
|
|
result.append(
|
|
ToolCall(
|
|
id=entry["id"],
|
|
type="function",
|
|
function=ToolCallFunction(
|
|
name=entry["name"],
|
|
arguments=entry["arguments"],
|
|
),
|
|
)
|
|
)
|
|
return result
|
|
|
|
def get_partial_message(self) -> Message | None:
|
|
"""Return whatever content/tool_calls have been accumulated so far."""
|
|
tool_calls = self._build_tool_calls() or None
|
|
if not self._accumulated_content and not tool_calls:
|
|
return None
|
|
return Message(
|
|
role="assistant",
|
|
content=self._accumulated_content or None,
|
|
tool_calls=tool_calls,
|
|
)
|
|
|
|
@property
|
|
def usage(self) -> TokenUsage | None:
|
|
"""Token usage reported by the API, if available."""
|
|
return self._usage
|
|
|
|
@property
|
|
def had_reasoning_only(self) -> bool:
|
|
"""True if the model produced reasoning tokens but no content or tool calls."""
|
|
return bool(self._accumulated_reasoning) and not self._accumulated_content and not self._tool_calls
|
|
|
|
def reset(self) -> None:
|
|
"""Clear all accumulators for the next turn."""
|
|
self._accumulated_content = ""
|
|
self._accumulated_reasoning = ""
|
|
self._tool_calls.clear()
|
|
self._usage = None
|
|
self._on_content = None
|
|
self._on_thinking = None
|
|
self._on_done = None
|