""" Tests for Replicate API client. Tests cover initialization, prompt formatting, generation, retry logic, and error handling. """ import pytest from unittest.mock import patch, MagicMock from app.ai.replicate_client import ( ReplicateClient, ReplicateResponse, ReplicateClientError, ReplicateAPIError, ReplicateRateLimitError, ReplicateTimeoutError, ModelType, ) class TestReplicateClientInit: """Tests for ReplicateClient initialization.""" @patch('app.ai.replicate_client.get_config') def test_init_with_token(self, mock_config): """Test initialization with explicit API token.""" mock_config.return_value = MagicMock( replicate_api_token=None, REPLICATE_MODEL=None ) client = ReplicateClient(api_token="test_token_123") assert client.api_token == "test_token_123" assert client.model == ReplicateClient.DEFAULT_MODEL.value @patch('app.ai.replicate_client.get_config') def test_init_from_config(self, mock_config): """Test initialization from config.""" mock_config.return_value = MagicMock( replicate_api_token="config_token", REPLICATE_MODEL="custom/model" ) client = ReplicateClient() assert client.api_token == "config_token" assert client.model == "custom/model" @patch('app.ai.replicate_client.get_config') def test_init_missing_token(self, mock_config): """Test initialization fails without API token.""" mock_config.return_value = MagicMock( replicate_api_token=None, REPLICATE_MODEL=None ) with pytest.raises(ReplicateClientError) as exc_info: ReplicateClient() assert "API token not configured" in str(exc_info.value) @patch('app.ai.replicate_client.get_config') def test_init_custom_model(self, mock_config): """Test initialization with custom model.""" mock_config.return_value = MagicMock( replicate_api_token="token", REPLICATE_MODEL=None ) client = ReplicateClient(model="meta/llama-2-70b") assert client.model == "meta/llama-2-70b" class TestPromptFormatting: """Tests for Llama-3 prompt formatting.""" @patch('app.ai.replicate_client.get_config') def test_format_prompt_user_only(self, mock_config): """Test formatting with only user prompt.""" mock_config.return_value = MagicMock( replicate_api_token="token", REPLICATE_MODEL=None ) client = ReplicateClient() formatted = client._format_llama_prompt("Hello world") assert "<|begin_of_text|>" in formatted assert "<|start_header_id|>user<|end_header_id|>" in formatted assert "Hello world" in formatted assert "<|start_header_id|>assistant<|end_header_id|>" in formatted # No system header without system prompt assert "system<|end_header_id|>" not in formatted @patch('app.ai.replicate_client.get_config') def test_format_prompt_with_system(self, mock_config): """Test formatting with system and user prompts.""" mock_config.return_value = MagicMock( replicate_api_token="token", REPLICATE_MODEL=None ) client = ReplicateClient() formatted = client._format_llama_prompt( "What is 2+2?", system_prompt="You are a helpful assistant." ) assert "<|start_header_id|>system<|end_header_id|>" in formatted assert "You are a helpful assistant." in formatted assert "<|start_header_id|>user<|end_header_id|>" in formatted assert "What is 2+2?" in formatted class TestGenerate: """Tests for text generation.""" @patch('app.ai.replicate_client.get_config') @patch('app.ai.replicate_client.replicate') def test_generate_success(self, mock_replicate, mock_config): """Test successful text generation.""" mock_config.return_value = MagicMock( replicate_api_token="token", REPLICATE_MODEL=None ) # Mock streaming response mock_replicate.run.return_value = iter(["Hello ", "world ", "!"]) client = ReplicateClient() response = client.generate("Say hello") assert isinstance(response, ReplicateResponse) assert response.text == "Hello world !" assert response.tokens_used > 0 assert response.model == ReplicateClient.DEFAULT_MODEL.value assert response.generation_time > 0 @patch('app.ai.replicate_client.get_config') @patch('app.ai.replicate_client.replicate') def test_generate_with_parameters(self, mock_replicate, mock_config): """Test generation with custom parameters.""" mock_config.return_value = MagicMock( replicate_api_token="token", REPLICATE_MODEL=None ) mock_replicate.run.return_value = iter(["Response"]) client = ReplicateClient() response = client.generate( prompt="Test", system_prompt="Be concise", max_tokens=100, temperature=0.5, top_p=0.8, timeout=60 ) # Verify parameters were passed call_args = mock_replicate.run.call_args assert call_args[1]['input']['max_tokens'] == 100 assert call_args[1]['input']['temperature'] == 0.5 assert call_args[1]['input']['top_p'] == 0.8 @patch('app.ai.replicate_client.get_config') @patch('app.ai.replicate_client.replicate') def test_generate_string_response(self, mock_replicate, mock_config): """Test handling string response (non-streaming).""" mock_config.return_value = MagicMock( replicate_api_token="token", REPLICATE_MODEL=None ) mock_replicate.run.return_value = "Direct string response" client = ReplicateClient() response = client.generate("Test") assert response.text == "Direct string response" class TestRetryLogic: """Tests for retry and error handling.""" @patch('app.ai.replicate_client.get_config') @patch('app.ai.replicate_client.replicate') @patch('app.ai.replicate_client.time.sleep') def test_retry_on_rate_limit(self, mock_sleep, mock_replicate, mock_config): """Test retry logic on rate limit errors.""" mock_config.return_value = MagicMock( replicate_api_token="token", REPLICATE_MODEL=None ) # First call raises rate limit, second succeeds mock_replicate.exceptions.ReplicateError = Exception mock_replicate.run.side_effect = [ Exception("Rate limit exceeded 429"), iter(["Success"]) ] client = ReplicateClient() response = client.generate("Test") assert response.text == "Success" assert mock_sleep.called # Verify backoff happened @patch('app.ai.replicate_client.get_config') @patch('app.ai.replicate_client.replicate') @patch('app.ai.replicate_client.time.sleep') def test_max_retries_exceeded(self, mock_sleep, mock_replicate, mock_config): """Test that max retries raises error.""" mock_config.return_value = MagicMock( replicate_api_token="token", REPLICATE_MODEL=None ) # All calls fail with rate limit mock_replicate.exceptions.ReplicateError = Exception mock_replicate.run.side_effect = Exception("Rate limit exceeded 429") client = ReplicateClient() with pytest.raises(ReplicateRateLimitError): client.generate("Test") # Should have retried MAX_RETRIES times assert mock_replicate.run.call_count == ReplicateClient.MAX_RETRIES @patch('app.ai.replicate_client.get_config') @patch('app.ai.replicate_client.replicate') def test_timeout_error(self, mock_replicate, mock_config): """Test timeout error handling.""" mock_config.return_value = MagicMock( replicate_api_token="token", REPLICATE_MODEL=None ) mock_replicate.exceptions.ReplicateError = Exception mock_replicate.run.side_effect = Exception("Request timeout") client = ReplicateClient() with pytest.raises(ReplicateTimeoutError): client.generate("Test") @patch('app.ai.replicate_client.get_config') @patch('app.ai.replicate_client.replicate') def test_api_error(self, mock_replicate, mock_config): """Test generic API error handling.""" mock_config.return_value = MagicMock( replicate_api_token="token", REPLICATE_MODEL=None ) mock_replicate.exceptions.ReplicateError = Exception mock_replicate.run.side_effect = Exception("Invalid model") client = ReplicateClient() with pytest.raises(ReplicateAPIError): client.generate("Test") class TestValidation: """Tests for API key validation.""" @patch('app.ai.replicate_client.get_config') @patch('app.ai.replicate_client.replicate') def test_validate_api_key_success(self, mock_replicate, mock_config): """Test successful API key validation.""" mock_config.return_value = MagicMock( replicate_api_token="valid_token", REPLICATE_MODEL=None ) mock_replicate.models.get.return_value = MagicMock() client = ReplicateClient() assert client.validate_api_key() is True @patch('app.ai.replicate_client.get_config') @patch('app.ai.replicate_client.replicate') def test_validate_api_key_failure(self, mock_replicate, mock_config): """Test failed API key validation.""" mock_config.return_value = MagicMock( replicate_api_token="invalid_token", REPLICATE_MODEL=None ) mock_replicate.models.get.side_effect = Exception("Invalid API token") client = ReplicateClient() assert client.validate_api_key() is False class TestResponseDataclass: """Tests for ReplicateResponse dataclass.""" def test_response_creation(self): """Test creating ReplicateResponse.""" response = ReplicateResponse( text="Hello world", tokens_used=50, model="meta/llama-3-8b", generation_time=1.5 ) assert response.text == "Hello world" assert response.tokens_used == 50 assert response.model == "meta/llama-3-8b" assert response.generation_time == 1.5 def test_response_immutability(self): """Test that response fields are accessible.""" response = ReplicateResponse( text="Test", tokens_used=10, model="test", generation_time=0.5 ) # Dataclass should allow attribute access assert hasattr(response, 'text') assert hasattr(response, 'tokens_used') class TestModelType: """Tests for ModelType enum and multi-model support.""" def test_model_type_values(self): """Test ModelType enum has expected values.""" assert ModelType.LLAMA_3_8B.value == "meta/meta-llama-3-8b-instruct" assert ModelType.CLAUDE_HAIKU.value == "anthropic/claude-3.5-haiku" assert ModelType.CLAUDE_SONNET.value == "anthropic/claude-3.5-sonnet" assert ModelType.CLAUDE_SONNET_4.value == "anthropic/claude-sonnet-4" @patch('app.ai.replicate_client.get_config') def test_init_with_model_type_enum(self, mock_config): """Test initialization with ModelType enum.""" mock_config.return_value = MagicMock( replicate_api_token="token", REPLICATE_MODEL=None ) client = ReplicateClient(model=ModelType.CLAUDE_HAIKU) assert client.model == "anthropic/claude-3.5-haiku" assert client.model_type == ModelType.CLAUDE_HAIKU @patch('app.ai.replicate_client.get_config') def test_is_claude_model(self, mock_config): """Test _is_claude_model helper.""" mock_config.return_value = MagicMock( replicate_api_token="token", REPLICATE_MODEL=None ) # Llama model client = ReplicateClient(model=ModelType.LLAMA_3_8B) assert client._is_claude_model() is False # Claude models client = ReplicateClient(model=ModelType.CLAUDE_HAIKU) assert client._is_claude_model() is True client = ReplicateClient(model=ModelType.CLAUDE_SONNET) assert client._is_claude_model() is True class TestClaudeModels: """Tests for Claude model generation via Replicate.""" @patch('app.ai.replicate_client.get_config') @patch('app.ai.replicate_client.replicate') def test_generate_with_claude_haiku(self, mock_replicate, mock_config): """Test generation with Claude Haiku model.""" mock_config.return_value = MagicMock( replicate_api_token="token", REPLICATE_MODEL=None ) mock_replicate.run.return_value = iter(["Claude ", "response"]) client = ReplicateClient(model=ModelType.CLAUDE_HAIKU) response = client.generate("Test prompt") assert response.text == "Claude response" assert response.model == "anthropic/claude-3.5-haiku" # Verify Claude-style params (not Llama formatted prompt) call_args = mock_replicate.run.call_args assert "prompt" in call_args[1]['input'] # Claude params don't include Llama special tokens assert "<|begin_of_text|>" not in call_args[1]['input']['prompt'] @patch('app.ai.replicate_client.get_config') @patch('app.ai.replicate_client.replicate') def test_generate_with_claude_system_prompt(self, mock_replicate, mock_config): """Test Claude generation includes system_prompt parameter.""" mock_config.return_value = MagicMock( replicate_api_token="token", REPLICATE_MODEL=None ) mock_replicate.run.return_value = iter(["Response"]) client = ReplicateClient(model=ModelType.CLAUDE_SONNET) client.generate( prompt="User message", system_prompt="You are a DM" ) call_args = mock_replicate.run.call_args assert call_args[1]['input']['system_prompt'] == "You are a DM" assert call_args[1]['input']['prompt'] == "User message" @patch('app.ai.replicate_client.get_config') @patch('app.ai.replicate_client.replicate') def test_model_specific_defaults(self, mock_replicate, mock_config): """Test that model-specific defaults are applied.""" mock_config.return_value = MagicMock( replicate_api_token="token", REPLICATE_MODEL=None ) mock_replicate.run.return_value = iter(["Response"]) # Claude Sonnet should use higher max_tokens by default client = ReplicateClient(model=ModelType.CLAUDE_SONNET) client.generate("Test") call_args = mock_replicate.run.call_args # Sonnet default is 1024 tokens assert call_args[1]['input']['max_tokens'] == 1024 assert call_args[1]['input']['temperature'] == 0.9 @patch('app.ai.replicate_client.get_config') @patch('app.ai.replicate_client.replicate') def test_model_override_in_generate(self, mock_replicate, mock_config): """Test overriding model in generate() call.""" mock_config.return_value = MagicMock( replicate_api_token="token", REPLICATE_MODEL=None ) mock_replicate.run.return_value = iter(["Response"]) # Init with Llama, but call with Claude client = ReplicateClient(model=ModelType.LLAMA_3_8B) response = client.generate("Test", model=ModelType.CLAUDE_HAIKU) # Response should reflect the overridden model assert response.model == "anthropic/claude-3.5-haiku" # Verify correct model was called call_args = mock_replicate.run.call_args assert call_args[0][0] == "anthropic/claude-3.5-haiku"