"""Streaming response handler — accumulates SSE chunks into a complete Message.""" import time from collections.abc import AsyncIterator, Callable from app.models.config import DisplayConfig from app.models.message import Message from app.models.tool_call import ToolCall, ToolCallFunction from app.utils.logging import get_logger from app.utils.token_counter import TokenUsage logger = get_logger(__name__) # Minimum interval between content update callbacks (seconds) _UPDATE_THROTTLE_INTERVAL = 0.1 class StreamHandler: """Processes an SSE chunk stream and produces a complete assistant Message. Accumulates content deltas and tool call fragments. Notifies the UI via optional callbacks during streaming. """ def __init__(self, display_config: DisplayConfig) -> None: self._display_config = display_config self._accumulated_content: str = "" self._accumulated_reasoning: str = "" self._tool_calls: dict[int, dict[str, str]] = {} self._usage: TokenUsage | None = None self._on_content: Callable[[str], None] | None = None self._on_thinking: Callable[[], None] | None = None self._on_done: Callable[[], None] | None = None def set_callbacks( self, on_content: Callable[[str], None] | None = None, on_thinking: Callable[[], None] | None = None, on_done: Callable[[], None] | None = None, ) -> None: """Set UI callbacks for streaming updates. Args: on_content: Called with accumulated content string (throttled to ~100ms). on_thinking: Called once when first reasoning token arrives. on_done: Called when streaming is complete. """ self._on_content = on_content self._on_thinking = on_thinking self._on_done = on_done async def process_stream(self, chunk_iter: AsyncIterator[dict]) -> Message: """Consume a chunk iterator and return the final Message. Args: chunk_iter: Async iterator of parsed SSE chunk dicts. Returns: Complete assistant Message with accumulated content and tool calls. """ thinking_notified = False last_update_time = 0.0 async for chunk in chunk_iter: self._process_chunk(chunk) if not self._display_config.stream_output: continue # Notify thinking once if ( not thinking_notified and not self._accumulated_content and self._accumulated_reasoning and self._on_thinking is not None ): self._on_thinking() thinking_notified = True # Throttled content updates if self._accumulated_content and self._on_content is not None: now = time.monotonic() if now - last_update_time >= _UPDATE_THROTTLE_INTERVAL: self._on_content(self._accumulated_content) last_update_time = now # Final content update (ensures last chunk is shown) if ( self._display_config.stream_output and self._accumulated_content and self._on_content is not None ): self._on_content(self._accumulated_content) if self._on_done is not None: self._on_done() tool_calls = self._build_tool_calls() or None return Message( role="assistant", content=self._accumulated_content or None, tool_calls=tool_calls, ) def _process_chunk(self, chunk: dict) -> None: """Extract content, tool calls, and usage from a single SSE chunk.""" choices = chunk.get("choices", []) if choices: delta = choices[0].get("delta", {}) content_piece = delta.get("content") if content_piece: self._accumulated_content += content_piece reasoning_piece = delta.get("reasoning") if reasoning_piece: self._accumulated_reasoning += reasoning_piece for tc_delta in delta.get("tool_calls", []): idx = tc_delta.get("index", 0) if idx not in self._tool_calls: self._tool_calls[idx] = { "id": tc_delta.get("id", ""), "name": "", "arguments": "", } entry = self._tool_calls[idx] if tc_delta.get("id"): entry["id"] = tc_delta["id"] func = tc_delta.get("function", {}) if func.get("name"): entry["name"] += func["name"] if func.get("arguments"): entry["arguments"] += func["arguments"] usage_data = chunk.get("usage") if usage_data: self._usage = TokenUsage( prompt_tokens=usage_data.get("prompt_tokens", 0), completion_tokens=usage_data.get("completion_tokens", 0), total_tokens=usage_data.get("total_tokens", 0), ) def _build_tool_calls(self) -> list[ToolCall]: """Convert accumulated tool call fragments into sorted ToolCall list.""" if not self._tool_calls: return [] result: list[ToolCall] = [] for idx in sorted(self._tool_calls): entry = self._tool_calls[idx] result.append( ToolCall( id=entry["id"], type="function", function=ToolCallFunction( name=entry["name"], arguments=entry["arguments"], ), ) ) return result def get_partial_message(self) -> Message | None: """Return whatever content/tool_calls have been accumulated so far.""" tool_calls = self._build_tool_calls() or None if not self._accumulated_content and not tool_calls: return None return Message( role="assistant", content=self._accumulated_content or None, tool_calls=tool_calls, ) @property def usage(self) -> TokenUsage | None: """Token usage reported by the API, if available.""" return self._usage @property def had_reasoning_only(self) -> bool: """True if the model produced reasoning tokens but no content or tool calls.""" return bool(self._accumulated_reasoning) and not self._accumulated_content and not self._tool_calls def reset(self) -> None: """Clear all accumulators for the next turn.""" self._accumulated_content = "" self._accumulated_reasoning = "" self._tool_calls.clear() self._usage = None self._on_content = None self._on_thinking = None self._on_done = None