"""Unit tests for LLM retry with exponential backoff.""" from unittest.mock import AsyncMock, patch import pytest from app.models.config import LLMConfig from app.models.message import Message from app.services.llm import LLMClient, LLMConnectionError, LLMResponseError @pytest.fixture def llm_config() -> LLMConfig: return LLMConfig( model="test-model", endpoint="http://localhost:11434", max_retries=3, retry_backoff_base=0.01, retry_backoff_max=0.05, ) @pytest.fixture def client(llm_config: LLMConfig) -> LLMClient: return LLMClient(llm_config) @pytest.fixture def messages() -> list[Message]: return [Message(role="user", content="Hello")] class TestRetry: @pytest.mark.asyncio async def test_succeeds_without_retry(self, client: LLMClient, messages: list[Message]) -> None: """Successful stream doesn't retry.""" call_count = 0 async def fake_stream(*args, **kwargs): nonlocal call_count call_count += 1 yield {"choices": [{"delta": {"content": "Hi"}}]} client.stream_chat = fake_stream # type: ignore[assignment] collected = [] async for chunk in client.stream_chat_with_retry(messages): collected.append(chunk) assert len(collected) == 1 assert call_count == 1 @pytest.mark.asyncio async def test_retries_on_connection_error(self, client: LLMClient, messages: list[Message]) -> None: """Retries on LLMConnectionError, then succeeds.""" call_count = 0 async def flaky_stream(*args, **kwargs): nonlocal call_count call_count += 1 if call_count < 3: raise LLMConnectionError("Connection refused") yield {"choices": [{"delta": {"content": "OK"}}]} client.stream_chat = flaky_stream # type: ignore[assignment] with patch("app.services.llm.asyncio.sleep", new_callable=AsyncMock): collected = [] async for chunk in client.stream_chat_with_retry(messages): collected.append(chunk) assert len(collected) == 1 assert call_count == 3 @pytest.mark.asyncio async def test_retries_on_5xx(self, client: LLMClient, messages: list[Message]) -> None: """Retries on 5xx LLMResponseError.""" call_count = 0 async def server_error_stream(*args, **kwargs): nonlocal call_count call_count += 1 if call_count < 2: raise LLMResponseError("Internal Server Error", status_code=500) yield {"choices": [{"delta": {"content": "OK"}}]} client.stream_chat = server_error_stream # type: ignore[assignment] with patch("app.services.llm.asyncio.sleep", new_callable=AsyncMock): collected = [] async for chunk in client.stream_chat_with_retry(messages): collected.append(chunk) assert len(collected) == 1 assert call_count == 2 @pytest.mark.asyncio async def test_no_retry_on_4xx(self, client: LLMClient, messages: list[Message]) -> None: """Does NOT retry on 4xx errors — raises immediately.""" async def bad_request_stream(*args, **kwargs): raise LLMResponseError("Bad Request", status_code=400) yield # pragma: no cover — make this an async generator client.stream_chat = bad_request_stream # type: ignore[assignment] with pytest.raises(LLMResponseError, match="Bad Request"): async for _ in client.stream_chat_with_retry(messages): pass # pragma: no cover @pytest.mark.asyncio async def test_respects_max_retries(self, client: LLMClient, messages: list[Message]) -> None: """After exhausting retries, re-raises the last exception.""" async def always_fail(*args, **kwargs): raise LLMConnectionError("Down forever") yield # pragma: no cover client.stream_chat = always_fail # type: ignore[assignment] with patch("app.services.llm.asyncio.sleep", new_callable=AsyncMock) as mock_sleep: with pytest.raises(LLMConnectionError, match="Down forever"): async for _ in client.stream_chat_with_retry(messages): pass # pragma: no cover # Should have slept max_retries times (3 retries after initial attempt) assert mock_sleep.call_count == 3