adding ai summary - default is disabled
This commit is contained in:
339
replicate_client.py
Normal file
339
replicate_client.py
Normal file
@@ -0,0 +1,339 @@
|
||||
"""
|
||||
Replicate API client wrapper for the Sneaky Intel Feed Aggregator.
|
||||
|
||||
Provides low-level API access with retry logic and error handling
|
||||
for all Replicate model interactions.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import replicate
|
||||
|
||||
from feed_aggregator.models.config import ReplicateConfig
|
||||
from feed_aggregator.utils.exceptions import APIError
|
||||
from feed_aggregator.utils.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ReplicateClient:
|
||||
"""
|
||||
Low-level wrapper for Replicate API calls.
|
||||
|
||||
Handles:
|
||||
- API authentication via REPLICATE_API_TOKEN
|
||||
- Retry logic with exponential backoff
|
||||
- Error classification and wrapping
|
||||
- Response normalization
|
||||
|
||||
Attributes:
|
||||
MAX_RETRIES: Maximum retry attempts for failed calls.
|
||||
INITIAL_BACKOFF_SECONDS: Initial delay before first retry.
|
||||
BACKOFF_MULTIPLIER: Multiplier for exponential backoff.
|
||||
LLAMA_MODEL: Model identifier for Llama 3 8B Instruct.
|
||||
EMBEDDING_MODEL: Model identifier for text embeddings.
|
||||
|
||||
Example:
|
||||
>>> from feed_aggregator.services.replicate_client import ReplicateClient
|
||||
>>> from feed_aggregator.models.config import ReplicateConfig
|
||||
>>>
|
||||
>>> config = ReplicateConfig(api_key="r8_...")
|
||||
>>> client = ReplicateClient(config)
|
||||
>>> response, error = client.run_llama("Summarize: ...")
|
||||
"""
|
||||
|
||||
# Retry configuration (matches DatabaseService pattern)
|
||||
MAX_RETRIES = 3
|
||||
INITIAL_BACKOFF_SECONDS = 1.0
|
||||
BACKOFF_MULTIPLIER = 2.0
|
||||
|
||||
# Model identifiers
|
||||
LLAMA_MODEL = "meta/meta-llama-3-8b-instruct"
|
||||
EMBEDDING_MODEL = "beautyyuyanli/multilingual-e5-large:a06276a89f1a902d5fc225a9ca32b6e8e6292b7f3b136518878da97c458e2bad"
|
||||
|
||||
# Default parameters for Llama
|
||||
DEFAULT_MAX_TOKENS = 500
|
||||
DEFAULT_TEMPERATURE = 0.3
|
||||
|
||||
def __init__(self, config: ReplicateConfig):
|
||||
"""
|
||||
Initialize the Replicate client.
|
||||
|
||||
Sets the REPLICATE_API_TOKEN environment variable for the
|
||||
replicate library to use.
|
||||
|
||||
Args:
|
||||
config: Replicate configuration with API key.
|
||||
"""
|
||||
self._config = config
|
||||
self._retry_count = 0 # Track total retries for metrics
|
||||
|
||||
# Set API token for replicate library
|
||||
os.environ["REPLICATE_API_TOKEN"] = config.api_key
|
||||
|
||||
logger.debug(
|
||||
"ReplicateClient initialized",
|
||||
model=self.LLAMA_MODEL,
|
||||
)
|
||||
|
||||
def run_llama(
|
||||
self,
|
||||
prompt: str,
|
||||
max_tokens: int = DEFAULT_MAX_TOKENS,
|
||||
temperature: float = DEFAULT_TEMPERATURE,
|
||||
) -> Tuple[Optional[str], Optional[APIError]]:
|
||||
"""
|
||||
Run Llama 3 8B Instruct with the given prompt.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send to the model.
|
||||
max_tokens: Maximum tokens in response.
|
||||
temperature: Sampling temperature (0-1).
|
||||
|
||||
Returns:
|
||||
Tuple of (response_text, error).
|
||||
On success: (str, None)
|
||||
On failure: (None, APIError)
|
||||
"""
|
||||
return self._run_with_retry(
|
||||
model=self.LLAMA_MODEL,
|
||||
input_params={
|
||||
"prompt": prompt,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
},
|
||||
operation="summarization",
|
||||
)
|
||||
|
||||
# Expected embedding dimensions for snowflake-arctic-embed-l
|
||||
EMBEDDING_DIMENSIONS = 1024
|
||||
|
||||
def run_embedding(
|
||||
self,
|
||||
text: str,
|
||||
) -> Tuple[Optional[List[float]], Optional[APIError]]:
|
||||
"""
|
||||
Generate embeddings using snowflake-arctic-embed-l.
|
||||
|
||||
Uses the same retry logic as run_llama but preserves the
|
||||
list response format instead of converting to string.
|
||||
|
||||
Args:
|
||||
text: Text to embed (title + summary, typically 200-500 tokens).
|
||||
|
||||
Returns:
|
||||
Tuple of (embedding_vector, error).
|
||||
On success: (List[float] with 1024 dimensions, None)
|
||||
On failure: (None, APIError)
|
||||
|
||||
Example:
|
||||
>>> embedding, error = client.run_embedding("Article about security")
|
||||
>>> if embedding:
|
||||
... print(f"Generated {len(embedding)}-dimensional embedding")
|
||||
"""
|
||||
last_error: Optional[Exception] = None
|
||||
backoff = self.INITIAL_BACKOFF_SECONDS
|
||||
|
||||
for attempt in range(1, self.MAX_RETRIES + 1):
|
||||
try:
|
||||
# E5 model expects texts as a JSON-formatted string
|
||||
output = replicate.run(
|
||||
self.EMBEDDING_MODEL,
|
||||
input={"texts": json.dumps([text])},
|
||||
)
|
||||
|
||||
# Output is a list of embeddings (one per input text)
|
||||
# We only pass one text, so extract the first embedding
|
||||
if isinstance(output, list) and len(output) > 0:
|
||||
embedding = output[0]
|
||||
|
||||
# Validate embedding dimensions
|
||||
if len(embedding) != self.EMBEDDING_DIMENSIONS:
|
||||
raise ValueError(
|
||||
f"Expected {self.EMBEDDING_DIMENSIONS} dimensions, "
|
||||
f"got {len(embedding)}"
|
||||
)
|
||||
|
||||
# Ensure all values are floats
|
||||
embedding = [float(v) for v in embedding]
|
||||
|
||||
logger.debug(
|
||||
"Replicate embedding succeeded",
|
||||
model=self.EMBEDDING_MODEL,
|
||||
operation="embedding",
|
||||
attempt=attempt,
|
||||
dimensions=len(embedding),
|
||||
)
|
||||
|
||||
return embedding, None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected output format from embedding model: {type(output)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
self._retry_count += 1
|
||||
|
||||
# Determine if error is retryable based on error message
|
||||
error_str = str(e).lower()
|
||||
is_rate_limit = "rate" in error_str or "429" in error_str
|
||||
is_server_error = (
|
||||
"500" in error_str
|
||||
or "502" in error_str
|
||||
or "503" in error_str
|
||||
or "504" in error_str
|
||||
)
|
||||
is_timeout = "timeout" in error_str
|
||||
is_retryable = is_rate_limit or is_server_error or is_timeout
|
||||
|
||||
if attempt < self.MAX_RETRIES and is_retryable:
|
||||
logger.warning(
|
||||
"Replicate embedding failed, retrying",
|
||||
model=self.EMBEDDING_MODEL,
|
||||
operation="embedding",
|
||||
attempt=attempt,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
error=str(e),
|
||||
backoff_seconds=backoff,
|
||||
)
|
||||
time.sleep(backoff)
|
||||
backoff *= self.BACKOFF_MULTIPLIER
|
||||
else:
|
||||
# Non-retryable error or max retries reached
|
||||
break
|
||||
|
||||
# All retries exhausted or non-retryable error
|
||||
logger.error(
|
||||
"Replicate embedding failed after retries",
|
||||
model=self.EMBEDDING_MODEL,
|
||||
operation="embedding",
|
||||
retry_count=self.MAX_RETRIES,
|
||||
error=str(last_error),
|
||||
)
|
||||
|
||||
return None, APIError(
|
||||
message=f"Replicate embedding failed: {str(last_error)}",
|
||||
api_name="Replicate",
|
||||
operation="embedding",
|
||||
retry_count=self.MAX_RETRIES,
|
||||
context={"model": self.EMBEDDING_MODEL, "error": str(last_error)},
|
||||
)
|
||||
|
||||
def _run_with_retry(
|
||||
self,
|
||||
model: str,
|
||||
input_params: Dict[str, Any],
|
||||
operation: str,
|
||||
) -> Tuple[Optional[str], Optional[APIError]]:
|
||||
"""
|
||||
Execute model call with retry logic.
|
||||
|
||||
Implements exponential backoff retry strategy for transient errors:
|
||||
- Rate limit errors (429)
|
||||
- Server errors (5xx)
|
||||
- Timeout errors
|
||||
|
||||
Non-retryable errors (auth failures, etc.) fail immediately.
|
||||
|
||||
Args:
|
||||
model: Replicate model identifier.
|
||||
input_params: Model input parameters.
|
||||
operation: Operation name for error reporting.
|
||||
|
||||
Returns:
|
||||
Tuple of (response, error).
|
||||
"""
|
||||
last_error: Optional[Exception] = None
|
||||
backoff = self.INITIAL_BACKOFF_SECONDS
|
||||
|
||||
for attempt in range(1, self.MAX_RETRIES + 1):
|
||||
try:
|
||||
output = replicate.run(model, input=input_params)
|
||||
|
||||
# Replicate returns a generator for streaming models
|
||||
# Collect all output parts into a single string
|
||||
if hasattr(output, "__iter__") and not isinstance(output, str):
|
||||
response = "".join(str(part) for part in output)
|
||||
else:
|
||||
response = str(output)
|
||||
|
||||
logger.debug(
|
||||
"Replicate API call succeeded",
|
||||
model=model,
|
||||
operation=operation,
|
||||
attempt=attempt,
|
||||
response_length=len(response),
|
||||
)
|
||||
|
||||
return response, None
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
self._retry_count += 1
|
||||
|
||||
# Determine if error is retryable based on error message
|
||||
error_str = str(e).lower()
|
||||
is_rate_limit = "rate" in error_str or "429" in error_str
|
||||
is_server_error = (
|
||||
"500" in error_str
|
||||
or "502" in error_str
|
||||
or "503" in error_str
|
||||
or "504" in error_str
|
||||
)
|
||||
is_timeout = "timeout" in error_str
|
||||
is_retryable = is_rate_limit or is_server_error or is_timeout
|
||||
|
||||
if attempt < self.MAX_RETRIES and is_retryable:
|
||||
logger.warning(
|
||||
"Replicate API call failed, retrying",
|
||||
model=model,
|
||||
operation=operation,
|
||||
attempt=attempt,
|
||||
max_retries=self.MAX_RETRIES,
|
||||
error=str(e),
|
||||
backoff_seconds=backoff,
|
||||
)
|
||||
time.sleep(backoff)
|
||||
backoff *= self.BACKOFF_MULTIPLIER
|
||||
else:
|
||||
# Non-retryable error or max retries reached
|
||||
break
|
||||
|
||||
# All retries exhausted or non-retryable error
|
||||
logger.error(
|
||||
"Replicate API call failed after retries",
|
||||
model=model,
|
||||
operation=operation,
|
||||
retry_count=self.MAX_RETRIES,
|
||||
error=str(last_error),
|
||||
)
|
||||
|
||||
return None, APIError(
|
||||
message=f"Replicate API call failed: {str(last_error)}",
|
||||
api_name="Replicate",
|
||||
operation=operation,
|
||||
retry_count=self.MAX_RETRIES,
|
||||
context={"model": model, "error": str(last_error)},
|
||||
)
|
||||
|
||||
@property
|
||||
def total_retries(self) -> int:
|
||||
"""
|
||||
Return total retry count across all calls.
|
||||
|
||||
Returns:
|
||||
Total number of retry attempts made.
|
||||
"""
|
||||
return self._retry_count
|
||||
|
||||
def reset_retry_count(self) -> None:
|
||||
"""
|
||||
Reset retry counter.
|
||||
|
||||
Useful for metrics collection between batch operations.
|
||||
"""
|
||||
self._retry_count = 0
|
||||
Reference in New Issue
Block a user