adding ai summary - default is disabled

This commit is contained in:
2026-01-26 17:13:11 -06:00
parent 2820944ec6
commit 921b6a81a4
13 changed files with 1357 additions and 43 deletions

339
replicate_client.py Normal file
View File

@@ -0,0 +1,339 @@
"""
Replicate API client wrapper for the Sneaky Intel Feed Aggregator.
Provides low-level API access with retry logic and error handling
for all Replicate model interactions.
"""
import json
import os
import time
from typing import Any, Dict, List, Optional, Tuple
import replicate
from feed_aggregator.models.config import ReplicateConfig
from feed_aggregator.utils.exceptions import APIError
from feed_aggregator.utils.logging import get_logger
logger = get_logger(__name__)
class ReplicateClient:
"""
Low-level wrapper for Replicate API calls.
Handles:
- API authentication via REPLICATE_API_TOKEN
- Retry logic with exponential backoff
- Error classification and wrapping
- Response normalization
Attributes:
MAX_RETRIES: Maximum retry attempts for failed calls.
INITIAL_BACKOFF_SECONDS: Initial delay before first retry.
BACKOFF_MULTIPLIER: Multiplier for exponential backoff.
LLAMA_MODEL: Model identifier for Llama 3 8B Instruct.
EMBEDDING_MODEL: Model identifier for text embeddings.
Example:
>>> from feed_aggregator.services.replicate_client import ReplicateClient
>>> from feed_aggregator.models.config import ReplicateConfig
>>>
>>> config = ReplicateConfig(api_key="r8_...")
>>> client = ReplicateClient(config)
>>> response, error = client.run_llama("Summarize: ...")
"""
# Retry configuration (matches DatabaseService pattern)
MAX_RETRIES = 3
INITIAL_BACKOFF_SECONDS = 1.0
BACKOFF_MULTIPLIER = 2.0
# Model identifiers
LLAMA_MODEL = "meta/meta-llama-3-8b-instruct"
EMBEDDING_MODEL = "beautyyuyanli/multilingual-e5-large:a06276a89f1a902d5fc225a9ca32b6e8e6292b7f3b136518878da97c458e2bad"
# Default parameters for Llama
DEFAULT_MAX_TOKENS = 500
DEFAULT_TEMPERATURE = 0.3
def __init__(self, config: ReplicateConfig):
"""
Initialize the Replicate client.
Sets the REPLICATE_API_TOKEN environment variable for the
replicate library to use.
Args:
config: Replicate configuration with API key.
"""
self._config = config
self._retry_count = 0 # Track total retries for metrics
# Set API token for replicate library
os.environ["REPLICATE_API_TOKEN"] = config.api_key
logger.debug(
"ReplicateClient initialized",
model=self.LLAMA_MODEL,
)
def run_llama(
self,
prompt: str,
max_tokens: int = DEFAULT_MAX_TOKENS,
temperature: float = DEFAULT_TEMPERATURE,
) -> Tuple[Optional[str], Optional[APIError]]:
"""
Run Llama 3 8B Instruct with the given prompt.
Args:
prompt: The prompt to send to the model.
max_tokens: Maximum tokens in response.
temperature: Sampling temperature (0-1).
Returns:
Tuple of (response_text, error).
On success: (str, None)
On failure: (None, APIError)
"""
return self._run_with_retry(
model=self.LLAMA_MODEL,
input_params={
"prompt": prompt,
"max_tokens": max_tokens,
"temperature": temperature,
},
operation="summarization",
)
# Expected embedding dimensions for snowflake-arctic-embed-l
EMBEDDING_DIMENSIONS = 1024
def run_embedding(
self,
text: str,
) -> Tuple[Optional[List[float]], Optional[APIError]]:
"""
Generate embeddings using snowflake-arctic-embed-l.
Uses the same retry logic as run_llama but preserves the
list response format instead of converting to string.
Args:
text: Text to embed (title + summary, typically 200-500 tokens).
Returns:
Tuple of (embedding_vector, error).
On success: (List[float] with 1024 dimensions, None)
On failure: (None, APIError)
Example:
>>> embedding, error = client.run_embedding("Article about security")
>>> if embedding:
... print(f"Generated {len(embedding)}-dimensional embedding")
"""
last_error: Optional[Exception] = None
backoff = self.INITIAL_BACKOFF_SECONDS
for attempt in range(1, self.MAX_RETRIES + 1):
try:
# E5 model expects texts as a JSON-formatted string
output = replicate.run(
self.EMBEDDING_MODEL,
input={"texts": json.dumps([text])},
)
# Output is a list of embeddings (one per input text)
# We only pass one text, so extract the first embedding
if isinstance(output, list) and len(output) > 0:
embedding = output[0]
# Validate embedding dimensions
if len(embedding) != self.EMBEDDING_DIMENSIONS:
raise ValueError(
f"Expected {self.EMBEDDING_DIMENSIONS} dimensions, "
f"got {len(embedding)}"
)
# Ensure all values are floats
embedding = [float(v) for v in embedding]
logger.debug(
"Replicate embedding succeeded",
model=self.EMBEDDING_MODEL,
operation="embedding",
attempt=attempt,
dimensions=len(embedding),
)
return embedding, None
else:
raise ValueError(
f"Unexpected output format from embedding model: {type(output)}"
)
except Exception as e:
last_error = e
self._retry_count += 1
# Determine if error is retryable based on error message
error_str = str(e).lower()
is_rate_limit = "rate" in error_str or "429" in error_str
is_server_error = (
"500" in error_str
or "502" in error_str
or "503" in error_str
or "504" in error_str
)
is_timeout = "timeout" in error_str
is_retryable = is_rate_limit or is_server_error or is_timeout
if attempt < self.MAX_RETRIES and is_retryable:
logger.warning(
"Replicate embedding failed, retrying",
model=self.EMBEDDING_MODEL,
operation="embedding",
attempt=attempt,
max_retries=self.MAX_RETRIES,
error=str(e),
backoff_seconds=backoff,
)
time.sleep(backoff)
backoff *= self.BACKOFF_MULTIPLIER
else:
# Non-retryable error or max retries reached
break
# All retries exhausted or non-retryable error
logger.error(
"Replicate embedding failed after retries",
model=self.EMBEDDING_MODEL,
operation="embedding",
retry_count=self.MAX_RETRIES,
error=str(last_error),
)
return None, APIError(
message=f"Replicate embedding failed: {str(last_error)}",
api_name="Replicate",
operation="embedding",
retry_count=self.MAX_RETRIES,
context={"model": self.EMBEDDING_MODEL, "error": str(last_error)},
)
def _run_with_retry(
self,
model: str,
input_params: Dict[str, Any],
operation: str,
) -> Tuple[Optional[str], Optional[APIError]]:
"""
Execute model call with retry logic.
Implements exponential backoff retry strategy for transient errors:
- Rate limit errors (429)
- Server errors (5xx)
- Timeout errors
Non-retryable errors (auth failures, etc.) fail immediately.
Args:
model: Replicate model identifier.
input_params: Model input parameters.
operation: Operation name for error reporting.
Returns:
Tuple of (response, error).
"""
last_error: Optional[Exception] = None
backoff = self.INITIAL_BACKOFF_SECONDS
for attempt in range(1, self.MAX_RETRIES + 1):
try:
output = replicate.run(model, input=input_params)
# Replicate returns a generator for streaming models
# Collect all output parts into a single string
if hasattr(output, "__iter__") and not isinstance(output, str):
response = "".join(str(part) for part in output)
else:
response = str(output)
logger.debug(
"Replicate API call succeeded",
model=model,
operation=operation,
attempt=attempt,
response_length=len(response),
)
return response, None
except Exception as e:
last_error = e
self._retry_count += 1
# Determine if error is retryable based on error message
error_str = str(e).lower()
is_rate_limit = "rate" in error_str or "429" in error_str
is_server_error = (
"500" in error_str
or "502" in error_str
or "503" in error_str
or "504" in error_str
)
is_timeout = "timeout" in error_str
is_retryable = is_rate_limit or is_server_error or is_timeout
if attempt < self.MAX_RETRIES and is_retryable:
logger.warning(
"Replicate API call failed, retrying",
model=model,
operation=operation,
attempt=attempt,
max_retries=self.MAX_RETRIES,
error=str(e),
backoff_seconds=backoff,
)
time.sleep(backoff)
backoff *= self.BACKOFF_MULTIPLIER
else:
# Non-retryable error or max retries reached
break
# All retries exhausted or non-retryable error
logger.error(
"Replicate API call failed after retries",
model=model,
operation=operation,
retry_count=self.MAX_RETRIES,
error=str(last_error),
)
return None, APIError(
message=f"Replicate API call failed: {str(last_error)}",
api_name="Replicate",
operation=operation,
retry_count=self.MAX_RETRIES,
context={"model": model, "error": str(last_error)},
)
@property
def total_retries(self) -> int:
"""
Return total retry count across all calls.
Returns:
Total number of retry attempts made.
"""
return self._retry_count
def reset_retry_count(self) -> None:
"""
Reset retry counter.
Useful for metrics collection between batch operations.
"""
self._retry_count = 0