463 lines
16 KiB
Python
463 lines
16 KiB
Python
"""
|
|
Tests for Replicate API client.
|
|
|
|
Tests cover initialization, prompt formatting, generation,
|
|
retry logic, and error handling.
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
from app.ai.replicate_client import (
|
|
ReplicateClient,
|
|
ReplicateResponse,
|
|
ReplicateClientError,
|
|
ReplicateAPIError,
|
|
ReplicateRateLimitError,
|
|
ReplicateTimeoutError,
|
|
ModelType,
|
|
)
|
|
|
|
|
|
class TestReplicateClientInit:
|
|
"""Tests for ReplicateClient initialization."""
|
|
|
|
@patch('app.ai.replicate_client.get_config')
|
|
def test_init_with_token(self, mock_config):
|
|
"""Test initialization with explicit API token."""
|
|
mock_config.return_value = MagicMock(
|
|
replicate_api_token=None,
|
|
REPLICATE_MODEL=None
|
|
)
|
|
|
|
client = ReplicateClient(api_token="test_token_123")
|
|
|
|
assert client.api_token == "test_token_123"
|
|
assert client.model == ReplicateClient.DEFAULT_MODEL.value
|
|
|
|
@patch('app.ai.replicate_client.get_config')
|
|
def test_init_from_config(self, mock_config):
|
|
"""Test initialization from config."""
|
|
mock_config.return_value = MagicMock(
|
|
replicate_api_token="config_token",
|
|
REPLICATE_MODEL="custom/model"
|
|
)
|
|
|
|
client = ReplicateClient()
|
|
|
|
assert client.api_token == "config_token"
|
|
assert client.model == "custom/model"
|
|
|
|
@patch('app.ai.replicate_client.get_config')
|
|
def test_init_missing_token(self, mock_config):
|
|
"""Test initialization fails without API token."""
|
|
mock_config.return_value = MagicMock(
|
|
replicate_api_token=None,
|
|
REPLICATE_MODEL=None
|
|
)
|
|
|
|
with pytest.raises(ReplicateClientError) as exc_info:
|
|
ReplicateClient()
|
|
|
|
assert "API token not configured" in str(exc_info.value)
|
|
|
|
@patch('app.ai.replicate_client.get_config')
|
|
def test_init_custom_model(self, mock_config):
|
|
"""Test initialization with custom model."""
|
|
mock_config.return_value = MagicMock(
|
|
replicate_api_token="token",
|
|
REPLICATE_MODEL=None
|
|
)
|
|
|
|
client = ReplicateClient(model="meta/llama-2-70b")
|
|
|
|
assert client.model == "meta/llama-2-70b"
|
|
|
|
|
|
class TestPromptFormatting:
|
|
"""Tests for Llama-3 prompt formatting."""
|
|
|
|
@patch('app.ai.replicate_client.get_config')
|
|
def test_format_prompt_user_only(self, mock_config):
|
|
"""Test formatting with only user prompt."""
|
|
mock_config.return_value = MagicMock(
|
|
replicate_api_token="token",
|
|
REPLICATE_MODEL=None
|
|
)
|
|
|
|
client = ReplicateClient()
|
|
formatted = client._format_llama_prompt("Hello world")
|
|
|
|
assert "<|begin_of_text|>" in formatted
|
|
assert "<|start_header_id|>user<|end_header_id|>" in formatted
|
|
assert "Hello world" in formatted
|
|
assert "<|start_header_id|>assistant<|end_header_id|>" in formatted
|
|
# No system header without system prompt
|
|
assert "system<|end_header_id|>" not in formatted
|
|
|
|
@patch('app.ai.replicate_client.get_config')
|
|
def test_format_prompt_with_system(self, mock_config):
|
|
"""Test formatting with system and user prompts."""
|
|
mock_config.return_value = MagicMock(
|
|
replicate_api_token="token",
|
|
REPLICATE_MODEL=None
|
|
)
|
|
|
|
client = ReplicateClient()
|
|
formatted = client._format_llama_prompt(
|
|
"What is 2+2?",
|
|
system_prompt="You are a helpful assistant."
|
|
)
|
|
|
|
assert "<|start_header_id|>system<|end_header_id|>" in formatted
|
|
assert "You are a helpful assistant." in formatted
|
|
assert "<|start_header_id|>user<|end_header_id|>" in formatted
|
|
assert "What is 2+2?" in formatted
|
|
|
|
|
|
class TestGenerate:
|
|
"""Tests for text generation."""
|
|
|
|
@patch('app.ai.replicate_client.get_config')
|
|
@patch('app.ai.replicate_client.replicate')
|
|
def test_generate_success(self, mock_replicate, mock_config):
|
|
"""Test successful text generation."""
|
|
mock_config.return_value = MagicMock(
|
|
replicate_api_token="token",
|
|
REPLICATE_MODEL=None
|
|
)
|
|
|
|
# Mock streaming response
|
|
mock_replicate.run.return_value = iter(["Hello ", "world ", "!"])
|
|
|
|
client = ReplicateClient()
|
|
response = client.generate("Say hello")
|
|
|
|
assert isinstance(response, ReplicateResponse)
|
|
assert response.text == "Hello world !"
|
|
assert response.tokens_used > 0
|
|
assert response.model == ReplicateClient.DEFAULT_MODEL.value
|
|
assert response.generation_time > 0
|
|
|
|
@patch('app.ai.replicate_client.get_config')
|
|
@patch('app.ai.replicate_client.replicate')
|
|
def test_generate_with_parameters(self, mock_replicate, mock_config):
|
|
"""Test generation with custom parameters."""
|
|
mock_config.return_value = MagicMock(
|
|
replicate_api_token="token",
|
|
REPLICATE_MODEL=None
|
|
)
|
|
|
|
mock_replicate.run.return_value = iter(["Response"])
|
|
|
|
client = ReplicateClient()
|
|
response = client.generate(
|
|
prompt="Test",
|
|
system_prompt="Be concise",
|
|
max_tokens=100,
|
|
temperature=0.5,
|
|
top_p=0.8,
|
|
timeout=60
|
|
)
|
|
|
|
# Verify parameters were passed
|
|
call_args = mock_replicate.run.call_args
|
|
assert call_args[1]['input']['max_tokens'] == 100
|
|
assert call_args[1]['input']['temperature'] == 0.5
|
|
assert call_args[1]['input']['top_p'] == 0.8
|
|
|
|
@patch('app.ai.replicate_client.get_config')
|
|
@patch('app.ai.replicate_client.replicate')
|
|
def test_generate_string_response(self, mock_replicate, mock_config):
|
|
"""Test handling string response (non-streaming)."""
|
|
mock_config.return_value = MagicMock(
|
|
replicate_api_token="token",
|
|
REPLICATE_MODEL=None
|
|
)
|
|
|
|
mock_replicate.run.return_value = "Direct string response"
|
|
|
|
client = ReplicateClient()
|
|
response = client.generate("Test")
|
|
|
|
assert response.text == "Direct string response"
|
|
|
|
|
|
class TestRetryLogic:
|
|
"""Tests for retry and error handling."""
|
|
|
|
@patch('app.ai.replicate_client.get_config')
|
|
@patch('app.ai.replicate_client.replicate')
|
|
@patch('app.ai.replicate_client.time.sleep')
|
|
def test_retry_on_rate_limit(self, mock_sleep, mock_replicate, mock_config):
|
|
"""Test retry logic on rate limit errors."""
|
|
mock_config.return_value = MagicMock(
|
|
replicate_api_token="token",
|
|
REPLICATE_MODEL=None
|
|
)
|
|
|
|
# First call raises rate limit, second succeeds
|
|
mock_replicate.exceptions.ReplicateError = Exception
|
|
mock_replicate.run.side_effect = [
|
|
Exception("Rate limit exceeded 429"),
|
|
iter(["Success"])
|
|
]
|
|
|
|
client = ReplicateClient()
|
|
response = client.generate("Test")
|
|
|
|
assert response.text == "Success"
|
|
assert mock_sleep.called # Verify backoff happened
|
|
|
|
@patch('app.ai.replicate_client.get_config')
|
|
@patch('app.ai.replicate_client.replicate')
|
|
@patch('app.ai.replicate_client.time.sleep')
|
|
def test_max_retries_exceeded(self, mock_sleep, mock_replicate, mock_config):
|
|
"""Test that max retries raises error."""
|
|
mock_config.return_value = MagicMock(
|
|
replicate_api_token="token",
|
|
REPLICATE_MODEL=None
|
|
)
|
|
|
|
# All calls fail with rate limit
|
|
mock_replicate.exceptions.ReplicateError = Exception
|
|
mock_replicate.run.side_effect = Exception("Rate limit exceeded 429")
|
|
|
|
client = ReplicateClient()
|
|
|
|
with pytest.raises(ReplicateRateLimitError):
|
|
client.generate("Test")
|
|
|
|
# Should have retried MAX_RETRIES times
|
|
assert mock_replicate.run.call_count == ReplicateClient.MAX_RETRIES
|
|
|
|
@patch('app.ai.replicate_client.get_config')
|
|
@patch('app.ai.replicate_client.replicate')
|
|
def test_timeout_error(self, mock_replicate, mock_config):
|
|
"""Test timeout error handling."""
|
|
mock_config.return_value = MagicMock(
|
|
replicate_api_token="token",
|
|
REPLICATE_MODEL=None
|
|
)
|
|
|
|
mock_replicate.exceptions.ReplicateError = Exception
|
|
mock_replicate.run.side_effect = Exception("Request timeout")
|
|
|
|
client = ReplicateClient()
|
|
|
|
with pytest.raises(ReplicateTimeoutError):
|
|
client.generate("Test")
|
|
|
|
@patch('app.ai.replicate_client.get_config')
|
|
@patch('app.ai.replicate_client.replicate')
|
|
def test_api_error(self, mock_replicate, mock_config):
|
|
"""Test generic API error handling."""
|
|
mock_config.return_value = MagicMock(
|
|
replicate_api_token="token",
|
|
REPLICATE_MODEL=None
|
|
)
|
|
|
|
mock_replicate.exceptions.ReplicateError = Exception
|
|
mock_replicate.run.side_effect = Exception("Invalid model")
|
|
|
|
client = ReplicateClient()
|
|
|
|
with pytest.raises(ReplicateAPIError):
|
|
client.generate("Test")
|
|
|
|
|
|
class TestValidation:
|
|
"""Tests for API key validation."""
|
|
|
|
@patch('app.ai.replicate_client.get_config')
|
|
@patch('app.ai.replicate_client.replicate')
|
|
def test_validate_api_key_success(self, mock_replicate, mock_config):
|
|
"""Test successful API key validation."""
|
|
mock_config.return_value = MagicMock(
|
|
replicate_api_token="valid_token",
|
|
REPLICATE_MODEL=None
|
|
)
|
|
|
|
mock_replicate.models.get.return_value = MagicMock()
|
|
|
|
client = ReplicateClient()
|
|
assert client.validate_api_key() is True
|
|
|
|
@patch('app.ai.replicate_client.get_config')
|
|
@patch('app.ai.replicate_client.replicate')
|
|
def test_validate_api_key_failure(self, mock_replicate, mock_config):
|
|
"""Test failed API key validation."""
|
|
mock_config.return_value = MagicMock(
|
|
replicate_api_token="invalid_token",
|
|
REPLICATE_MODEL=None
|
|
)
|
|
|
|
mock_replicate.models.get.side_effect = Exception("Invalid API token")
|
|
|
|
client = ReplicateClient()
|
|
assert client.validate_api_key() is False
|
|
|
|
|
|
class TestResponseDataclass:
|
|
"""Tests for ReplicateResponse dataclass."""
|
|
|
|
def test_response_creation(self):
|
|
"""Test creating ReplicateResponse."""
|
|
response = ReplicateResponse(
|
|
text="Hello world",
|
|
tokens_used=50,
|
|
model="meta/llama-3-8b",
|
|
generation_time=1.5
|
|
)
|
|
|
|
assert response.text == "Hello world"
|
|
assert response.tokens_used == 50
|
|
assert response.model == "meta/llama-3-8b"
|
|
assert response.generation_time == 1.5
|
|
|
|
def test_response_immutability(self):
|
|
"""Test that response fields are accessible."""
|
|
response = ReplicateResponse(
|
|
text="Test",
|
|
tokens_used=10,
|
|
model="test",
|
|
generation_time=0.5
|
|
)
|
|
|
|
# Dataclass should allow attribute access
|
|
assert hasattr(response, 'text')
|
|
assert hasattr(response, 'tokens_used')
|
|
|
|
|
|
class TestModelType:
|
|
"""Tests for ModelType enum and multi-model support."""
|
|
|
|
def test_model_type_values(self):
|
|
"""Test ModelType enum has expected values."""
|
|
assert ModelType.LLAMA_3_8B.value == "meta/meta-llama-3-8b-instruct"
|
|
assert ModelType.CLAUDE_HAIKU.value == "anthropic/claude-3.5-haiku"
|
|
assert ModelType.CLAUDE_SONNET.value == "anthropic/claude-3.5-sonnet"
|
|
assert ModelType.CLAUDE_SONNET_4.value == "anthropic/claude-sonnet-4"
|
|
|
|
@patch('app.ai.replicate_client.get_config')
|
|
def test_init_with_model_type_enum(self, mock_config):
|
|
"""Test initialization with ModelType enum."""
|
|
mock_config.return_value = MagicMock(
|
|
replicate_api_token="token",
|
|
REPLICATE_MODEL=None
|
|
)
|
|
|
|
client = ReplicateClient(model=ModelType.CLAUDE_HAIKU)
|
|
|
|
assert client.model == "anthropic/claude-3.5-haiku"
|
|
assert client.model_type == ModelType.CLAUDE_HAIKU
|
|
|
|
@patch('app.ai.replicate_client.get_config')
|
|
def test_is_claude_model(self, mock_config):
|
|
"""Test _is_claude_model helper."""
|
|
mock_config.return_value = MagicMock(
|
|
replicate_api_token="token",
|
|
REPLICATE_MODEL=None
|
|
)
|
|
|
|
# Llama model
|
|
client = ReplicateClient(model=ModelType.LLAMA_3_8B)
|
|
assert client._is_claude_model() is False
|
|
|
|
# Claude models
|
|
client = ReplicateClient(model=ModelType.CLAUDE_HAIKU)
|
|
assert client._is_claude_model() is True
|
|
|
|
client = ReplicateClient(model=ModelType.CLAUDE_SONNET)
|
|
assert client._is_claude_model() is True
|
|
|
|
|
|
class TestClaudeModels:
|
|
"""Tests for Claude model generation via Replicate."""
|
|
|
|
@patch('app.ai.replicate_client.get_config')
|
|
@patch('app.ai.replicate_client.replicate')
|
|
def test_generate_with_claude_haiku(self, mock_replicate, mock_config):
|
|
"""Test generation with Claude Haiku model."""
|
|
mock_config.return_value = MagicMock(
|
|
replicate_api_token="token",
|
|
REPLICATE_MODEL=None
|
|
)
|
|
|
|
mock_replicate.run.return_value = iter(["Claude ", "response"])
|
|
|
|
client = ReplicateClient(model=ModelType.CLAUDE_HAIKU)
|
|
response = client.generate("Test prompt")
|
|
|
|
assert response.text == "Claude response"
|
|
assert response.model == "anthropic/claude-3.5-haiku"
|
|
|
|
# Verify Claude-style params (not Llama formatted prompt)
|
|
call_args = mock_replicate.run.call_args
|
|
assert "prompt" in call_args[1]['input']
|
|
# Claude params don't include Llama special tokens
|
|
assert "<|begin_of_text|>" not in call_args[1]['input']['prompt']
|
|
|
|
@patch('app.ai.replicate_client.get_config')
|
|
@patch('app.ai.replicate_client.replicate')
|
|
def test_generate_with_claude_system_prompt(self, mock_replicate, mock_config):
|
|
"""Test Claude generation includes system_prompt parameter."""
|
|
mock_config.return_value = MagicMock(
|
|
replicate_api_token="token",
|
|
REPLICATE_MODEL=None
|
|
)
|
|
|
|
mock_replicate.run.return_value = iter(["Response"])
|
|
|
|
client = ReplicateClient(model=ModelType.CLAUDE_SONNET)
|
|
client.generate(
|
|
prompt="User message",
|
|
system_prompt="You are a DM"
|
|
)
|
|
|
|
call_args = mock_replicate.run.call_args
|
|
assert call_args[1]['input']['system_prompt'] == "You are a DM"
|
|
assert call_args[1]['input']['prompt'] == "User message"
|
|
|
|
@patch('app.ai.replicate_client.get_config')
|
|
@patch('app.ai.replicate_client.replicate')
|
|
def test_model_specific_defaults(self, mock_replicate, mock_config):
|
|
"""Test that model-specific defaults are applied."""
|
|
mock_config.return_value = MagicMock(
|
|
replicate_api_token="token",
|
|
REPLICATE_MODEL=None
|
|
)
|
|
|
|
mock_replicate.run.return_value = iter(["Response"])
|
|
|
|
# Claude Sonnet should use higher max_tokens by default
|
|
client = ReplicateClient(model=ModelType.CLAUDE_SONNET)
|
|
client.generate("Test")
|
|
|
|
call_args = mock_replicate.run.call_args
|
|
# Sonnet default is 1024 tokens
|
|
assert call_args[1]['input']['max_tokens'] == 1024
|
|
assert call_args[1]['input']['temperature'] == 0.9
|
|
|
|
@patch('app.ai.replicate_client.get_config')
|
|
@patch('app.ai.replicate_client.replicate')
|
|
def test_model_override_in_generate(self, mock_replicate, mock_config):
|
|
"""Test overriding model in generate() call."""
|
|
mock_config.return_value = MagicMock(
|
|
replicate_api_token="token",
|
|
REPLICATE_MODEL=None
|
|
)
|
|
|
|
mock_replicate.run.return_value = iter(["Response"])
|
|
|
|
# Init with Llama, but call with Claude
|
|
client = ReplicateClient(model=ModelType.LLAMA_3_8B)
|
|
response = client.generate("Test", model=ModelType.CLAUDE_HAIKU)
|
|
|
|
# Response should reflect the overridden model
|
|
assert response.model == "anthropic/claude-3.5-haiku"
|
|
|
|
# Verify correct model was called
|
|
call_args = mock_replicate.run.call_args
|
|
assert call_args[0][0] == "anthropic/claude-3.5-haiku"
|