Files
SneakyCode/app/services/streaming.py

194 lines
6.8 KiB
Python

"""Streaming response handler — accumulates SSE chunks into a complete Message."""
import time
from collections.abc import AsyncIterator, Callable
from app.models.config import DisplayConfig
from app.models.message import Message
from app.models.tool_call import ToolCall, ToolCallFunction
from app.utils.logging import get_logger
from app.utils.token_counter import TokenUsage
logger = get_logger(__name__)
# Minimum interval between content update callbacks (seconds)
_UPDATE_THROTTLE_INTERVAL = 0.1
class StreamHandler:
"""Processes an SSE chunk stream and produces a complete assistant Message.
Accumulates content deltas and tool call fragments. Notifies the UI via
optional callbacks during streaming.
"""
def __init__(self, display_config: DisplayConfig) -> None:
self._display_config = display_config
self._accumulated_content: str = ""
self._accumulated_reasoning: str = ""
self._tool_calls: dict[int, dict[str, str]] = {}
self._usage: TokenUsage | None = None
self._on_content: Callable[[str], None] | None = None
self._on_thinking: Callable[[], None] | None = None
self._on_done: Callable[[], None] | None = None
def set_callbacks(
self,
on_content: Callable[[str], None] | None = None,
on_thinking: Callable[[], None] | None = None,
on_done: Callable[[], None] | None = None,
) -> None:
"""Set UI callbacks for streaming updates.
Args:
on_content: Called with accumulated content string (throttled to ~100ms).
on_thinking: Called once when first reasoning token arrives.
on_done: Called when streaming is complete.
"""
self._on_content = on_content
self._on_thinking = on_thinking
self._on_done = on_done
async def process_stream(self, chunk_iter: AsyncIterator[dict]) -> Message:
"""Consume a chunk iterator and return the final Message.
Args:
chunk_iter: Async iterator of parsed SSE chunk dicts.
Returns:
Complete assistant Message with accumulated content and tool calls.
"""
thinking_notified = False
last_update_time = 0.0
async for chunk in chunk_iter:
self._process_chunk(chunk)
if not self._display_config.stream_output:
continue
# Notify thinking once
if (
not thinking_notified
and not self._accumulated_content
and self._accumulated_reasoning
and self._on_thinking is not None
):
self._on_thinking()
thinking_notified = True
# Throttled content updates
if self._accumulated_content and self._on_content is not None:
now = time.monotonic()
if now - last_update_time >= _UPDATE_THROTTLE_INTERVAL:
self._on_content(self._accumulated_content)
last_update_time = now
# Final content update (ensures last chunk is shown)
if (
self._display_config.stream_output
and self._accumulated_content
and self._on_content is not None
):
self._on_content(self._accumulated_content)
if self._on_done is not None:
self._on_done()
tool_calls = self._build_tool_calls() or None
return Message(
role="assistant",
content=self._accumulated_content or None,
tool_calls=tool_calls,
)
def _process_chunk(self, chunk: dict) -> None:
"""Extract content, tool calls, and usage from a single SSE chunk."""
choices = chunk.get("choices", [])
if choices:
delta = choices[0].get("delta", {})
content_piece = delta.get("content")
if content_piece:
self._accumulated_content += content_piece
reasoning_piece = delta.get("reasoning")
if reasoning_piece:
self._accumulated_reasoning += reasoning_piece
for tc_delta in delta.get("tool_calls", []):
idx = tc_delta.get("index", 0)
if idx not in self._tool_calls:
self._tool_calls[idx] = {
"id": tc_delta.get("id", ""),
"name": "",
"arguments": "",
}
entry = self._tool_calls[idx]
if tc_delta.get("id"):
entry["id"] = tc_delta["id"]
func = tc_delta.get("function", {})
if func.get("name"):
entry["name"] += func["name"]
if func.get("arguments"):
entry["arguments"] += func["arguments"]
usage_data = chunk.get("usage")
if usage_data:
self._usage = TokenUsage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0),
)
def _build_tool_calls(self) -> list[ToolCall]:
"""Convert accumulated tool call fragments into sorted ToolCall list."""
if not self._tool_calls:
return []
result: list[ToolCall] = []
for idx in sorted(self._tool_calls):
entry = self._tool_calls[idx]
result.append(
ToolCall(
id=entry["id"],
type="function",
function=ToolCallFunction(
name=entry["name"],
arguments=entry["arguments"],
),
)
)
return result
def get_partial_message(self) -> Message | None:
"""Return whatever content/tool_calls have been accumulated so far."""
tool_calls = self._build_tool_calls() or None
if not self._accumulated_content and not tool_calls:
return None
return Message(
role="assistant",
content=self._accumulated_content or None,
tool_calls=tool_calls,
)
@property
def usage(self) -> TokenUsage | None:
"""Token usage reported by the API, if available."""
return self._usage
@property
def had_reasoning_only(self) -> bool:
"""True if the model produced reasoning tokens but no content or tool calls."""
return bool(self._accumulated_reasoning) and not self._accumulated_content and not self._tool_calls
def reset(self) -> None:
"""Clear all accumulators for the next turn."""
self._accumulated_content = ""
self._accumulated_reasoning = ""
self._tool_calls.clear()
self._usage = None
self._on_content = None
self._on_thinking = None
self._on_done = None