"""LLM client wrapper for Ollama / OpenAI-compatible endpoints.""" import asyncio import json import random from collections.abc import AsyncIterator from typing import Any, Self import httpx from app.models.config import LLMConfig from app.models.message import Message from app.utils.logging import get_logger logger = get_logger(__name__) # --- Exception hierarchy --- class LLMError(Exception): """Base exception for LLM client errors.""" class LLMConnectionError(LLMError): """Connection or timeout failure when reaching the LLM endpoint.""" class LLMResponseError(LLMError): """Non-2xx HTTP response from the LLM endpoint.""" def __init__(self, message: str, status_code: int | None = None) -> None: super().__init__(message) self.status_code = status_code class LLMStreamError(LLMError): """Malformed SSE data during streaming.""" # --- Client --- class LLMClient: """Async streaming client for OpenAI-compatible chat completions. Designed for Ollama but works with any endpoint implementing the OpenAI /v1/chat/completions SSE streaming protocol. """ def __init__(self, config: LLMConfig) -> None: """Initialize the LLM client. Args: config: LLM configuration (model, endpoint, timeout, etc.). """ self._config = config self._client = httpx.AsyncClient( base_url=config.endpoint, timeout=httpx.Timeout(config.timeout, connect=10.0), ) async def preflight_check(self) -> None: """Verify the endpoint is reachable and the configured model is available. Raises: LLMConnectionError: If the endpoint is unreachable. LLMResponseError: If the model is not found or the endpoint returns an error. """ # Check endpoint is reachable try: response = await self._client.get("/api/tags") except (httpx.ConnectError, httpx.HTTPError, OSError) as e: raise LLMConnectionError( f"Cannot reach Ollama at {self._config.endpoint}. Is Ollama running?" ) from e except httpx.TimeoutException as e: raise LLMConnectionError( f"Timed out connecting to {self._config.endpoint}." ) from e if response.status_code != 200: raise LLMResponseError( f"Ollama returned {response.status_code} from /api/tags.", status_code=response.status_code, ) # Check model is available try: data = response.json() except (ValueError, KeyError): logger.warning("preflight_parse_error", msg="Could not parse /api/tags response") return available = [m.get("name", "") for m in data.get("models", [])] model = self._config.model # Match with or without tag suffix (e.g. "qwen3.5" matches "qwen3.5:latest") if not any(model == name or model == name.split(":")[0] for name in available): available_str = ", ".join(available) if available else "(none)" raise LLMResponseError( f"Model '{model}' not found. Available models: {available_str}" ) async def stream_chat( self, messages: list[Message], tools: list[dict[str, Any]] | None = None, ) -> AsyncIterator[dict]: """Stream a chat completion request, yielding parsed SSE chunks. Args: messages: Conversation history to send to the model. tools: Optional OpenAI function-calling tool schemas. Yields: Parsed JSON dicts from each SSE data line. Raises: LLMConnectionError: On connection or timeout failures. LLMResponseError: On non-2xx HTTP status. LLMStreamError: On malformed SSE data (only if every line fails). """ payload: dict[str, Any] = { "model": self._config.model, "messages": [m.to_api_dict() for m in messages], "stream": True, "temperature": self._config.temperature, "max_tokens": self._config.max_tokens, } if tools: payload["tools"] = tools try: async with self._client.stream( "POST", self._config.api_path, json=payload ) as response: if response.status_code != 200: body = await response.aread() raise LLMResponseError( f"LLM returned {response.status_code}: {body.decode(errors='replace')}", status_code=response.status_code, ) async for line in response.aiter_lines(): if not line.startswith("data: "): continue data = line[6:] # strip "data: " prefix if data.strip() == "[DONE]": return try: yield json.loads(data) except json.JSONDecodeError: logger.warning("malformed_sse_chunk", data=data[:200]) except httpx.ConnectError as e: raise LLMConnectionError(f"Cannot connect to LLM endpoint: {e}") from e except httpx.TimeoutException as e: raise LLMConnectionError(f"LLM request timed out: {e}") from e except httpx.HTTPError as e: raise LLMError(f"HTTP error communicating with LLM: {e}") from e async def stream_chat_with_retry( self, messages: list[Message], tools: list[dict[str, Any]] | None = None, ) -> AsyncIterator[dict]: """Stream chat with automatic retry on transient errors. Retries on LLMConnectionError and LLMResponseError with status >= 500. Does NOT retry on 4xx errors (client-side, not transient). Uses exponential backoff with jitter. Args: messages: Conversation history to send to the model. tools: Optional OpenAI function-calling tool schemas. Yields: Parsed JSON dicts from each SSE data line. Raises: LLMConnectionError: After exhausting retries on connection failures. LLMResponseError: After exhausting retries on server errors, or immediately on 4xx. """ max_retries = self._config.max_retries last_exception: LLMError | None = None for attempt in range(max_retries + 1): try: async for chunk in self.stream_chat(messages, tools=tools): yield chunk return except LLMConnectionError as e: last_exception = e except LLMResponseError as e: if e.status_code is not None and e.status_code < 500: raise last_exception = e except LLMStreamError as e: last_exception = e if attempt < max_retries: backoff = min( self._config.retry_backoff_base * (2 ** attempt) + random.uniform(0, 1), self._config.retry_backoff_max, ) logger.warning( "llm_retry", attempt=attempt + 1, max_retries=max_retries, backoff_seconds=round(backoff, 2), error=str(last_exception), ) await asyncio.sleep(backoff) raise last_exception # type: ignore[misc] async def close(self) -> None: """Close the underlying HTTP client.""" await self._client.aclose() async def __aenter__(self) -> Self: return self async def __aexit__(self, *exc: object) -> None: await self.close()