first commit

This commit is contained in:
2025-11-24 23:10:55 -06:00
commit 8315fa51c9
279 changed files with 74600 additions and 0 deletions

View File

@@ -0,0 +1,462 @@
"""
Tests for Replicate API client.
Tests cover initialization, prompt formatting, generation,
retry logic, and error handling.
"""
import pytest
from unittest.mock import patch, MagicMock
from app.ai.replicate_client import (
ReplicateClient,
ReplicateResponse,
ReplicateClientError,
ReplicateAPIError,
ReplicateRateLimitError,
ReplicateTimeoutError,
ModelType,
)
class TestReplicateClientInit:
"""Tests for ReplicateClient initialization."""
@patch('app.ai.replicate_client.get_config')
def test_init_with_token(self, mock_config):
"""Test initialization with explicit API token."""
mock_config.return_value = MagicMock(
replicate_api_token=None,
REPLICATE_MODEL=None
)
client = ReplicateClient(api_token="test_token_123")
assert client.api_token == "test_token_123"
assert client.model == ReplicateClient.DEFAULT_MODEL.value
@patch('app.ai.replicate_client.get_config')
def test_init_from_config(self, mock_config):
"""Test initialization from config."""
mock_config.return_value = MagicMock(
replicate_api_token="config_token",
REPLICATE_MODEL="custom/model"
)
client = ReplicateClient()
assert client.api_token == "config_token"
assert client.model == "custom/model"
@patch('app.ai.replicate_client.get_config')
def test_init_missing_token(self, mock_config):
"""Test initialization fails without API token."""
mock_config.return_value = MagicMock(
replicate_api_token=None,
REPLICATE_MODEL=None
)
with pytest.raises(ReplicateClientError) as exc_info:
ReplicateClient()
assert "API token not configured" in str(exc_info.value)
@patch('app.ai.replicate_client.get_config')
def test_init_custom_model(self, mock_config):
"""Test initialization with custom model."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
client = ReplicateClient(model="meta/llama-2-70b")
assert client.model == "meta/llama-2-70b"
class TestPromptFormatting:
"""Tests for Llama-3 prompt formatting."""
@patch('app.ai.replicate_client.get_config')
def test_format_prompt_user_only(self, mock_config):
"""Test formatting with only user prompt."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
client = ReplicateClient()
formatted = client._format_llama_prompt("Hello world")
assert "<|begin_of_text|>" in formatted
assert "<|start_header_id|>user<|end_header_id|>" in formatted
assert "Hello world" in formatted
assert "<|start_header_id|>assistant<|end_header_id|>" in formatted
# No system header without system prompt
assert "system<|end_header_id|>" not in formatted
@patch('app.ai.replicate_client.get_config')
def test_format_prompt_with_system(self, mock_config):
"""Test formatting with system and user prompts."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
client = ReplicateClient()
formatted = client._format_llama_prompt(
"What is 2+2?",
system_prompt="You are a helpful assistant."
)
assert "<|start_header_id|>system<|end_header_id|>" in formatted
assert "You are a helpful assistant." in formatted
assert "<|start_header_id|>user<|end_header_id|>" in formatted
assert "What is 2+2?" in formatted
class TestGenerate:
"""Tests for text generation."""
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_generate_success(self, mock_replicate, mock_config):
"""Test successful text generation."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
# Mock streaming response
mock_replicate.run.return_value = iter(["Hello ", "world ", "!"])
client = ReplicateClient()
response = client.generate("Say hello")
assert isinstance(response, ReplicateResponse)
assert response.text == "Hello world !"
assert response.tokens_used > 0
assert response.model == ReplicateClient.DEFAULT_MODEL.value
assert response.generation_time > 0
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_generate_with_parameters(self, mock_replicate, mock_config):
"""Test generation with custom parameters."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
mock_replicate.run.return_value = iter(["Response"])
client = ReplicateClient()
response = client.generate(
prompt="Test",
system_prompt="Be concise",
max_tokens=100,
temperature=0.5,
top_p=0.8,
timeout=60
)
# Verify parameters were passed
call_args = mock_replicate.run.call_args
assert call_args[1]['input']['max_tokens'] == 100
assert call_args[1]['input']['temperature'] == 0.5
assert call_args[1]['input']['top_p'] == 0.8
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_generate_string_response(self, mock_replicate, mock_config):
"""Test handling string response (non-streaming)."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
mock_replicate.run.return_value = "Direct string response"
client = ReplicateClient()
response = client.generate("Test")
assert response.text == "Direct string response"
class TestRetryLogic:
"""Tests for retry and error handling."""
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
@patch('app.ai.replicate_client.time.sleep')
def test_retry_on_rate_limit(self, mock_sleep, mock_replicate, mock_config):
"""Test retry logic on rate limit errors."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
# First call raises rate limit, second succeeds
mock_replicate.exceptions.ReplicateError = Exception
mock_replicate.run.side_effect = [
Exception("Rate limit exceeded 429"),
iter(["Success"])
]
client = ReplicateClient()
response = client.generate("Test")
assert response.text == "Success"
assert mock_sleep.called # Verify backoff happened
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
@patch('app.ai.replicate_client.time.sleep')
def test_max_retries_exceeded(self, mock_sleep, mock_replicate, mock_config):
"""Test that max retries raises error."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
# All calls fail with rate limit
mock_replicate.exceptions.ReplicateError = Exception
mock_replicate.run.side_effect = Exception("Rate limit exceeded 429")
client = ReplicateClient()
with pytest.raises(ReplicateRateLimitError):
client.generate("Test")
# Should have retried MAX_RETRIES times
assert mock_replicate.run.call_count == ReplicateClient.MAX_RETRIES
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_timeout_error(self, mock_replicate, mock_config):
"""Test timeout error handling."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
mock_replicate.exceptions.ReplicateError = Exception
mock_replicate.run.side_effect = Exception("Request timeout")
client = ReplicateClient()
with pytest.raises(ReplicateTimeoutError):
client.generate("Test")
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_api_error(self, mock_replicate, mock_config):
"""Test generic API error handling."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
mock_replicate.exceptions.ReplicateError = Exception
mock_replicate.run.side_effect = Exception("Invalid model")
client = ReplicateClient()
with pytest.raises(ReplicateAPIError):
client.generate("Test")
class TestValidation:
"""Tests for API key validation."""
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_validate_api_key_success(self, mock_replicate, mock_config):
"""Test successful API key validation."""
mock_config.return_value = MagicMock(
replicate_api_token="valid_token",
REPLICATE_MODEL=None
)
mock_replicate.models.get.return_value = MagicMock()
client = ReplicateClient()
assert client.validate_api_key() is True
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_validate_api_key_failure(self, mock_replicate, mock_config):
"""Test failed API key validation."""
mock_config.return_value = MagicMock(
replicate_api_token="invalid_token",
REPLICATE_MODEL=None
)
mock_replicate.models.get.side_effect = Exception("Invalid API token")
client = ReplicateClient()
assert client.validate_api_key() is False
class TestResponseDataclass:
"""Tests for ReplicateResponse dataclass."""
def test_response_creation(self):
"""Test creating ReplicateResponse."""
response = ReplicateResponse(
text="Hello world",
tokens_used=50,
model="meta/llama-3-8b",
generation_time=1.5
)
assert response.text == "Hello world"
assert response.tokens_used == 50
assert response.model == "meta/llama-3-8b"
assert response.generation_time == 1.5
def test_response_immutability(self):
"""Test that response fields are accessible."""
response = ReplicateResponse(
text="Test",
tokens_used=10,
model="test",
generation_time=0.5
)
# Dataclass should allow attribute access
assert hasattr(response, 'text')
assert hasattr(response, 'tokens_used')
class TestModelType:
"""Tests for ModelType enum and multi-model support."""
def test_model_type_values(self):
"""Test ModelType enum has expected values."""
assert ModelType.LLAMA_3_8B.value == "meta/meta-llama-3-8b-instruct"
assert ModelType.CLAUDE_HAIKU.value == "anthropic/claude-3.5-haiku"
assert ModelType.CLAUDE_SONNET.value == "anthropic/claude-3.5-sonnet"
assert ModelType.CLAUDE_SONNET_4.value == "anthropic/claude-sonnet-4"
@patch('app.ai.replicate_client.get_config')
def test_init_with_model_type_enum(self, mock_config):
"""Test initialization with ModelType enum."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
client = ReplicateClient(model=ModelType.CLAUDE_HAIKU)
assert client.model == "anthropic/claude-3.5-haiku"
assert client.model_type == ModelType.CLAUDE_HAIKU
@patch('app.ai.replicate_client.get_config')
def test_is_claude_model(self, mock_config):
"""Test _is_claude_model helper."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
# Llama model
client = ReplicateClient(model=ModelType.LLAMA_3_8B)
assert client._is_claude_model() is False
# Claude models
client = ReplicateClient(model=ModelType.CLAUDE_HAIKU)
assert client._is_claude_model() is True
client = ReplicateClient(model=ModelType.CLAUDE_SONNET)
assert client._is_claude_model() is True
class TestClaudeModels:
"""Tests for Claude model generation via Replicate."""
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_generate_with_claude_haiku(self, mock_replicate, mock_config):
"""Test generation with Claude Haiku model."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
mock_replicate.run.return_value = iter(["Claude ", "response"])
client = ReplicateClient(model=ModelType.CLAUDE_HAIKU)
response = client.generate("Test prompt")
assert response.text == "Claude response"
assert response.model == "anthropic/claude-3.5-haiku"
# Verify Claude-style params (not Llama formatted prompt)
call_args = mock_replicate.run.call_args
assert "prompt" in call_args[1]['input']
# Claude params don't include Llama special tokens
assert "<|begin_of_text|>" not in call_args[1]['input']['prompt']
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_generate_with_claude_system_prompt(self, mock_replicate, mock_config):
"""Test Claude generation includes system_prompt parameter."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
mock_replicate.run.return_value = iter(["Response"])
client = ReplicateClient(model=ModelType.CLAUDE_SONNET)
client.generate(
prompt="User message",
system_prompt="You are a DM"
)
call_args = mock_replicate.run.call_args
assert call_args[1]['input']['system_prompt'] == "You are a DM"
assert call_args[1]['input']['prompt'] == "User message"
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_model_specific_defaults(self, mock_replicate, mock_config):
"""Test that model-specific defaults are applied."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
mock_replicate.run.return_value = iter(["Response"])
# Claude Sonnet should use higher max_tokens by default
client = ReplicateClient(model=ModelType.CLAUDE_SONNET)
client.generate("Test")
call_args = mock_replicate.run.call_args
# Sonnet default is 1024 tokens
assert call_args[1]['input']['max_tokens'] == 1024
assert call_args[1]['input']['temperature'] == 0.9
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_model_override_in_generate(self, mock_replicate, mock_config):
"""Test overriding model in generate() call."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
mock_replicate.run.return_value = iter(["Response"])
# Init with Llama, but call with Claude
client = ReplicateClient(model=ModelType.LLAMA_3_8B)
response = client.generate("Test", model=ModelType.CLAUDE_HAIKU)
# Response should reflect the overridden model
assert response.model == "anthropic/claude-3.5-haiku"
# Verify correct model was called
call_args = mock_replicate.run.call_args
assert call_args[0][0] == "anthropic/claude-3.5-haiku"