""" Replicate API client wrapper for the Sneaky Intel Feed Aggregator. Provides low-level API access with retry logic and error handling for all Replicate model interactions. """ import json import os import time from typing import Any, Dict, List, Optional, Tuple import replicate from feed_aggregator.models.config import ReplicateConfig from feed_aggregator.utils.exceptions import APIError from feed_aggregator.utils.logging import get_logger logger = get_logger(__name__) class ReplicateClient: """ Low-level wrapper for Replicate API calls. Handles: - API authentication via REPLICATE_API_TOKEN - Retry logic with exponential backoff - Error classification and wrapping - Response normalization Attributes: MAX_RETRIES: Maximum retry attempts for failed calls. INITIAL_BACKOFF_SECONDS: Initial delay before first retry. BACKOFF_MULTIPLIER: Multiplier for exponential backoff. LLAMA_MODEL: Model identifier for Llama 3 8B Instruct. EMBEDDING_MODEL: Model identifier for text embeddings. Example: >>> from feed_aggregator.services.replicate_client import ReplicateClient >>> from feed_aggregator.models.config import ReplicateConfig >>> >>> config = ReplicateConfig(api_key="r8_...") >>> client = ReplicateClient(config) >>> response, error = client.run_llama("Summarize: ...") """ # Retry configuration (matches DatabaseService pattern) MAX_RETRIES = 3 INITIAL_BACKOFF_SECONDS = 1.0 BACKOFF_MULTIPLIER = 2.0 # Model identifiers LLAMA_MODEL = "meta/meta-llama-3-8b-instruct" EMBEDDING_MODEL = "beautyyuyanli/multilingual-e5-large:a06276a89f1a902d5fc225a9ca32b6e8e6292b7f3b136518878da97c458e2bad" # Default parameters for Llama DEFAULT_MAX_TOKENS = 500 DEFAULT_TEMPERATURE = 0.3 def __init__(self, config: ReplicateConfig): """ Initialize the Replicate client. Sets the REPLICATE_API_TOKEN environment variable for the replicate library to use. Args: config: Replicate configuration with API key. """ self._config = config self._retry_count = 0 # Track total retries for metrics # Set API token for replicate library os.environ["REPLICATE_API_TOKEN"] = config.api_key logger.debug( "ReplicateClient initialized", model=self.LLAMA_MODEL, ) def run_llama( self, prompt: str, max_tokens: int = DEFAULT_MAX_TOKENS, temperature: float = DEFAULT_TEMPERATURE, ) -> Tuple[Optional[str], Optional[APIError]]: """ Run Llama 3 8B Instruct with the given prompt. Args: prompt: The prompt to send to the model. max_tokens: Maximum tokens in response. temperature: Sampling temperature (0-1). Returns: Tuple of (response_text, error). On success: (str, None) On failure: (None, APIError) """ return self._run_with_retry( model=self.LLAMA_MODEL, input_params={ "prompt": prompt, "max_tokens": max_tokens, "temperature": temperature, }, operation="summarization", ) # Expected embedding dimensions for snowflake-arctic-embed-l EMBEDDING_DIMENSIONS = 1024 def run_embedding( self, text: str, ) -> Tuple[Optional[List[float]], Optional[APIError]]: """ Generate embeddings using snowflake-arctic-embed-l. Uses the same retry logic as run_llama but preserves the list response format instead of converting to string. Args: text: Text to embed (title + summary, typically 200-500 tokens). Returns: Tuple of (embedding_vector, error). On success: (List[float] with 1024 dimensions, None) On failure: (None, APIError) Example: >>> embedding, error = client.run_embedding("Article about security") >>> if embedding: ... print(f"Generated {len(embedding)}-dimensional embedding") """ last_error: Optional[Exception] = None backoff = self.INITIAL_BACKOFF_SECONDS for attempt in range(1, self.MAX_RETRIES + 1): try: # E5 model expects texts as a JSON-formatted string output = replicate.run( self.EMBEDDING_MODEL, input={"texts": json.dumps([text])}, ) # Output is a list of embeddings (one per input text) # We only pass one text, so extract the first embedding if isinstance(output, list) and len(output) > 0: embedding = output[0] # Validate embedding dimensions if len(embedding) != self.EMBEDDING_DIMENSIONS: raise ValueError( f"Expected {self.EMBEDDING_DIMENSIONS} dimensions, " f"got {len(embedding)}" ) # Ensure all values are floats embedding = [float(v) for v in embedding] logger.debug( "Replicate embedding succeeded", model=self.EMBEDDING_MODEL, operation="embedding", attempt=attempt, dimensions=len(embedding), ) return embedding, None else: raise ValueError( f"Unexpected output format from embedding model: {type(output)}" ) except Exception as e: last_error = e self._retry_count += 1 # Determine if error is retryable based on error message error_str = str(e).lower() is_rate_limit = "rate" in error_str or "429" in error_str is_server_error = ( "500" in error_str or "502" in error_str or "503" in error_str or "504" in error_str ) is_timeout = "timeout" in error_str is_retryable = is_rate_limit or is_server_error or is_timeout if attempt < self.MAX_RETRIES and is_retryable: logger.warning( "Replicate embedding failed, retrying", model=self.EMBEDDING_MODEL, operation="embedding", attempt=attempt, max_retries=self.MAX_RETRIES, error=str(e), backoff_seconds=backoff, ) time.sleep(backoff) backoff *= self.BACKOFF_MULTIPLIER else: # Non-retryable error or max retries reached break # All retries exhausted or non-retryable error logger.error( "Replicate embedding failed after retries", model=self.EMBEDDING_MODEL, operation="embedding", retry_count=self.MAX_RETRIES, error=str(last_error), ) return None, APIError( message=f"Replicate embedding failed: {str(last_error)}", api_name="Replicate", operation="embedding", retry_count=self.MAX_RETRIES, context={"model": self.EMBEDDING_MODEL, "error": str(last_error)}, ) def _run_with_retry( self, model: str, input_params: Dict[str, Any], operation: str, ) -> Tuple[Optional[str], Optional[APIError]]: """ Execute model call with retry logic. Implements exponential backoff retry strategy for transient errors: - Rate limit errors (429) - Server errors (5xx) - Timeout errors Non-retryable errors (auth failures, etc.) fail immediately. Args: model: Replicate model identifier. input_params: Model input parameters. operation: Operation name for error reporting. Returns: Tuple of (response, error). """ last_error: Optional[Exception] = None backoff = self.INITIAL_BACKOFF_SECONDS for attempt in range(1, self.MAX_RETRIES + 1): try: output = replicate.run(model, input=input_params) # Replicate returns a generator for streaming models # Collect all output parts into a single string if hasattr(output, "__iter__") and not isinstance(output, str): response = "".join(str(part) for part in output) else: response = str(output) logger.debug( "Replicate API call succeeded", model=model, operation=operation, attempt=attempt, response_length=len(response), ) return response, None except Exception as e: last_error = e self._retry_count += 1 # Determine if error is retryable based on error message error_str = str(e).lower() is_rate_limit = "rate" in error_str or "429" in error_str is_server_error = ( "500" in error_str or "502" in error_str or "503" in error_str or "504" in error_str ) is_timeout = "timeout" in error_str is_retryable = is_rate_limit or is_server_error or is_timeout if attempt < self.MAX_RETRIES and is_retryable: logger.warning( "Replicate API call failed, retrying", model=model, operation=operation, attempt=attempt, max_retries=self.MAX_RETRIES, error=str(e), backoff_seconds=backoff, ) time.sleep(backoff) backoff *= self.BACKOFF_MULTIPLIER else: # Non-retryable error or max retries reached break # All retries exhausted or non-retryable error logger.error( "Replicate API call failed after retries", model=model, operation=operation, retry_count=self.MAX_RETRIES, error=str(last_error), ) return None, APIError( message=f"Replicate API call failed: {str(last_error)}", api_name="Replicate", operation=operation, retry_count=self.MAX_RETRIES, context={"model": model, "error": str(last_error)}, ) @property def total_retries(self) -> int: """ Return total retry count across all calls. Returns: Total number of retry attempts made. """ return self._retry_count def reset_retry_count(self) -> None: """ Reset retry counter. Useful for metrics collection between batch operations. """ self._retry_count = 0