first commit
This commit is contained in:
462
api/tests/test_replicate_client.py
Normal file
462
api/tests/test_replicate_client.py
Normal file
@@ -0,0 +1,462 @@
|
||||
"""
|
||||
Tests for Replicate API client.
|
||||
|
||||
Tests cover initialization, prompt formatting, generation,
|
||||
retry logic, and error handling.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from app.ai.replicate_client import (
|
||||
ReplicateClient,
|
||||
ReplicateResponse,
|
||||
ReplicateClientError,
|
||||
ReplicateAPIError,
|
||||
ReplicateRateLimitError,
|
||||
ReplicateTimeoutError,
|
||||
ModelType,
|
||||
)
|
||||
|
||||
|
||||
class TestReplicateClientInit:
|
||||
"""Tests for ReplicateClient initialization."""
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
def test_init_with_token(self, mock_config):
|
||||
"""Test initialization with explicit API token."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token=None,
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
client = ReplicateClient(api_token="test_token_123")
|
||||
|
||||
assert client.api_token == "test_token_123"
|
||||
assert client.model == ReplicateClient.DEFAULT_MODEL.value
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
def test_init_from_config(self, mock_config):
|
||||
"""Test initialization from config."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="config_token",
|
||||
REPLICATE_MODEL="custom/model"
|
||||
)
|
||||
|
||||
client = ReplicateClient()
|
||||
|
||||
assert client.api_token == "config_token"
|
||||
assert client.model == "custom/model"
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
def test_init_missing_token(self, mock_config):
|
||||
"""Test initialization fails without API token."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token=None,
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
with pytest.raises(ReplicateClientError) as exc_info:
|
||||
ReplicateClient()
|
||||
|
||||
assert "API token not configured" in str(exc_info.value)
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
def test_init_custom_model(self, mock_config):
|
||||
"""Test initialization with custom model."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
client = ReplicateClient(model="meta/llama-2-70b")
|
||||
|
||||
assert client.model == "meta/llama-2-70b"
|
||||
|
||||
|
||||
class TestPromptFormatting:
|
||||
"""Tests for Llama-3 prompt formatting."""
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
def test_format_prompt_user_only(self, mock_config):
|
||||
"""Test formatting with only user prompt."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
client = ReplicateClient()
|
||||
formatted = client._format_llama_prompt("Hello world")
|
||||
|
||||
assert "<|begin_of_text|>" in formatted
|
||||
assert "<|start_header_id|>user<|end_header_id|>" in formatted
|
||||
assert "Hello world" in formatted
|
||||
assert "<|start_header_id|>assistant<|end_header_id|>" in formatted
|
||||
# No system header without system prompt
|
||||
assert "system<|end_header_id|>" not in formatted
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
def test_format_prompt_with_system(self, mock_config):
|
||||
"""Test formatting with system and user prompts."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
client = ReplicateClient()
|
||||
formatted = client._format_llama_prompt(
|
||||
"What is 2+2?",
|
||||
system_prompt="You are a helpful assistant."
|
||||
)
|
||||
|
||||
assert "<|start_header_id|>system<|end_header_id|>" in formatted
|
||||
assert "You are a helpful assistant." in formatted
|
||||
assert "<|start_header_id|>user<|end_header_id|>" in formatted
|
||||
assert "What is 2+2?" in formatted
|
||||
|
||||
|
||||
class TestGenerate:
|
||||
"""Tests for text generation."""
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
@patch('app.ai.replicate_client.replicate')
|
||||
def test_generate_success(self, mock_replicate, mock_config):
|
||||
"""Test successful text generation."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
# Mock streaming response
|
||||
mock_replicate.run.return_value = iter(["Hello ", "world ", "!"])
|
||||
|
||||
client = ReplicateClient()
|
||||
response = client.generate("Say hello")
|
||||
|
||||
assert isinstance(response, ReplicateResponse)
|
||||
assert response.text == "Hello world !"
|
||||
assert response.tokens_used > 0
|
||||
assert response.model == ReplicateClient.DEFAULT_MODEL.value
|
||||
assert response.generation_time > 0
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
@patch('app.ai.replicate_client.replicate')
|
||||
def test_generate_with_parameters(self, mock_replicate, mock_config):
|
||||
"""Test generation with custom parameters."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
mock_replicate.run.return_value = iter(["Response"])
|
||||
|
||||
client = ReplicateClient()
|
||||
response = client.generate(
|
||||
prompt="Test",
|
||||
system_prompt="Be concise",
|
||||
max_tokens=100,
|
||||
temperature=0.5,
|
||||
top_p=0.8,
|
||||
timeout=60
|
||||
)
|
||||
|
||||
# Verify parameters were passed
|
||||
call_args = mock_replicate.run.call_args
|
||||
assert call_args[1]['input']['max_tokens'] == 100
|
||||
assert call_args[1]['input']['temperature'] == 0.5
|
||||
assert call_args[1]['input']['top_p'] == 0.8
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
@patch('app.ai.replicate_client.replicate')
|
||||
def test_generate_string_response(self, mock_replicate, mock_config):
|
||||
"""Test handling string response (non-streaming)."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
mock_replicate.run.return_value = "Direct string response"
|
||||
|
||||
client = ReplicateClient()
|
||||
response = client.generate("Test")
|
||||
|
||||
assert response.text == "Direct string response"
|
||||
|
||||
|
||||
class TestRetryLogic:
|
||||
"""Tests for retry and error handling."""
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
@patch('app.ai.replicate_client.replicate')
|
||||
@patch('app.ai.replicate_client.time.sleep')
|
||||
def test_retry_on_rate_limit(self, mock_sleep, mock_replicate, mock_config):
|
||||
"""Test retry logic on rate limit errors."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
# First call raises rate limit, second succeeds
|
||||
mock_replicate.exceptions.ReplicateError = Exception
|
||||
mock_replicate.run.side_effect = [
|
||||
Exception("Rate limit exceeded 429"),
|
||||
iter(["Success"])
|
||||
]
|
||||
|
||||
client = ReplicateClient()
|
||||
response = client.generate("Test")
|
||||
|
||||
assert response.text == "Success"
|
||||
assert mock_sleep.called # Verify backoff happened
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
@patch('app.ai.replicate_client.replicate')
|
||||
@patch('app.ai.replicate_client.time.sleep')
|
||||
def test_max_retries_exceeded(self, mock_sleep, mock_replicate, mock_config):
|
||||
"""Test that max retries raises error."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
# All calls fail with rate limit
|
||||
mock_replicate.exceptions.ReplicateError = Exception
|
||||
mock_replicate.run.side_effect = Exception("Rate limit exceeded 429")
|
||||
|
||||
client = ReplicateClient()
|
||||
|
||||
with pytest.raises(ReplicateRateLimitError):
|
||||
client.generate("Test")
|
||||
|
||||
# Should have retried MAX_RETRIES times
|
||||
assert mock_replicate.run.call_count == ReplicateClient.MAX_RETRIES
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
@patch('app.ai.replicate_client.replicate')
|
||||
def test_timeout_error(self, mock_replicate, mock_config):
|
||||
"""Test timeout error handling."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
mock_replicate.exceptions.ReplicateError = Exception
|
||||
mock_replicate.run.side_effect = Exception("Request timeout")
|
||||
|
||||
client = ReplicateClient()
|
||||
|
||||
with pytest.raises(ReplicateTimeoutError):
|
||||
client.generate("Test")
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
@patch('app.ai.replicate_client.replicate')
|
||||
def test_api_error(self, mock_replicate, mock_config):
|
||||
"""Test generic API error handling."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
mock_replicate.exceptions.ReplicateError = Exception
|
||||
mock_replicate.run.side_effect = Exception("Invalid model")
|
||||
|
||||
client = ReplicateClient()
|
||||
|
||||
with pytest.raises(ReplicateAPIError):
|
||||
client.generate("Test")
|
||||
|
||||
|
||||
class TestValidation:
|
||||
"""Tests for API key validation."""
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
@patch('app.ai.replicate_client.replicate')
|
||||
def test_validate_api_key_success(self, mock_replicate, mock_config):
|
||||
"""Test successful API key validation."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="valid_token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
mock_replicate.models.get.return_value = MagicMock()
|
||||
|
||||
client = ReplicateClient()
|
||||
assert client.validate_api_key() is True
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
@patch('app.ai.replicate_client.replicate')
|
||||
def test_validate_api_key_failure(self, mock_replicate, mock_config):
|
||||
"""Test failed API key validation."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="invalid_token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
mock_replicate.models.get.side_effect = Exception("Invalid API token")
|
||||
|
||||
client = ReplicateClient()
|
||||
assert client.validate_api_key() is False
|
||||
|
||||
|
||||
class TestResponseDataclass:
|
||||
"""Tests for ReplicateResponse dataclass."""
|
||||
|
||||
def test_response_creation(self):
|
||||
"""Test creating ReplicateResponse."""
|
||||
response = ReplicateResponse(
|
||||
text="Hello world",
|
||||
tokens_used=50,
|
||||
model="meta/llama-3-8b",
|
||||
generation_time=1.5
|
||||
)
|
||||
|
||||
assert response.text == "Hello world"
|
||||
assert response.tokens_used == 50
|
||||
assert response.model == "meta/llama-3-8b"
|
||||
assert response.generation_time == 1.5
|
||||
|
||||
def test_response_immutability(self):
|
||||
"""Test that response fields are accessible."""
|
||||
response = ReplicateResponse(
|
||||
text="Test",
|
||||
tokens_used=10,
|
||||
model="test",
|
||||
generation_time=0.5
|
||||
)
|
||||
|
||||
# Dataclass should allow attribute access
|
||||
assert hasattr(response, 'text')
|
||||
assert hasattr(response, 'tokens_used')
|
||||
|
||||
|
||||
class TestModelType:
|
||||
"""Tests for ModelType enum and multi-model support."""
|
||||
|
||||
def test_model_type_values(self):
|
||||
"""Test ModelType enum has expected values."""
|
||||
assert ModelType.LLAMA_3_8B.value == "meta/meta-llama-3-8b-instruct"
|
||||
assert ModelType.CLAUDE_HAIKU.value == "anthropic/claude-3.5-haiku"
|
||||
assert ModelType.CLAUDE_SONNET.value == "anthropic/claude-3.5-sonnet"
|
||||
assert ModelType.CLAUDE_SONNET_4.value == "anthropic/claude-sonnet-4"
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
def test_init_with_model_type_enum(self, mock_config):
|
||||
"""Test initialization with ModelType enum."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
client = ReplicateClient(model=ModelType.CLAUDE_HAIKU)
|
||||
|
||||
assert client.model == "anthropic/claude-3.5-haiku"
|
||||
assert client.model_type == ModelType.CLAUDE_HAIKU
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
def test_is_claude_model(self, mock_config):
|
||||
"""Test _is_claude_model helper."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
# Llama model
|
||||
client = ReplicateClient(model=ModelType.LLAMA_3_8B)
|
||||
assert client._is_claude_model() is False
|
||||
|
||||
# Claude models
|
||||
client = ReplicateClient(model=ModelType.CLAUDE_HAIKU)
|
||||
assert client._is_claude_model() is True
|
||||
|
||||
client = ReplicateClient(model=ModelType.CLAUDE_SONNET)
|
||||
assert client._is_claude_model() is True
|
||||
|
||||
|
||||
class TestClaudeModels:
|
||||
"""Tests for Claude model generation via Replicate."""
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
@patch('app.ai.replicate_client.replicate')
|
||||
def test_generate_with_claude_haiku(self, mock_replicate, mock_config):
|
||||
"""Test generation with Claude Haiku model."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
mock_replicate.run.return_value = iter(["Claude ", "response"])
|
||||
|
||||
client = ReplicateClient(model=ModelType.CLAUDE_HAIKU)
|
||||
response = client.generate("Test prompt")
|
||||
|
||||
assert response.text == "Claude response"
|
||||
assert response.model == "anthropic/claude-3.5-haiku"
|
||||
|
||||
# Verify Claude-style params (not Llama formatted prompt)
|
||||
call_args = mock_replicate.run.call_args
|
||||
assert "prompt" in call_args[1]['input']
|
||||
# Claude params don't include Llama special tokens
|
||||
assert "<|begin_of_text|>" not in call_args[1]['input']['prompt']
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
@patch('app.ai.replicate_client.replicate')
|
||||
def test_generate_with_claude_system_prompt(self, mock_replicate, mock_config):
|
||||
"""Test Claude generation includes system_prompt parameter."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
mock_replicate.run.return_value = iter(["Response"])
|
||||
|
||||
client = ReplicateClient(model=ModelType.CLAUDE_SONNET)
|
||||
client.generate(
|
||||
prompt="User message",
|
||||
system_prompt="You are a DM"
|
||||
)
|
||||
|
||||
call_args = mock_replicate.run.call_args
|
||||
assert call_args[1]['input']['system_prompt'] == "You are a DM"
|
||||
assert call_args[1]['input']['prompt'] == "User message"
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
@patch('app.ai.replicate_client.replicate')
|
||||
def test_model_specific_defaults(self, mock_replicate, mock_config):
|
||||
"""Test that model-specific defaults are applied."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
mock_replicate.run.return_value = iter(["Response"])
|
||||
|
||||
# Claude Sonnet should use higher max_tokens by default
|
||||
client = ReplicateClient(model=ModelType.CLAUDE_SONNET)
|
||||
client.generate("Test")
|
||||
|
||||
call_args = mock_replicate.run.call_args
|
||||
# Sonnet default is 1024 tokens
|
||||
assert call_args[1]['input']['max_tokens'] == 1024
|
||||
assert call_args[1]['input']['temperature'] == 0.9
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
@patch('app.ai.replicate_client.replicate')
|
||||
def test_model_override_in_generate(self, mock_replicate, mock_config):
|
||||
"""Test overriding model in generate() call."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
mock_replicate.run.return_value = iter(["Response"])
|
||||
|
||||
# Init with Llama, but call with Claude
|
||||
client = ReplicateClient(model=ModelType.LLAMA_3_8B)
|
||||
response = client.generate("Test", model=ModelType.CLAUDE_HAIKU)
|
||||
|
||||
# Response should reflect the overridden model
|
||||
assert response.model == "anthropic/claude-3.5-haiku"
|
||||
|
||||
# Verify correct model was called
|
||||
call_args = mock_replicate.run.call_args
|
||||
assert call_args[0][0] == "anthropic/claude-3.5-haiku"
|
||||
Reference in New Issue
Block a user