Files
SneakyCode/app/services/llm.py
Phillip Tarrant f0d8ef8f0a feat: add thinking mode toggle to suppress reasoning-only response loops
Adds `llm.thinking` config option (default: true) that when disabled:
- Injects /no_think into the last user message for Qwen 3.x compatibility
- Sends chat_template_kwargs in API payload for backends that support it
- Silently and immediately nudges on reasoning-only responses instead of
  showing warnings and wasting retry iterations

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-11 19:34:36 -05:00

258 lines
8.8 KiB
Python

"""LLM client wrapper for Ollama / OpenAI-compatible endpoints."""
import asyncio
import json
import random
from collections.abc import AsyncIterator
from typing import Any, Self
import httpx
from app.models.config import LLMConfig
from app.models.message import Message
from app.utils.logging import get_logger
logger = get_logger(__name__)
# --- Exception hierarchy ---
class LLMError(Exception):
"""Base exception for LLM client errors."""
class LLMConnectionError(LLMError):
"""Connection or timeout failure when reaching the LLM endpoint."""
class LLMResponseError(LLMError):
"""Non-2xx HTTP response from the LLM endpoint."""
def __init__(self, message: str, status_code: int | None = None) -> None:
super().__init__(message)
self.status_code = status_code
class LLMStreamError(LLMError):
"""Malformed SSE data during streaming."""
# --- Client ---
class LLMClient:
"""Async streaming client for OpenAI-compatible chat completions.
Designed for Ollama but works with any endpoint implementing the
OpenAI /v1/chat/completions SSE streaming protocol.
"""
def __init__(self, config: LLMConfig) -> None:
"""Initialize the LLM client.
Args:
config: LLM configuration (model, endpoint, timeout, etc.).
"""
self._config = config
self._client = httpx.AsyncClient(
base_url=config.endpoint,
timeout=httpx.Timeout(config.timeout, connect=10.0),
)
async def list_models(self) -> list[dict[str, str]]:
"""Query Ollama /api/tags for available models.
Returns:
List of dicts with 'name' and 'size' keys.
Raises:
LLMConnectionError: If the endpoint is unreachable.
"""
try:
response = await self._client.get("/api/tags")
data = response.json()
return [
{"name": m.get("name", ""), "size": str(m.get("size", ""))}
for m in data.get("models", [])
]
except (httpx.HTTPError, httpx.TimeoutException) as e:
raise LLMConnectionError(f"Failed to list models: {e}") from e
async def preflight_check(self) -> None:
"""Verify the endpoint is reachable and the configured model is available.
Raises:
LLMConnectionError: If the endpoint is unreachable.
LLMResponseError: If the model is not found or the endpoint returns an error.
"""
# Check endpoint is reachable
try:
response = await self._client.get("/api/tags")
except (httpx.ConnectError, httpx.HTTPError, OSError) as e:
raise LLMConnectionError(
f"Cannot reach Ollama at {self._config.endpoint}. Is Ollama running?"
) from e
except httpx.TimeoutException as e:
raise LLMConnectionError(
f"Timed out connecting to {self._config.endpoint}."
) from e
if response.status_code != 200:
raise LLMResponseError(
f"Ollama returned {response.status_code} from /api/tags.",
status_code=response.status_code,
)
# Check model is available
try:
data = response.json()
except (ValueError, KeyError):
logger.warning("preflight_parse_error", msg="Could not parse /api/tags response")
return
available = [m.get("name", "") for m in data.get("models", [])]
model = self._config.model
# Match with or without tag suffix (e.g. "qwen3.5" matches "qwen3.5:latest")
if not any(model == name or model == name.split(":")[0] for name in available):
available_str = ", ".join(available) if available else "(none)"
raise LLMResponseError(
f"Model '{model}' not found. Available models: {available_str}"
)
async def stream_chat(
self,
messages: list[Message],
tools: list[dict[str, Any]] | None = None,
) -> AsyncIterator[dict]:
"""Stream a chat completion request, yielding parsed SSE chunks.
Args:
messages: Conversation history to send to the model.
tools: Optional OpenAI function-calling tool schemas.
Yields:
Parsed JSON dicts from each SSE data line.
Raises:
LLMConnectionError: On connection or timeout failures.
LLMResponseError: On non-2xx HTTP status.
LLMStreamError: On malformed SSE data (only if every line fails).
"""
payload: dict[str, Any] = {
"model": self._config.model,
"messages": [m.to_api_dict() for m in messages],
"stream": True,
"temperature": self._config.temperature,
"max_tokens": self._config.max_tokens,
}
if tools:
payload["tools"] = tools
# When thinking is disabled, inject chat_template_kwargs for backends that support it
if not self._config.thinking:
payload.setdefault("chat_template_kwargs", {})["enable_thinking"] = False
# Merge model-specific extra parameters (e.g., reasoning_effort)
if self._config.extra_body:
payload.update(self._config.extra_body)
try:
async with self._client.stream(
"POST", self._config.api_path, json=payload
) as response:
if response.status_code != 200:
body = await response.aread()
raise LLMResponseError(
f"LLM returned {response.status_code}: {body.decode(errors='replace')}",
status_code=response.status_code,
)
async for line in response.aiter_lines():
if not line.startswith("data: "):
continue
data = line[6:] # strip "data: " prefix
if data.strip() == "[DONE]":
return
try:
yield json.loads(data)
except json.JSONDecodeError:
logger.warning("malformed_sse_chunk", data=data[:200])
except httpx.ConnectError as e:
raise LLMConnectionError(f"Cannot connect to LLM endpoint: {e}") from e
except httpx.TimeoutException as e:
raise LLMConnectionError(f"LLM request timed out: {e}") from e
except httpx.HTTPError as e:
raise LLMError(f"HTTP error communicating with LLM: {e}") from e
async def stream_chat_with_retry(
self,
messages: list[Message],
tools: list[dict[str, Any]] | None = None,
) -> AsyncIterator[dict]:
"""Stream chat with automatic retry on transient errors.
Retries on LLMConnectionError and LLMResponseError with status >= 500.
Does NOT retry on 4xx errors (client-side, not transient).
Uses exponential backoff with jitter.
Args:
messages: Conversation history to send to the model.
tools: Optional OpenAI function-calling tool schemas.
Yields:
Parsed JSON dicts from each SSE data line.
Raises:
LLMConnectionError: After exhausting retries on connection failures.
LLMResponseError: After exhausting retries on server errors, or immediately on 4xx.
"""
max_retries = self._config.max_retries
last_exception: LLMError | None = None
for attempt in range(max_retries + 1):
try:
async for chunk in self.stream_chat(messages, tools=tools):
yield chunk
return
except LLMConnectionError as e:
last_exception = e
except LLMResponseError as e:
if e.status_code is not None and e.status_code < 500:
raise
last_exception = e
except LLMStreamError as e:
last_exception = e
if attempt < max_retries:
backoff = min(
self._config.retry_backoff_base * (2 ** attempt) + random.uniform(0, 1),
self._config.retry_backoff_max,
)
logger.warning(
"llm_retry",
attempt=attempt + 1,
max_retries=max_retries,
backoff_seconds=round(backoff, 2),
error=str(last_exception),
)
await asyncio.sleep(backoff)
raise last_exception # type: ignore[misc]
async def close(self) -> None:
"""Close the underlying HTTP client."""
await self._client.aclose()
async def __aenter__(self) -> Self:
return self
async def __aexit__(self, *exc: object) -> None:
await self.close()