451 lines
14 KiB
Python
451 lines
14 KiB
Python
"""
|
|
Replicate API client for AI model integration.
|
|
|
|
This module provides a client for interacting with the Replicate API
|
|
to generate text using various models including Llama-3 and Claude models.
|
|
All AI generation goes through Replicate for unified billing and management.
|
|
"""
|
|
|
|
import time
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from typing import Any
|
|
|
|
import replicate
|
|
import structlog
|
|
|
|
from app.config import get_config
|
|
|
|
logger = structlog.get_logger(__name__)
|
|
|
|
|
|
class ModelType(str, Enum):
|
|
"""Supported model types on Replicate."""
|
|
# Free tier - Llama models
|
|
LLAMA_3_8B = "meta/meta-llama-3-8b-instruct"
|
|
|
|
# Paid tiers - Claude models via Replicate
|
|
CLAUDE_HAIKU = "anthropic/claude-3.5-haiku"
|
|
CLAUDE_SONNET = "anthropic/claude-3.5-sonnet"
|
|
CLAUDE_SONNET_4 = "anthropic/claude-4.5-sonnet"
|
|
|
|
|
|
@dataclass
|
|
class ReplicateResponse:
|
|
"""Response from Replicate API generation."""
|
|
text: str
|
|
tokens_used: int # Deprecated: use tokens_output instead
|
|
tokens_input: int
|
|
tokens_output: int
|
|
model: str
|
|
generation_time: float
|
|
|
|
|
|
class ReplicateClientError(Exception):
|
|
"""Base exception for Replicate client errors."""
|
|
pass
|
|
|
|
|
|
class ReplicateAPIError(ReplicateClientError):
|
|
"""Error from Replicate API."""
|
|
pass
|
|
|
|
|
|
class ReplicateRateLimitError(ReplicateClientError):
|
|
"""Rate limit exceeded on Replicate API."""
|
|
pass
|
|
|
|
|
|
class ReplicateTimeoutError(ReplicateClientError):
|
|
"""Timeout waiting for Replicate response."""
|
|
pass
|
|
|
|
|
|
class ReplicateClient:
|
|
"""
|
|
Client for interacting with Replicate API.
|
|
|
|
Supports multiple models including Llama-3 and Claude models.
|
|
Implements retry logic with exponential backoff for rate limits.
|
|
"""
|
|
|
|
# Default model for free tier
|
|
DEFAULT_MODEL = ModelType.LLAMA_3_8B
|
|
|
|
# Retry configuration
|
|
MAX_RETRIES = 3
|
|
INITIAL_RETRY_DELAY = 1.0 # seconds
|
|
|
|
# Default generation parameters
|
|
DEFAULT_MAX_TOKENS = 256
|
|
DEFAULT_TEMPERATURE = 0.7
|
|
DEFAULT_TOP_P = 0.9
|
|
DEFAULT_TIMEOUT = 30 # seconds
|
|
|
|
# Model-specific defaults
|
|
MODEL_DEFAULTS = {
|
|
ModelType.LLAMA_3_8B: {
|
|
"max_tokens": 256,
|
|
"temperature": 0.7,
|
|
},
|
|
ModelType.CLAUDE_HAIKU: {
|
|
"max_tokens": 512,
|
|
"temperature": 0.8,
|
|
},
|
|
ModelType.CLAUDE_SONNET: {
|
|
"max_tokens": 1024,
|
|
"temperature": 0.9,
|
|
},
|
|
ModelType.CLAUDE_SONNET_4: {
|
|
"max_tokens": 2048,
|
|
"temperature": 0.9,
|
|
},
|
|
}
|
|
|
|
def __init__(self, api_token: str | None = None, model: str | ModelType | None = None):
|
|
"""
|
|
Initialize the Replicate client.
|
|
|
|
Args:
|
|
api_token: Replicate API token. If not provided, reads from config.
|
|
model: Model identifier or ModelType enum. Defaults to Llama-3 8B Instruct.
|
|
|
|
Raises:
|
|
ReplicateClientError: If API token is not configured.
|
|
"""
|
|
config = get_config()
|
|
|
|
# Get API token from parameter or config
|
|
self.api_token = api_token or getattr(config, 'replicate_api_token', None)
|
|
if not self.api_token:
|
|
raise ReplicateClientError(
|
|
"Replicate API token not configured. "
|
|
"Set REPLICATE_API_TOKEN in environment or config."
|
|
)
|
|
|
|
# Get model from parameter, config, or default
|
|
if model is None:
|
|
model = getattr(config, 'REPLICATE_MODEL', None) or self.DEFAULT_MODEL
|
|
|
|
# Convert string to ModelType if needed, or keep as string for custom models
|
|
if isinstance(model, ModelType):
|
|
self.model = model.value
|
|
self.model_type = model
|
|
elif isinstance(model, str):
|
|
# Try to match to ModelType
|
|
self.model = model
|
|
self.model_type = self._get_model_type(model)
|
|
else:
|
|
self.model = self.DEFAULT_MODEL.value
|
|
self.model_type = self.DEFAULT_MODEL
|
|
|
|
# Set the API token for the replicate library
|
|
import os
|
|
os.environ['REPLICATE_API_TOKEN'] = self.api_token
|
|
|
|
logger.info(
|
|
"Replicate client initialized",
|
|
model=self.model,
|
|
model_type=self.model_type.name if self.model_type else "custom"
|
|
)
|
|
|
|
def _get_model_type(self, model_string: str) -> ModelType | None:
|
|
"""Get ModelType enum from model string."""
|
|
for model_type in ModelType:
|
|
if model_type.value == model_string:
|
|
return model_type
|
|
return None
|
|
|
|
def _is_claude_model(self) -> bool:
|
|
"""Check if current model is a Claude model."""
|
|
return self.model_type in [
|
|
ModelType.CLAUDE_HAIKU,
|
|
ModelType.CLAUDE_SONNET,
|
|
ModelType.CLAUDE_SONNET_4
|
|
]
|
|
|
|
def generate(
|
|
self,
|
|
prompt: str,
|
|
system_prompt: str | None = None,
|
|
max_tokens: int | None = None,
|
|
temperature: float | None = None,
|
|
top_p: float | None = None,
|
|
timeout: int | None = None,
|
|
model: str | ModelType | None = None
|
|
) -> ReplicateResponse:
|
|
"""
|
|
Generate text using the configured model.
|
|
|
|
Args:
|
|
prompt: The user prompt to send to the model.
|
|
system_prompt: Optional system prompt for context setting.
|
|
max_tokens: Maximum tokens to generate. Uses model defaults if not specified.
|
|
temperature: Sampling temperature (0.0-1.0). Uses model defaults if not specified.
|
|
top_p: Top-p sampling parameter. Defaults to 0.9.
|
|
timeout: Timeout in seconds. Defaults to 30.
|
|
model: Override the default model for this request.
|
|
|
|
Returns:
|
|
ReplicateResponse with generated text and metadata.
|
|
|
|
Raises:
|
|
ReplicateAPIError: For API errors.
|
|
ReplicateRateLimitError: When rate limited.
|
|
ReplicateTimeoutError: When request times out.
|
|
"""
|
|
# Handle model override
|
|
if model:
|
|
if isinstance(model, ModelType):
|
|
current_model = model.value
|
|
current_model_type = model
|
|
else:
|
|
current_model = model
|
|
current_model_type = self._get_model_type(model)
|
|
else:
|
|
current_model = self.model
|
|
current_model_type = self.model_type
|
|
|
|
# Get model-specific defaults
|
|
model_defaults = self.MODEL_DEFAULTS.get(current_model_type, {})
|
|
|
|
# Apply defaults (parameter > model default > class default)
|
|
max_tokens = max_tokens or model_defaults.get("max_tokens", self.DEFAULT_MAX_TOKENS)
|
|
temperature = temperature or model_defaults.get("temperature", self.DEFAULT_TEMPERATURE)
|
|
top_p = top_p or self.DEFAULT_TOP_P
|
|
timeout = timeout or self.DEFAULT_TIMEOUT
|
|
|
|
# Format prompt based on model type
|
|
is_claude = current_model_type in [
|
|
ModelType.CLAUDE_HAIKU,
|
|
ModelType.CLAUDE_SONNET,
|
|
ModelType.CLAUDE_SONNET_4
|
|
]
|
|
|
|
if is_claude:
|
|
input_params = self._build_claude_params(
|
|
prompt, system_prompt, max_tokens, temperature, top_p
|
|
)
|
|
else:
|
|
# Llama-style formatting
|
|
formatted_prompt = self._format_llama_prompt(prompt, system_prompt)
|
|
input_params = {
|
|
"prompt": formatted_prompt,
|
|
"max_tokens": max_tokens,
|
|
"temperature": temperature,
|
|
"top_p": top_p,
|
|
}
|
|
|
|
logger.debug(
|
|
"Generating text with Replicate",
|
|
model=current_model,
|
|
max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
is_claude=is_claude
|
|
)
|
|
|
|
# Execute with retry logic
|
|
start_time = time.time()
|
|
output = self._execute_with_retry(current_model, input_params, timeout)
|
|
generation_time = time.time() - start_time
|
|
|
|
# Parse response
|
|
text = self._parse_response(output)
|
|
|
|
# Estimate tokens (rough approximation: ~4 chars per token)
|
|
# Calculate input tokens from the actual prompt sent
|
|
prompt_text = input_params.get("prompt", "")
|
|
system_text = input_params.get("system_prompt", "")
|
|
total_input_text = prompt_text + system_text
|
|
tokens_input = len(total_input_text) // 4
|
|
|
|
# Calculate output tokens from response
|
|
tokens_output = len(text) // 4
|
|
|
|
# Total for backwards compatibility
|
|
tokens_used = tokens_input + tokens_output
|
|
|
|
logger.info(
|
|
"Replicate generation complete",
|
|
model=current_model,
|
|
tokens_input=tokens_input,
|
|
tokens_output=tokens_output,
|
|
tokens_used=tokens_used,
|
|
generation_time=f"{generation_time:.2f}s",
|
|
response_length=len(text)
|
|
)
|
|
|
|
return ReplicateResponse(
|
|
text=text.strip(),
|
|
tokens_used=tokens_used,
|
|
tokens_input=tokens_input,
|
|
tokens_output=tokens_output,
|
|
model=current_model,
|
|
generation_time=generation_time
|
|
)
|
|
|
|
def _build_claude_params(
|
|
self,
|
|
prompt: str,
|
|
system_prompt: str | None,
|
|
max_tokens: int,
|
|
temperature: float,
|
|
top_p: float
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Build input parameters for Claude models on Replicate.
|
|
|
|
Args:
|
|
prompt: User prompt.
|
|
system_prompt: Optional system prompt.
|
|
max_tokens: Maximum tokens to generate.
|
|
temperature: Sampling temperature.
|
|
top_p: Top-p sampling parameter.
|
|
|
|
Returns:
|
|
Dictionary of input parameters for Replicate API.
|
|
"""
|
|
params = {
|
|
"prompt": prompt,
|
|
"max_tokens": max_tokens,
|
|
"temperature": temperature,
|
|
"top_p": top_p,
|
|
}
|
|
|
|
if system_prompt:
|
|
params["system_prompt"] = system_prompt
|
|
|
|
return params
|
|
|
|
def _format_llama_prompt(self, prompt: str, system_prompt: str | None = None) -> str:
|
|
"""
|
|
Format prompt for Llama-3 Instruct model.
|
|
|
|
Llama-3 Instruct uses a specific format with special tokens.
|
|
|
|
Args:
|
|
prompt: User prompt.
|
|
system_prompt: Optional system prompt.
|
|
|
|
Returns:
|
|
Formatted prompt string.
|
|
"""
|
|
parts = []
|
|
|
|
if system_prompt:
|
|
parts.append(f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|>")
|
|
else:
|
|
parts.append("<|begin_of_text|>")
|
|
|
|
parts.append(f"<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|>")
|
|
parts.append("<|start_header_id|>assistant<|end_header_id|>\n\n")
|
|
|
|
return "".join(parts)
|
|
|
|
def _parse_response(self, output: Any) -> str:
|
|
"""
|
|
Parse response from Replicate API.
|
|
|
|
Handles both streaming iterators and direct string responses.
|
|
|
|
Args:
|
|
output: Raw output from Replicate API.
|
|
|
|
Returns:
|
|
Parsed text string.
|
|
"""
|
|
if hasattr(output, '__iter__') and not isinstance(output, str):
|
|
return "".join(output)
|
|
return str(output)
|
|
|
|
def _execute_with_retry(
|
|
self,
|
|
model: str,
|
|
input_params: dict[str, Any],
|
|
timeout: int
|
|
) -> Any:
|
|
"""
|
|
Execute Replicate API call with retry logic.
|
|
|
|
Implements exponential backoff for rate limit errors.
|
|
|
|
Args:
|
|
model: Model identifier to run.
|
|
input_params: Input parameters for the model.
|
|
timeout: Timeout in seconds.
|
|
|
|
Returns:
|
|
API response output.
|
|
|
|
Raises:
|
|
ReplicateAPIError: For API errors after retries.
|
|
ReplicateRateLimitError: When rate limit persists after retries.
|
|
ReplicateTimeoutError: When request times out.
|
|
"""
|
|
last_error = None
|
|
retry_delay = self.INITIAL_RETRY_DELAY
|
|
|
|
for attempt in range(self.MAX_RETRIES):
|
|
try:
|
|
output = replicate.run(
|
|
model,
|
|
input=input_params
|
|
)
|
|
return output
|
|
|
|
except replicate.exceptions.ReplicateError as e:
|
|
error_message = str(e).lower()
|
|
|
|
if "rate limit" in error_message or "429" in error_message:
|
|
last_error = ReplicateRateLimitError(f"Rate limited: {e}")
|
|
|
|
if attempt < self.MAX_RETRIES - 1:
|
|
logger.warning(
|
|
"Rate limited, retrying",
|
|
attempt=attempt + 1,
|
|
retry_delay=retry_delay
|
|
)
|
|
time.sleep(retry_delay)
|
|
retry_delay *= 2
|
|
continue
|
|
else:
|
|
raise last_error
|
|
|
|
elif "timeout" in error_message:
|
|
raise ReplicateTimeoutError(f"Request timed out: {e}")
|
|
|
|
else:
|
|
raise ReplicateAPIError(f"API error: {e}")
|
|
|
|
except Exception as e:
|
|
error_message = str(e).lower()
|
|
|
|
if "timeout" in error_message:
|
|
raise ReplicateTimeoutError(f"Request timed out: {e}")
|
|
|
|
raise ReplicateAPIError(f"Unexpected error: {e}")
|
|
|
|
if last_error:
|
|
raise last_error
|
|
raise ReplicateAPIError("Max retries exceeded")
|
|
|
|
def validate_api_key(self) -> bool:
|
|
"""
|
|
Validate that the API key is valid.
|
|
|
|
Makes a minimal API call to check credentials.
|
|
|
|
Returns:
|
|
True if API key is valid, False otherwise.
|
|
"""
|
|
try:
|
|
model_name = self.model.split(':')[0]
|
|
model = replicate.models.get(model_name)
|
|
return model is not None
|
|
except Exception as e:
|
|
logger.warning(
|
|
"API key validation failed",
|
|
error=str(e)
|
|
)
|
|
return False
|