- Config extensions: retry backoff, truncation threshold, session persistence - LLM retry with exponential backoff + jitter on transient errors (5xx, connection) - Conversation truncation: drops oldest messages preserving first user + recent N - Session persistence: auto-save/restore with atomic writes, cleanup of old files - Graceful shutdown: SIGTERM handler, cancel() on AgentLoop, save-on-exit - Partial message recovery on mid-stream interruption - New slash commands: /save, /session - 18 new tests (5 retry, 5 truncation, 4 session, 4 integration workflows) - README.md and docs/tools.md documentation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
231 lines
7.7 KiB
Python
231 lines
7.7 KiB
Python
"""LLM client wrapper for Ollama / OpenAI-compatible endpoints."""
|
|
|
|
import asyncio
|
|
import json
|
|
import random
|
|
from collections.abc import AsyncIterator
|
|
from typing import Any, Self
|
|
|
|
import httpx
|
|
|
|
from app.models.config import LLMConfig
|
|
from app.models.message import Message
|
|
from app.utils.logging import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
# --- Exception hierarchy ---
|
|
|
|
|
|
class LLMError(Exception):
|
|
"""Base exception for LLM client errors."""
|
|
|
|
|
|
class LLMConnectionError(LLMError):
|
|
"""Connection or timeout failure when reaching the LLM endpoint."""
|
|
|
|
|
|
class LLMResponseError(LLMError):
|
|
"""Non-2xx HTTP response from the LLM endpoint."""
|
|
|
|
def __init__(self, message: str, status_code: int | None = None) -> None:
|
|
super().__init__(message)
|
|
self.status_code = status_code
|
|
|
|
|
|
class LLMStreamError(LLMError):
|
|
"""Malformed SSE data during streaming."""
|
|
|
|
|
|
# --- Client ---
|
|
|
|
|
|
class LLMClient:
|
|
"""Async streaming client for OpenAI-compatible chat completions.
|
|
|
|
Designed for Ollama but works with any endpoint implementing the
|
|
OpenAI /v1/chat/completions SSE streaming protocol.
|
|
"""
|
|
|
|
def __init__(self, config: LLMConfig) -> None:
|
|
"""Initialize the LLM client.
|
|
|
|
Args:
|
|
config: LLM configuration (model, endpoint, timeout, etc.).
|
|
"""
|
|
self._config = config
|
|
self._client = httpx.AsyncClient(
|
|
base_url=config.endpoint,
|
|
timeout=httpx.Timeout(config.timeout, connect=10.0),
|
|
)
|
|
|
|
async def preflight_check(self) -> None:
|
|
"""Verify the endpoint is reachable and the configured model is available.
|
|
|
|
Raises:
|
|
LLMConnectionError: If the endpoint is unreachable.
|
|
LLMResponseError: If the model is not found or the endpoint returns an error.
|
|
"""
|
|
# Check endpoint is reachable
|
|
try:
|
|
response = await self._client.get("/api/tags")
|
|
except (httpx.ConnectError, httpx.HTTPError, OSError) as e:
|
|
raise LLMConnectionError(
|
|
f"Cannot reach Ollama at {self._config.endpoint}. Is Ollama running?"
|
|
) from e
|
|
except httpx.TimeoutException as e:
|
|
raise LLMConnectionError(
|
|
f"Timed out connecting to {self._config.endpoint}."
|
|
) from e
|
|
|
|
if response.status_code != 200:
|
|
raise LLMResponseError(
|
|
f"Ollama returned {response.status_code} from /api/tags.",
|
|
status_code=response.status_code,
|
|
)
|
|
|
|
# Check model is available
|
|
try:
|
|
data = response.json()
|
|
except (ValueError, KeyError):
|
|
logger.warning("preflight_parse_error", msg="Could not parse /api/tags response")
|
|
return
|
|
|
|
available = [m.get("name", "") for m in data.get("models", [])]
|
|
model = self._config.model
|
|
|
|
# Match with or without tag suffix (e.g. "qwen3.5" matches "qwen3.5:latest")
|
|
if not any(model == name or model == name.split(":")[0] for name in available):
|
|
available_str = ", ".join(available) if available else "(none)"
|
|
raise LLMResponseError(
|
|
f"Model '{model}' not found. Available models: {available_str}"
|
|
)
|
|
|
|
async def stream_chat(
|
|
self,
|
|
messages: list[Message],
|
|
tools: list[dict[str, Any]] | None = None,
|
|
) -> AsyncIterator[dict]:
|
|
"""Stream a chat completion request, yielding parsed SSE chunks.
|
|
|
|
Args:
|
|
messages: Conversation history to send to the model.
|
|
tools: Optional OpenAI function-calling tool schemas.
|
|
|
|
Yields:
|
|
Parsed JSON dicts from each SSE data line.
|
|
|
|
Raises:
|
|
LLMConnectionError: On connection or timeout failures.
|
|
LLMResponseError: On non-2xx HTTP status.
|
|
LLMStreamError: On malformed SSE data (only if every line fails).
|
|
"""
|
|
payload: dict[str, Any] = {
|
|
"model": self._config.model,
|
|
"messages": [m.to_api_dict() for m in messages],
|
|
"stream": True,
|
|
"temperature": self._config.temperature,
|
|
"max_tokens": self._config.max_tokens,
|
|
}
|
|
|
|
if tools:
|
|
payload["tools"] = tools
|
|
|
|
try:
|
|
async with self._client.stream(
|
|
"POST", self._config.api_path, json=payload
|
|
) as response:
|
|
if response.status_code != 200:
|
|
body = await response.aread()
|
|
raise LLMResponseError(
|
|
f"LLM returned {response.status_code}: {body.decode(errors='replace')}",
|
|
status_code=response.status_code,
|
|
)
|
|
|
|
async for line in response.aiter_lines():
|
|
if not line.startswith("data: "):
|
|
continue
|
|
|
|
data = line[6:] # strip "data: " prefix
|
|
|
|
if data.strip() == "[DONE]":
|
|
return
|
|
|
|
try:
|
|
yield json.loads(data)
|
|
except json.JSONDecodeError:
|
|
logger.warning("malformed_sse_chunk", data=data[:200])
|
|
|
|
except httpx.ConnectError as e:
|
|
raise LLMConnectionError(f"Cannot connect to LLM endpoint: {e}") from e
|
|
except httpx.TimeoutException as e:
|
|
raise LLMConnectionError(f"LLM request timed out: {e}") from e
|
|
except httpx.HTTPError as e:
|
|
raise LLMError(f"HTTP error communicating with LLM: {e}") from e
|
|
|
|
async def stream_chat_with_retry(
|
|
self,
|
|
messages: list[Message],
|
|
tools: list[dict[str, Any]] | None = None,
|
|
) -> AsyncIterator[dict]:
|
|
"""Stream chat with automatic retry on transient errors.
|
|
|
|
Retries on LLMConnectionError and LLMResponseError with status >= 500.
|
|
Does NOT retry on 4xx errors (client-side, not transient).
|
|
Uses exponential backoff with jitter.
|
|
|
|
Args:
|
|
messages: Conversation history to send to the model.
|
|
tools: Optional OpenAI function-calling tool schemas.
|
|
|
|
Yields:
|
|
Parsed JSON dicts from each SSE data line.
|
|
|
|
Raises:
|
|
LLMConnectionError: After exhausting retries on connection failures.
|
|
LLMResponseError: After exhausting retries on server errors, or immediately on 4xx.
|
|
"""
|
|
max_retries = self._config.max_retries
|
|
last_exception: LLMError | None = None
|
|
|
|
for attempt in range(max_retries + 1):
|
|
try:
|
|
async for chunk in self.stream_chat(messages, tools=tools):
|
|
yield chunk
|
|
return
|
|
except LLMConnectionError as e:
|
|
last_exception = e
|
|
except LLMResponseError as e:
|
|
if e.status_code is not None and e.status_code < 500:
|
|
raise
|
|
last_exception = e
|
|
except LLMStreamError as e:
|
|
last_exception = e
|
|
|
|
if attempt < max_retries:
|
|
backoff = min(
|
|
self._config.retry_backoff_base * (2 ** attempt) + random.uniform(0, 1),
|
|
self._config.retry_backoff_max,
|
|
)
|
|
logger.warning(
|
|
"llm_retry",
|
|
attempt=attempt + 1,
|
|
max_retries=max_retries,
|
|
backoff_seconds=round(backoff, 2),
|
|
error=str(last_exception),
|
|
)
|
|
await asyncio.sleep(backoff)
|
|
|
|
raise last_exception # type: ignore[misc]
|
|
|
|
async def close(self) -> None:
|
|
"""Close the underlying HTTP client."""
|
|
await self._client.aclose()
|
|
|
|
async def __aenter__(self) -> Self:
|
|
return self
|
|
|
|
async def __aexit__(self, *exc: object) -> None:
|
|
await self.close()
|