first commit
This commit is contained in:
450
api/app/ai/replicate_client.py
Normal file
450
api/app/ai/replicate_client.py
Normal file
@@ -0,0 +1,450 @@
|
||||
"""
|
||||
Replicate API client for AI model integration.
|
||||
|
||||
This module provides a client for interacting with the Replicate API
|
||||
to generate text using various models including Llama-3 and Claude models.
|
||||
All AI generation goes through Replicate for unified billing and management.
|
||||
"""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import replicate
|
||||
import structlog
|
||||
|
||||
from app.config import get_config
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
"""Supported model types on Replicate."""
|
||||
# Free tier - Llama models
|
||||
LLAMA_3_8B = "meta/meta-llama-3-8b-instruct"
|
||||
|
||||
# Paid tiers - Claude models via Replicate
|
||||
CLAUDE_HAIKU = "anthropic/claude-3.5-haiku"
|
||||
CLAUDE_SONNET = "anthropic/claude-3.5-sonnet"
|
||||
CLAUDE_SONNET_4 = "anthropic/claude-4.5-sonnet"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReplicateResponse:
|
||||
"""Response from Replicate API generation."""
|
||||
text: str
|
||||
tokens_used: int # Deprecated: use tokens_output instead
|
||||
tokens_input: int
|
||||
tokens_output: int
|
||||
model: str
|
||||
generation_time: float
|
||||
|
||||
|
||||
class ReplicateClientError(Exception):
|
||||
"""Base exception for Replicate client errors."""
|
||||
pass
|
||||
|
||||
|
||||
class ReplicateAPIError(ReplicateClientError):
|
||||
"""Error from Replicate API."""
|
||||
pass
|
||||
|
||||
|
||||
class ReplicateRateLimitError(ReplicateClientError):
|
||||
"""Rate limit exceeded on Replicate API."""
|
||||
pass
|
||||
|
||||
|
||||
class ReplicateTimeoutError(ReplicateClientError):
|
||||
"""Timeout waiting for Replicate response."""
|
||||
pass
|
||||
|
||||
|
||||
class ReplicateClient:
|
||||
"""
|
||||
Client for interacting with Replicate API.
|
||||
|
||||
Supports multiple models including Llama-3 and Claude models.
|
||||
Implements retry logic with exponential backoff for rate limits.
|
||||
"""
|
||||
|
||||
# Default model for free tier
|
||||
DEFAULT_MODEL = ModelType.LLAMA_3_8B
|
||||
|
||||
# Retry configuration
|
||||
MAX_RETRIES = 3
|
||||
INITIAL_RETRY_DELAY = 1.0 # seconds
|
||||
|
||||
# Default generation parameters
|
||||
DEFAULT_MAX_TOKENS = 256
|
||||
DEFAULT_TEMPERATURE = 0.7
|
||||
DEFAULT_TOP_P = 0.9
|
||||
DEFAULT_TIMEOUT = 30 # seconds
|
||||
|
||||
# Model-specific defaults
|
||||
MODEL_DEFAULTS = {
|
||||
ModelType.LLAMA_3_8B: {
|
||||
"max_tokens": 256,
|
||||
"temperature": 0.7,
|
||||
},
|
||||
ModelType.CLAUDE_HAIKU: {
|
||||
"max_tokens": 512,
|
||||
"temperature": 0.8,
|
||||
},
|
||||
ModelType.CLAUDE_SONNET: {
|
||||
"max_tokens": 1024,
|
||||
"temperature": 0.9,
|
||||
},
|
||||
ModelType.CLAUDE_SONNET_4: {
|
||||
"max_tokens": 2048,
|
||||
"temperature": 0.9,
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, api_token: str | None = None, model: str | ModelType | None = None):
|
||||
"""
|
||||
Initialize the Replicate client.
|
||||
|
||||
Args:
|
||||
api_token: Replicate API token. If not provided, reads from config.
|
||||
model: Model identifier or ModelType enum. Defaults to Llama-3 8B Instruct.
|
||||
|
||||
Raises:
|
||||
ReplicateClientError: If API token is not configured.
|
||||
"""
|
||||
config = get_config()
|
||||
|
||||
# Get API token from parameter or config
|
||||
self.api_token = api_token or getattr(config, 'replicate_api_token', None)
|
||||
if not self.api_token:
|
||||
raise ReplicateClientError(
|
||||
"Replicate API token not configured. "
|
||||
"Set REPLICATE_API_TOKEN in environment or config."
|
||||
)
|
||||
|
||||
# Get model from parameter, config, or default
|
||||
if model is None:
|
||||
model = getattr(config, 'REPLICATE_MODEL', None) or self.DEFAULT_MODEL
|
||||
|
||||
# Convert string to ModelType if needed, or keep as string for custom models
|
||||
if isinstance(model, ModelType):
|
||||
self.model = model.value
|
||||
self.model_type = model
|
||||
elif isinstance(model, str):
|
||||
# Try to match to ModelType
|
||||
self.model = model
|
||||
self.model_type = self._get_model_type(model)
|
||||
else:
|
||||
self.model = self.DEFAULT_MODEL.value
|
||||
self.model_type = self.DEFAULT_MODEL
|
||||
|
||||
# Set the API token for the replicate library
|
||||
import os
|
||||
os.environ['REPLICATE_API_TOKEN'] = self.api_token
|
||||
|
||||
logger.info(
|
||||
"Replicate client initialized",
|
||||
model=self.model,
|
||||
model_type=self.model_type.name if self.model_type else "custom"
|
||||
)
|
||||
|
||||
def _get_model_type(self, model_string: str) -> ModelType | None:
|
||||
"""Get ModelType enum from model string."""
|
||||
for model_type in ModelType:
|
||||
if model_type.value == model_string:
|
||||
return model_type
|
||||
return None
|
||||
|
||||
def _is_claude_model(self) -> bool:
|
||||
"""Check if current model is a Claude model."""
|
||||
return self.model_type in [
|
||||
ModelType.CLAUDE_HAIKU,
|
||||
ModelType.CLAUDE_SONNET,
|
||||
ModelType.CLAUDE_SONNET_4
|
||||
]
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
system_prompt: str | None = None,
|
||||
max_tokens: int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
timeout: int | None = None,
|
||||
model: str | ModelType | None = None
|
||||
) -> ReplicateResponse:
|
||||
"""
|
||||
Generate text using the configured model.
|
||||
|
||||
Args:
|
||||
prompt: The user prompt to send to the model.
|
||||
system_prompt: Optional system prompt for context setting.
|
||||
max_tokens: Maximum tokens to generate. Uses model defaults if not specified.
|
||||
temperature: Sampling temperature (0.0-1.0). Uses model defaults if not specified.
|
||||
top_p: Top-p sampling parameter. Defaults to 0.9.
|
||||
timeout: Timeout in seconds. Defaults to 30.
|
||||
model: Override the default model for this request.
|
||||
|
||||
Returns:
|
||||
ReplicateResponse with generated text and metadata.
|
||||
|
||||
Raises:
|
||||
ReplicateAPIError: For API errors.
|
||||
ReplicateRateLimitError: When rate limited.
|
||||
ReplicateTimeoutError: When request times out.
|
||||
"""
|
||||
# Handle model override
|
||||
if model:
|
||||
if isinstance(model, ModelType):
|
||||
current_model = model.value
|
||||
current_model_type = model
|
||||
else:
|
||||
current_model = model
|
||||
current_model_type = self._get_model_type(model)
|
||||
else:
|
||||
current_model = self.model
|
||||
current_model_type = self.model_type
|
||||
|
||||
# Get model-specific defaults
|
||||
model_defaults = self.MODEL_DEFAULTS.get(current_model_type, {})
|
||||
|
||||
# Apply defaults (parameter > model default > class default)
|
||||
max_tokens = max_tokens or model_defaults.get("max_tokens", self.DEFAULT_MAX_TOKENS)
|
||||
temperature = temperature or model_defaults.get("temperature", self.DEFAULT_TEMPERATURE)
|
||||
top_p = top_p or self.DEFAULT_TOP_P
|
||||
timeout = timeout or self.DEFAULT_TIMEOUT
|
||||
|
||||
# Format prompt based on model type
|
||||
is_claude = current_model_type in [
|
||||
ModelType.CLAUDE_HAIKU,
|
||||
ModelType.CLAUDE_SONNET,
|
||||
ModelType.CLAUDE_SONNET_4
|
||||
]
|
||||
|
||||
if is_claude:
|
||||
input_params = self._build_claude_params(
|
||||
prompt, system_prompt, max_tokens, temperature, top_p
|
||||
)
|
||||
else:
|
||||
# Llama-style formatting
|
||||
formatted_prompt = self._format_llama_prompt(prompt, system_prompt)
|
||||
input_params = {
|
||||
"prompt": formatted_prompt,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
"Generating text with Replicate",
|
||||
model=current_model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
is_claude=is_claude
|
||||
)
|
||||
|
||||
# Execute with retry logic
|
||||
start_time = time.time()
|
||||
output = self._execute_with_retry(current_model, input_params, timeout)
|
||||
generation_time = time.time() - start_time
|
||||
|
||||
# Parse response
|
||||
text = self._parse_response(output)
|
||||
|
||||
# Estimate tokens (rough approximation: ~4 chars per token)
|
||||
# Calculate input tokens from the actual prompt sent
|
||||
prompt_text = input_params.get("prompt", "")
|
||||
system_text = input_params.get("system_prompt", "")
|
||||
total_input_text = prompt_text + system_text
|
||||
tokens_input = len(total_input_text) // 4
|
||||
|
||||
# Calculate output tokens from response
|
||||
tokens_output = len(text) // 4
|
||||
|
||||
# Total for backwards compatibility
|
||||
tokens_used = tokens_input + tokens_output
|
||||
|
||||
logger.info(
|
||||
"Replicate generation complete",
|
||||
model=current_model,
|
||||
tokens_input=tokens_input,
|
||||
tokens_output=tokens_output,
|
||||
tokens_used=tokens_used,
|
||||
generation_time=f"{generation_time:.2f}s",
|
||||
response_length=len(text)
|
||||
)
|
||||
|
||||
return ReplicateResponse(
|
||||
text=text.strip(),
|
||||
tokens_used=tokens_used,
|
||||
tokens_input=tokens_input,
|
||||
tokens_output=tokens_output,
|
||||
model=current_model,
|
||||
generation_time=generation_time
|
||||
)
|
||||
|
||||
def _build_claude_params(
|
||||
self,
|
||||
prompt: str,
|
||||
system_prompt: str | None,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
top_p: float
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Build input parameters for Claude models on Replicate.
|
||||
|
||||
Args:
|
||||
prompt: User prompt.
|
||||
system_prompt: Optional system prompt.
|
||||
max_tokens: Maximum tokens to generate.
|
||||
temperature: Sampling temperature.
|
||||
top_p: Top-p sampling parameter.
|
||||
|
||||
Returns:
|
||||
Dictionary of input parameters for Replicate API.
|
||||
"""
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
}
|
||||
|
||||
if system_prompt:
|
||||
params["system_prompt"] = system_prompt
|
||||
|
||||
return params
|
||||
|
||||
def _format_llama_prompt(self, prompt: str, system_prompt: str | None = None) -> str:
|
||||
"""
|
||||
Format prompt for Llama-3 Instruct model.
|
||||
|
||||
Llama-3 Instruct uses a specific format with special tokens.
|
||||
|
||||
Args:
|
||||
prompt: User prompt.
|
||||
system_prompt: Optional system prompt.
|
||||
|
||||
Returns:
|
||||
Formatted prompt string.
|
||||
"""
|
||||
parts = []
|
||||
|
||||
if system_prompt:
|
||||
parts.append(f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|>")
|
||||
else:
|
||||
parts.append("<|begin_of_text|>")
|
||||
|
||||
parts.append(f"<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|>")
|
||||
parts.append("<|start_header_id|>assistant<|end_header_id|>\n\n")
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
def _parse_response(self, output: Any) -> str:
|
||||
"""
|
||||
Parse response from Replicate API.
|
||||
|
||||
Handles both streaming iterators and direct string responses.
|
||||
|
||||
Args:
|
||||
output: Raw output from Replicate API.
|
||||
|
||||
Returns:
|
||||
Parsed text string.
|
||||
"""
|
||||
if hasattr(output, '__iter__') and not isinstance(output, str):
|
||||
return "".join(output)
|
||||
return str(output)
|
||||
|
||||
def _execute_with_retry(
|
||||
self,
|
||||
model: str,
|
||||
input_params: dict[str, Any],
|
||||
timeout: int
|
||||
) -> Any:
|
||||
"""
|
||||
Execute Replicate API call with retry logic.
|
||||
|
||||
Implements exponential backoff for rate limit errors.
|
||||
|
||||
Args:
|
||||
model: Model identifier to run.
|
||||
input_params: Input parameters for the model.
|
||||
timeout: Timeout in seconds.
|
||||
|
||||
Returns:
|
||||
API response output.
|
||||
|
||||
Raises:
|
||||
ReplicateAPIError: For API errors after retries.
|
||||
ReplicateRateLimitError: When rate limit persists after retries.
|
||||
ReplicateTimeoutError: When request times out.
|
||||
"""
|
||||
last_error = None
|
||||
retry_delay = self.INITIAL_RETRY_DELAY
|
||||
|
||||
for attempt in range(self.MAX_RETRIES):
|
||||
try:
|
||||
output = replicate.run(
|
||||
model,
|
||||
input=input_params
|
||||
)
|
||||
return output
|
||||
|
||||
except replicate.exceptions.ReplicateError as e:
|
||||
error_message = str(e).lower()
|
||||
|
||||
if "rate limit" in error_message or "429" in error_message:
|
||||
last_error = ReplicateRateLimitError(f"Rate limited: {e}")
|
||||
|
||||
if attempt < self.MAX_RETRIES - 1:
|
||||
logger.warning(
|
||||
"Rate limited, retrying",
|
||||
attempt=attempt + 1,
|
||||
retry_delay=retry_delay
|
||||
)
|
||||
time.sleep(retry_delay)
|
||||
retry_delay *= 2
|
||||
continue
|
||||
else:
|
||||
raise last_error
|
||||
|
||||
elif "timeout" in error_message:
|
||||
raise ReplicateTimeoutError(f"Request timed out: {e}")
|
||||
|
||||
else:
|
||||
raise ReplicateAPIError(f"API error: {e}")
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e).lower()
|
||||
|
||||
if "timeout" in error_message:
|
||||
raise ReplicateTimeoutError(f"Request timed out: {e}")
|
||||
|
||||
raise ReplicateAPIError(f"Unexpected error: {e}")
|
||||
|
||||
if last_error:
|
||||
raise last_error
|
||||
raise ReplicateAPIError("Max retries exceeded")
|
||||
|
||||
def validate_api_key(self) -> bool:
|
||||
"""
|
||||
Validate that the API key is valid.
|
||||
|
||||
Makes a minimal API call to check credentials.
|
||||
|
||||
Returns:
|
||||
True if API key is valid, False otherwise.
|
||||
"""
|
||||
try:
|
||||
model_name = self.model.split(':')[0]
|
||||
model = replicate.models.get(model_name)
|
||||
return model is not None
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"API key validation failed",
|
||||
error=str(e)
|
||||
)
|
||||
return False
|
||||
Reference in New Issue
Block a user