Files
Code_of_Conquest/api/tests/test_replicate_client.py
2025-11-24 23:10:55 -06:00

463 lines
16 KiB
Python

"""
Tests for Replicate API client.
Tests cover initialization, prompt formatting, generation,
retry logic, and error handling.
"""
import pytest
from unittest.mock import patch, MagicMock
from app.ai.replicate_client import (
ReplicateClient,
ReplicateResponse,
ReplicateClientError,
ReplicateAPIError,
ReplicateRateLimitError,
ReplicateTimeoutError,
ModelType,
)
class TestReplicateClientInit:
"""Tests for ReplicateClient initialization."""
@patch('app.ai.replicate_client.get_config')
def test_init_with_token(self, mock_config):
"""Test initialization with explicit API token."""
mock_config.return_value = MagicMock(
replicate_api_token=None,
REPLICATE_MODEL=None
)
client = ReplicateClient(api_token="test_token_123")
assert client.api_token == "test_token_123"
assert client.model == ReplicateClient.DEFAULT_MODEL.value
@patch('app.ai.replicate_client.get_config')
def test_init_from_config(self, mock_config):
"""Test initialization from config."""
mock_config.return_value = MagicMock(
replicate_api_token="config_token",
REPLICATE_MODEL="custom/model"
)
client = ReplicateClient()
assert client.api_token == "config_token"
assert client.model == "custom/model"
@patch('app.ai.replicate_client.get_config')
def test_init_missing_token(self, mock_config):
"""Test initialization fails without API token."""
mock_config.return_value = MagicMock(
replicate_api_token=None,
REPLICATE_MODEL=None
)
with pytest.raises(ReplicateClientError) as exc_info:
ReplicateClient()
assert "API token not configured" in str(exc_info.value)
@patch('app.ai.replicate_client.get_config')
def test_init_custom_model(self, mock_config):
"""Test initialization with custom model."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
client = ReplicateClient(model="meta/llama-2-70b")
assert client.model == "meta/llama-2-70b"
class TestPromptFormatting:
"""Tests for Llama-3 prompt formatting."""
@patch('app.ai.replicate_client.get_config')
def test_format_prompt_user_only(self, mock_config):
"""Test formatting with only user prompt."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
client = ReplicateClient()
formatted = client._format_llama_prompt("Hello world")
assert "<|begin_of_text|>" in formatted
assert "<|start_header_id|>user<|end_header_id|>" in formatted
assert "Hello world" in formatted
assert "<|start_header_id|>assistant<|end_header_id|>" in formatted
# No system header without system prompt
assert "system<|end_header_id|>" not in formatted
@patch('app.ai.replicate_client.get_config')
def test_format_prompt_with_system(self, mock_config):
"""Test formatting with system and user prompts."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
client = ReplicateClient()
formatted = client._format_llama_prompt(
"What is 2+2?",
system_prompt="You are a helpful assistant."
)
assert "<|start_header_id|>system<|end_header_id|>" in formatted
assert "You are a helpful assistant." in formatted
assert "<|start_header_id|>user<|end_header_id|>" in formatted
assert "What is 2+2?" in formatted
class TestGenerate:
"""Tests for text generation."""
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_generate_success(self, mock_replicate, mock_config):
"""Test successful text generation."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
# Mock streaming response
mock_replicate.run.return_value = iter(["Hello ", "world ", "!"])
client = ReplicateClient()
response = client.generate("Say hello")
assert isinstance(response, ReplicateResponse)
assert response.text == "Hello world !"
assert response.tokens_used > 0
assert response.model == ReplicateClient.DEFAULT_MODEL.value
assert response.generation_time > 0
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_generate_with_parameters(self, mock_replicate, mock_config):
"""Test generation with custom parameters."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
mock_replicate.run.return_value = iter(["Response"])
client = ReplicateClient()
response = client.generate(
prompt="Test",
system_prompt="Be concise",
max_tokens=100,
temperature=0.5,
top_p=0.8,
timeout=60
)
# Verify parameters were passed
call_args = mock_replicate.run.call_args
assert call_args[1]['input']['max_tokens'] == 100
assert call_args[1]['input']['temperature'] == 0.5
assert call_args[1]['input']['top_p'] == 0.8
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_generate_string_response(self, mock_replicate, mock_config):
"""Test handling string response (non-streaming)."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
mock_replicate.run.return_value = "Direct string response"
client = ReplicateClient()
response = client.generate("Test")
assert response.text == "Direct string response"
class TestRetryLogic:
"""Tests for retry and error handling."""
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
@patch('app.ai.replicate_client.time.sleep')
def test_retry_on_rate_limit(self, mock_sleep, mock_replicate, mock_config):
"""Test retry logic on rate limit errors."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
# First call raises rate limit, second succeeds
mock_replicate.exceptions.ReplicateError = Exception
mock_replicate.run.side_effect = [
Exception("Rate limit exceeded 429"),
iter(["Success"])
]
client = ReplicateClient()
response = client.generate("Test")
assert response.text == "Success"
assert mock_sleep.called # Verify backoff happened
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
@patch('app.ai.replicate_client.time.sleep')
def test_max_retries_exceeded(self, mock_sleep, mock_replicate, mock_config):
"""Test that max retries raises error."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
# All calls fail with rate limit
mock_replicate.exceptions.ReplicateError = Exception
mock_replicate.run.side_effect = Exception("Rate limit exceeded 429")
client = ReplicateClient()
with pytest.raises(ReplicateRateLimitError):
client.generate("Test")
# Should have retried MAX_RETRIES times
assert mock_replicate.run.call_count == ReplicateClient.MAX_RETRIES
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_timeout_error(self, mock_replicate, mock_config):
"""Test timeout error handling."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
mock_replicate.exceptions.ReplicateError = Exception
mock_replicate.run.side_effect = Exception("Request timeout")
client = ReplicateClient()
with pytest.raises(ReplicateTimeoutError):
client.generate("Test")
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_api_error(self, mock_replicate, mock_config):
"""Test generic API error handling."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
mock_replicate.exceptions.ReplicateError = Exception
mock_replicate.run.side_effect = Exception("Invalid model")
client = ReplicateClient()
with pytest.raises(ReplicateAPIError):
client.generate("Test")
class TestValidation:
"""Tests for API key validation."""
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_validate_api_key_success(self, mock_replicate, mock_config):
"""Test successful API key validation."""
mock_config.return_value = MagicMock(
replicate_api_token="valid_token",
REPLICATE_MODEL=None
)
mock_replicate.models.get.return_value = MagicMock()
client = ReplicateClient()
assert client.validate_api_key() is True
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_validate_api_key_failure(self, mock_replicate, mock_config):
"""Test failed API key validation."""
mock_config.return_value = MagicMock(
replicate_api_token="invalid_token",
REPLICATE_MODEL=None
)
mock_replicate.models.get.side_effect = Exception("Invalid API token")
client = ReplicateClient()
assert client.validate_api_key() is False
class TestResponseDataclass:
"""Tests for ReplicateResponse dataclass."""
def test_response_creation(self):
"""Test creating ReplicateResponse."""
response = ReplicateResponse(
text="Hello world",
tokens_used=50,
model="meta/llama-3-8b",
generation_time=1.5
)
assert response.text == "Hello world"
assert response.tokens_used == 50
assert response.model == "meta/llama-3-8b"
assert response.generation_time == 1.5
def test_response_immutability(self):
"""Test that response fields are accessible."""
response = ReplicateResponse(
text="Test",
tokens_used=10,
model="test",
generation_time=0.5
)
# Dataclass should allow attribute access
assert hasattr(response, 'text')
assert hasattr(response, 'tokens_used')
class TestModelType:
"""Tests for ModelType enum and multi-model support."""
def test_model_type_values(self):
"""Test ModelType enum has expected values."""
assert ModelType.LLAMA_3_8B.value == "meta/meta-llama-3-8b-instruct"
assert ModelType.CLAUDE_HAIKU.value == "anthropic/claude-3.5-haiku"
assert ModelType.CLAUDE_SONNET.value == "anthropic/claude-3.5-sonnet"
assert ModelType.CLAUDE_SONNET_4.value == "anthropic/claude-sonnet-4"
@patch('app.ai.replicate_client.get_config')
def test_init_with_model_type_enum(self, mock_config):
"""Test initialization with ModelType enum."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
client = ReplicateClient(model=ModelType.CLAUDE_HAIKU)
assert client.model == "anthropic/claude-3.5-haiku"
assert client.model_type == ModelType.CLAUDE_HAIKU
@patch('app.ai.replicate_client.get_config')
def test_is_claude_model(self, mock_config):
"""Test _is_claude_model helper."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
# Llama model
client = ReplicateClient(model=ModelType.LLAMA_3_8B)
assert client._is_claude_model() is False
# Claude models
client = ReplicateClient(model=ModelType.CLAUDE_HAIKU)
assert client._is_claude_model() is True
client = ReplicateClient(model=ModelType.CLAUDE_SONNET)
assert client._is_claude_model() is True
class TestClaudeModels:
"""Tests for Claude model generation via Replicate."""
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_generate_with_claude_haiku(self, mock_replicate, mock_config):
"""Test generation with Claude Haiku model."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
mock_replicate.run.return_value = iter(["Claude ", "response"])
client = ReplicateClient(model=ModelType.CLAUDE_HAIKU)
response = client.generate("Test prompt")
assert response.text == "Claude response"
assert response.model == "anthropic/claude-3.5-haiku"
# Verify Claude-style params (not Llama formatted prompt)
call_args = mock_replicate.run.call_args
assert "prompt" in call_args[1]['input']
# Claude params don't include Llama special tokens
assert "<|begin_of_text|>" not in call_args[1]['input']['prompt']
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_generate_with_claude_system_prompt(self, mock_replicate, mock_config):
"""Test Claude generation includes system_prompt parameter."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
mock_replicate.run.return_value = iter(["Response"])
client = ReplicateClient(model=ModelType.CLAUDE_SONNET)
client.generate(
prompt="User message",
system_prompt="You are a DM"
)
call_args = mock_replicate.run.call_args
assert call_args[1]['input']['system_prompt'] == "You are a DM"
assert call_args[1]['input']['prompt'] == "User message"
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_model_specific_defaults(self, mock_replicate, mock_config):
"""Test that model-specific defaults are applied."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
mock_replicate.run.return_value = iter(["Response"])
# Claude Sonnet should use higher max_tokens by default
client = ReplicateClient(model=ModelType.CLAUDE_SONNET)
client.generate("Test")
call_args = mock_replicate.run.call_args
# Sonnet default is 1024 tokens
assert call_args[1]['input']['max_tokens'] == 1024
assert call_args[1]['input']['temperature'] == 0.9
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_model_override_in_generate(self, mock_replicate, mock_config):
"""Test overriding model in generate() call."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
mock_replicate.run.return_value = iter(["Response"])
# Init with Llama, but call with Claude
client = ReplicateClient(model=ModelType.LLAMA_3_8B)
response = client.generate("Test", model=ModelType.CLAUDE_HAIKU)
# Response should reflect the overridden model
assert response.model == "anthropic/claude-3.5-haiku"
# Verify correct model was called
call_args = mock_replicate.run.call_args
assert call_args[0][0] == "anthropic/claude-3.5-haiku"