""" Replicate API client for AI model integration. This module provides a client for interacting with the Replicate API to generate text using various models including Llama-3 and Claude models. All AI generation goes through Replicate for unified billing and management. """ import time from dataclasses import dataclass from enum import Enum from typing import Any import replicate import structlog from app.config import get_config logger = structlog.get_logger(__name__) class ModelType(str, Enum): """Supported model types on Replicate.""" # Free tier - Llama models LLAMA_3_8B = "meta/meta-llama-3-8b-instruct" # Paid tiers - Claude models via Replicate CLAUDE_HAIKU = "anthropic/claude-3.5-haiku" CLAUDE_SONNET = "anthropic/claude-3.5-sonnet" CLAUDE_SONNET_4 = "anthropic/claude-4.5-sonnet" @dataclass class ReplicateResponse: """Response from Replicate API generation.""" text: str tokens_used: int # Deprecated: use tokens_output instead tokens_input: int tokens_output: int model: str generation_time: float class ReplicateClientError(Exception): """Base exception for Replicate client errors.""" pass class ReplicateAPIError(ReplicateClientError): """Error from Replicate API.""" pass class ReplicateRateLimitError(ReplicateClientError): """Rate limit exceeded on Replicate API.""" pass class ReplicateTimeoutError(ReplicateClientError): """Timeout waiting for Replicate response.""" pass class ReplicateClient: """ Client for interacting with Replicate API. Supports multiple models including Llama-3 and Claude models. Implements retry logic with exponential backoff for rate limits. """ # Default model for free tier DEFAULT_MODEL = ModelType.LLAMA_3_8B # Retry configuration MAX_RETRIES = 3 INITIAL_RETRY_DELAY = 1.0 # seconds # Default generation parameters DEFAULT_MAX_TOKENS = 256 DEFAULT_TEMPERATURE = 0.7 DEFAULT_TOP_P = 0.9 DEFAULT_TIMEOUT = 30 # seconds # Model-specific defaults MODEL_DEFAULTS = { ModelType.LLAMA_3_8B: { "max_tokens": 256, "temperature": 0.7, }, ModelType.CLAUDE_HAIKU: { "max_tokens": 512, "temperature": 0.8, }, ModelType.CLAUDE_SONNET: { "max_tokens": 1024, "temperature": 0.9, }, ModelType.CLAUDE_SONNET_4: { "max_tokens": 2048, "temperature": 0.9, }, } def __init__(self, api_token: str | None = None, model: str | ModelType | None = None): """ Initialize the Replicate client. Args: api_token: Replicate API token. If not provided, reads from config. model: Model identifier or ModelType enum. Defaults to Llama-3 8B Instruct. Raises: ReplicateClientError: If API token is not configured. """ config = get_config() # Get API token from parameter or config self.api_token = api_token or getattr(config, 'replicate_api_token', None) if not self.api_token: raise ReplicateClientError( "Replicate API token not configured. " "Set REPLICATE_API_TOKEN in environment or config." ) # Get model from parameter, config, or default if model is None: model = getattr(config, 'REPLICATE_MODEL', None) or self.DEFAULT_MODEL # Convert string to ModelType if needed, or keep as string for custom models if isinstance(model, ModelType): self.model = model.value self.model_type = model elif isinstance(model, str): # Try to match to ModelType self.model = model self.model_type = self._get_model_type(model) else: self.model = self.DEFAULT_MODEL.value self.model_type = self.DEFAULT_MODEL # Set the API token for the replicate library import os os.environ['REPLICATE_API_TOKEN'] = self.api_token logger.info( "Replicate client initialized", model=self.model, model_type=self.model_type.name if self.model_type else "custom" ) def _get_model_type(self, model_string: str) -> ModelType | None: """Get ModelType enum from model string.""" for model_type in ModelType: if model_type.value == model_string: return model_type return None def _is_claude_model(self) -> bool: """Check if current model is a Claude model.""" return self.model_type in [ ModelType.CLAUDE_HAIKU, ModelType.CLAUDE_SONNET, ModelType.CLAUDE_SONNET_4 ] def generate( self, prompt: str, system_prompt: str | None = None, max_tokens: int | None = None, temperature: float | None = None, top_p: float | None = None, timeout: int | None = None, model: str | ModelType | None = None ) -> ReplicateResponse: """ Generate text using the configured model. Args: prompt: The user prompt to send to the model. system_prompt: Optional system prompt for context setting. max_tokens: Maximum tokens to generate. Uses model defaults if not specified. temperature: Sampling temperature (0.0-1.0). Uses model defaults if not specified. top_p: Top-p sampling parameter. Defaults to 0.9. timeout: Timeout in seconds. Defaults to 30. model: Override the default model for this request. Returns: ReplicateResponse with generated text and metadata. Raises: ReplicateAPIError: For API errors. ReplicateRateLimitError: When rate limited. ReplicateTimeoutError: When request times out. """ # Handle model override if model: if isinstance(model, ModelType): current_model = model.value current_model_type = model else: current_model = model current_model_type = self._get_model_type(model) else: current_model = self.model current_model_type = self.model_type # Get model-specific defaults model_defaults = self.MODEL_DEFAULTS.get(current_model_type, {}) # Apply defaults (parameter > model default > class default) max_tokens = max_tokens or model_defaults.get("max_tokens", self.DEFAULT_MAX_TOKENS) temperature = temperature or model_defaults.get("temperature", self.DEFAULT_TEMPERATURE) top_p = top_p or self.DEFAULT_TOP_P timeout = timeout or self.DEFAULT_TIMEOUT # Format prompt based on model type is_claude = current_model_type in [ ModelType.CLAUDE_HAIKU, ModelType.CLAUDE_SONNET, ModelType.CLAUDE_SONNET_4 ] if is_claude: input_params = self._build_claude_params( prompt, system_prompt, max_tokens, temperature, top_p ) else: # Llama-style formatting formatted_prompt = self._format_llama_prompt(prompt, system_prompt) input_params = { "prompt": formatted_prompt, "max_tokens": max_tokens, "temperature": temperature, "top_p": top_p, } logger.debug( "Generating text with Replicate", model=current_model, max_tokens=max_tokens, temperature=temperature, is_claude=is_claude ) # Execute with retry logic start_time = time.time() output = self._execute_with_retry(current_model, input_params, timeout) generation_time = time.time() - start_time # Parse response text = self._parse_response(output) # Estimate tokens (rough approximation: ~4 chars per token) # Calculate input tokens from the actual prompt sent prompt_text = input_params.get("prompt", "") system_text = input_params.get("system_prompt", "") total_input_text = prompt_text + system_text tokens_input = len(total_input_text) // 4 # Calculate output tokens from response tokens_output = len(text) // 4 # Total for backwards compatibility tokens_used = tokens_input + tokens_output logger.info( "Replicate generation complete", model=current_model, tokens_input=tokens_input, tokens_output=tokens_output, tokens_used=tokens_used, generation_time=f"{generation_time:.2f}s", response_length=len(text) ) return ReplicateResponse( text=text.strip(), tokens_used=tokens_used, tokens_input=tokens_input, tokens_output=tokens_output, model=current_model, generation_time=generation_time ) def _build_claude_params( self, prompt: str, system_prompt: str | None, max_tokens: int, temperature: float, top_p: float ) -> dict[str, Any]: """ Build input parameters for Claude models on Replicate. Args: prompt: User prompt. system_prompt: Optional system prompt. max_tokens: Maximum tokens to generate. temperature: Sampling temperature. top_p: Top-p sampling parameter. Returns: Dictionary of input parameters for Replicate API. """ params = { "prompt": prompt, "max_tokens": max_tokens, "temperature": temperature, "top_p": top_p, } if system_prompt: params["system_prompt"] = system_prompt return params def _format_llama_prompt(self, prompt: str, system_prompt: str | None = None) -> str: """ Format prompt for Llama-3 Instruct model. Llama-3 Instruct uses a specific format with special tokens. Args: prompt: User prompt. system_prompt: Optional system prompt. Returns: Formatted prompt string. """ parts = [] if system_prompt: parts.append(f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|>") else: parts.append("<|begin_of_text|>") parts.append(f"<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|>") parts.append("<|start_header_id|>assistant<|end_header_id|>\n\n") return "".join(parts) def _parse_response(self, output: Any) -> str: """ Parse response from Replicate API. Handles both streaming iterators and direct string responses. Args: output: Raw output from Replicate API. Returns: Parsed text string. """ if hasattr(output, '__iter__') and not isinstance(output, str): return "".join(output) return str(output) def _execute_with_retry( self, model: str, input_params: dict[str, Any], timeout: int ) -> Any: """ Execute Replicate API call with retry logic. Implements exponential backoff for rate limit errors. Args: model: Model identifier to run. input_params: Input parameters for the model. timeout: Timeout in seconds. Returns: API response output. Raises: ReplicateAPIError: For API errors after retries. ReplicateRateLimitError: When rate limit persists after retries. ReplicateTimeoutError: When request times out. """ last_error = None retry_delay = self.INITIAL_RETRY_DELAY for attempt in range(self.MAX_RETRIES): try: output = replicate.run( model, input=input_params ) return output except replicate.exceptions.ReplicateError as e: error_message = str(e).lower() if "rate limit" in error_message or "429" in error_message: last_error = ReplicateRateLimitError(f"Rate limited: {e}") if attempt < self.MAX_RETRIES - 1: logger.warning( "Rate limited, retrying", attempt=attempt + 1, retry_delay=retry_delay ) time.sleep(retry_delay) retry_delay *= 2 continue else: raise last_error elif "timeout" in error_message: raise ReplicateTimeoutError(f"Request timed out: {e}") else: raise ReplicateAPIError(f"API error: {e}") except Exception as e: error_message = str(e).lower() if "timeout" in error_message: raise ReplicateTimeoutError(f"Request timed out: {e}") raise ReplicateAPIError(f"Unexpected error: {e}") if last_error: raise last_error raise ReplicateAPIError("Max retries exceeded") def validate_api_key(self) -> bool: """ Validate that the API key is valid. Makes a minimal API call to check credentials. Returns: True if API key is valid, False otherwise. """ try: model_name = self.model.split(':')[0] model = replicate.models.get(model_name) return model is not None except Exception as e: logger.warning( "API key validation failed", error=str(e) ) return False