Files
SneakyCode/app/services/llm.py
Phillip Tarrant 76ba490aa2 Add Phase 7: polish and hardening — retry, truncation, sessions, shutdown
- Config extensions: retry backoff, truncation threshold, session persistence
- LLM retry with exponential backoff + jitter on transient errors (5xx, connection)
- Conversation truncation: drops oldest messages preserving first user + recent N
- Session persistence: auto-save/restore with atomic writes, cleanup of old files
- Graceful shutdown: SIGTERM handler, cancel() on AgentLoop, save-on-exit
- Partial message recovery on mid-stream interruption
- New slash commands: /save, /session
- 18 new tests (5 retry, 5 truncation, 4 session, 4 integration workflows)
- README.md and docs/tools.md documentation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-11 10:20:16 -05:00

231 lines
7.7 KiB
Python

"""LLM client wrapper for Ollama / OpenAI-compatible endpoints."""
import asyncio
import json
import random
from collections.abc import AsyncIterator
from typing import Any, Self
import httpx
from app.models.config import LLMConfig
from app.models.message import Message
from app.utils.logging import get_logger
logger = get_logger(__name__)
# --- Exception hierarchy ---
class LLMError(Exception):
"""Base exception for LLM client errors."""
class LLMConnectionError(LLMError):
"""Connection or timeout failure when reaching the LLM endpoint."""
class LLMResponseError(LLMError):
"""Non-2xx HTTP response from the LLM endpoint."""
def __init__(self, message: str, status_code: int | None = None) -> None:
super().__init__(message)
self.status_code = status_code
class LLMStreamError(LLMError):
"""Malformed SSE data during streaming."""
# --- Client ---
class LLMClient:
"""Async streaming client for OpenAI-compatible chat completions.
Designed for Ollama but works with any endpoint implementing the
OpenAI /v1/chat/completions SSE streaming protocol.
"""
def __init__(self, config: LLMConfig) -> None:
"""Initialize the LLM client.
Args:
config: LLM configuration (model, endpoint, timeout, etc.).
"""
self._config = config
self._client = httpx.AsyncClient(
base_url=config.endpoint,
timeout=httpx.Timeout(config.timeout, connect=10.0),
)
async def preflight_check(self) -> None:
"""Verify the endpoint is reachable and the configured model is available.
Raises:
LLMConnectionError: If the endpoint is unreachable.
LLMResponseError: If the model is not found or the endpoint returns an error.
"""
# Check endpoint is reachable
try:
response = await self._client.get("/api/tags")
except (httpx.ConnectError, httpx.HTTPError, OSError) as e:
raise LLMConnectionError(
f"Cannot reach Ollama at {self._config.endpoint}. Is Ollama running?"
) from e
except httpx.TimeoutException as e:
raise LLMConnectionError(
f"Timed out connecting to {self._config.endpoint}."
) from e
if response.status_code != 200:
raise LLMResponseError(
f"Ollama returned {response.status_code} from /api/tags.",
status_code=response.status_code,
)
# Check model is available
try:
data = response.json()
except (ValueError, KeyError):
logger.warning("preflight_parse_error", msg="Could not parse /api/tags response")
return
available = [m.get("name", "") for m in data.get("models", [])]
model = self._config.model
# Match with or without tag suffix (e.g. "qwen3.5" matches "qwen3.5:latest")
if not any(model == name or model == name.split(":")[0] for name in available):
available_str = ", ".join(available) if available else "(none)"
raise LLMResponseError(
f"Model '{model}' not found. Available models: {available_str}"
)
async def stream_chat(
self,
messages: list[Message],
tools: list[dict[str, Any]] | None = None,
) -> AsyncIterator[dict]:
"""Stream a chat completion request, yielding parsed SSE chunks.
Args:
messages: Conversation history to send to the model.
tools: Optional OpenAI function-calling tool schemas.
Yields:
Parsed JSON dicts from each SSE data line.
Raises:
LLMConnectionError: On connection or timeout failures.
LLMResponseError: On non-2xx HTTP status.
LLMStreamError: On malformed SSE data (only if every line fails).
"""
payload: dict[str, Any] = {
"model": self._config.model,
"messages": [m.to_api_dict() for m in messages],
"stream": True,
"temperature": self._config.temperature,
"max_tokens": self._config.max_tokens,
}
if tools:
payload["tools"] = tools
try:
async with self._client.stream(
"POST", self._config.api_path, json=payload
) as response:
if response.status_code != 200:
body = await response.aread()
raise LLMResponseError(
f"LLM returned {response.status_code}: {body.decode(errors='replace')}",
status_code=response.status_code,
)
async for line in response.aiter_lines():
if not line.startswith("data: "):
continue
data = line[6:] # strip "data: " prefix
if data.strip() == "[DONE]":
return
try:
yield json.loads(data)
except json.JSONDecodeError:
logger.warning("malformed_sse_chunk", data=data[:200])
except httpx.ConnectError as e:
raise LLMConnectionError(f"Cannot connect to LLM endpoint: {e}") from e
except httpx.TimeoutException as e:
raise LLMConnectionError(f"LLM request timed out: {e}") from e
except httpx.HTTPError as e:
raise LLMError(f"HTTP error communicating with LLM: {e}") from e
async def stream_chat_with_retry(
self,
messages: list[Message],
tools: list[dict[str, Any]] | None = None,
) -> AsyncIterator[dict]:
"""Stream chat with automatic retry on transient errors.
Retries on LLMConnectionError and LLMResponseError with status >= 500.
Does NOT retry on 4xx errors (client-side, not transient).
Uses exponential backoff with jitter.
Args:
messages: Conversation history to send to the model.
tools: Optional OpenAI function-calling tool schemas.
Yields:
Parsed JSON dicts from each SSE data line.
Raises:
LLMConnectionError: After exhausting retries on connection failures.
LLMResponseError: After exhausting retries on server errors, or immediately on 4xx.
"""
max_retries = self._config.max_retries
last_exception: LLMError | None = None
for attempt in range(max_retries + 1):
try:
async for chunk in self.stream_chat(messages, tools=tools):
yield chunk
return
except LLMConnectionError as e:
last_exception = e
except LLMResponseError as e:
if e.status_code is not None and e.status_code < 500:
raise
last_exception = e
except LLMStreamError as e:
last_exception = e
if attempt < max_retries:
backoff = min(
self._config.retry_backoff_base * (2 ** attempt) + random.uniform(0, 1),
self._config.retry_backoff_max,
)
logger.warning(
"llm_retry",
attempt=attempt + 1,
max_retries=max_retries,
backoff_seconds=round(backoff, 2),
error=str(last_exception),
)
await asyncio.sleep(backoff)
raise last_exception # type: ignore[misc]
async def close(self) -> None:
"""Close the underlying HTTP client."""
await self._client.aclose()
async def __aenter__(self) -> Self:
return self
async def __aexit__(self, *exc: object) -> None:
await self.close()