first commit

This commit is contained in:
2025-11-24 23:10:55 -06:00
commit 8315fa51c9
279 changed files with 74600 additions and 0 deletions

View File

@@ -0,0 +1,450 @@
"""
Replicate API client for AI model integration.
This module provides a client for interacting with the Replicate API
to generate text using various models including Llama-3 and Claude models.
All AI generation goes through Replicate for unified billing and management.
"""
import time
from dataclasses import dataclass
from enum import Enum
from typing import Any
import replicate
import structlog
from app.config import get_config
logger = structlog.get_logger(__name__)
class ModelType(str, Enum):
"""Supported model types on Replicate."""
# Free tier - Llama models
LLAMA_3_8B = "meta/meta-llama-3-8b-instruct"
# Paid tiers - Claude models via Replicate
CLAUDE_HAIKU = "anthropic/claude-3.5-haiku"
CLAUDE_SONNET = "anthropic/claude-3.5-sonnet"
CLAUDE_SONNET_4 = "anthropic/claude-4.5-sonnet"
@dataclass
class ReplicateResponse:
"""Response from Replicate API generation."""
text: str
tokens_used: int # Deprecated: use tokens_output instead
tokens_input: int
tokens_output: int
model: str
generation_time: float
class ReplicateClientError(Exception):
"""Base exception for Replicate client errors."""
pass
class ReplicateAPIError(ReplicateClientError):
"""Error from Replicate API."""
pass
class ReplicateRateLimitError(ReplicateClientError):
"""Rate limit exceeded on Replicate API."""
pass
class ReplicateTimeoutError(ReplicateClientError):
"""Timeout waiting for Replicate response."""
pass
class ReplicateClient:
"""
Client for interacting with Replicate API.
Supports multiple models including Llama-3 and Claude models.
Implements retry logic with exponential backoff for rate limits.
"""
# Default model for free tier
DEFAULT_MODEL = ModelType.LLAMA_3_8B
# Retry configuration
MAX_RETRIES = 3
INITIAL_RETRY_DELAY = 1.0 # seconds
# Default generation parameters
DEFAULT_MAX_TOKENS = 256
DEFAULT_TEMPERATURE = 0.7
DEFAULT_TOP_P = 0.9
DEFAULT_TIMEOUT = 30 # seconds
# Model-specific defaults
MODEL_DEFAULTS = {
ModelType.LLAMA_3_8B: {
"max_tokens": 256,
"temperature": 0.7,
},
ModelType.CLAUDE_HAIKU: {
"max_tokens": 512,
"temperature": 0.8,
},
ModelType.CLAUDE_SONNET: {
"max_tokens": 1024,
"temperature": 0.9,
},
ModelType.CLAUDE_SONNET_4: {
"max_tokens": 2048,
"temperature": 0.9,
},
}
def __init__(self, api_token: str | None = None, model: str | ModelType | None = None):
"""
Initialize the Replicate client.
Args:
api_token: Replicate API token. If not provided, reads from config.
model: Model identifier or ModelType enum. Defaults to Llama-3 8B Instruct.
Raises:
ReplicateClientError: If API token is not configured.
"""
config = get_config()
# Get API token from parameter or config
self.api_token = api_token or getattr(config, 'replicate_api_token', None)
if not self.api_token:
raise ReplicateClientError(
"Replicate API token not configured. "
"Set REPLICATE_API_TOKEN in environment or config."
)
# Get model from parameter, config, or default
if model is None:
model = getattr(config, 'REPLICATE_MODEL', None) or self.DEFAULT_MODEL
# Convert string to ModelType if needed, or keep as string for custom models
if isinstance(model, ModelType):
self.model = model.value
self.model_type = model
elif isinstance(model, str):
# Try to match to ModelType
self.model = model
self.model_type = self._get_model_type(model)
else:
self.model = self.DEFAULT_MODEL.value
self.model_type = self.DEFAULT_MODEL
# Set the API token for the replicate library
import os
os.environ['REPLICATE_API_TOKEN'] = self.api_token
logger.info(
"Replicate client initialized",
model=self.model,
model_type=self.model_type.name if self.model_type else "custom"
)
def _get_model_type(self, model_string: str) -> ModelType | None:
"""Get ModelType enum from model string."""
for model_type in ModelType:
if model_type.value == model_string:
return model_type
return None
def _is_claude_model(self) -> bool:
"""Check if current model is a Claude model."""
return self.model_type in [
ModelType.CLAUDE_HAIKU,
ModelType.CLAUDE_SONNET,
ModelType.CLAUDE_SONNET_4
]
def generate(
self,
prompt: str,
system_prompt: str | None = None,
max_tokens: int | None = None,
temperature: float | None = None,
top_p: float | None = None,
timeout: int | None = None,
model: str | ModelType | None = None
) -> ReplicateResponse:
"""
Generate text using the configured model.
Args:
prompt: The user prompt to send to the model.
system_prompt: Optional system prompt for context setting.
max_tokens: Maximum tokens to generate. Uses model defaults if not specified.
temperature: Sampling temperature (0.0-1.0). Uses model defaults if not specified.
top_p: Top-p sampling parameter. Defaults to 0.9.
timeout: Timeout in seconds. Defaults to 30.
model: Override the default model for this request.
Returns:
ReplicateResponse with generated text and metadata.
Raises:
ReplicateAPIError: For API errors.
ReplicateRateLimitError: When rate limited.
ReplicateTimeoutError: When request times out.
"""
# Handle model override
if model:
if isinstance(model, ModelType):
current_model = model.value
current_model_type = model
else:
current_model = model
current_model_type = self._get_model_type(model)
else:
current_model = self.model
current_model_type = self.model_type
# Get model-specific defaults
model_defaults = self.MODEL_DEFAULTS.get(current_model_type, {})
# Apply defaults (parameter > model default > class default)
max_tokens = max_tokens or model_defaults.get("max_tokens", self.DEFAULT_MAX_TOKENS)
temperature = temperature or model_defaults.get("temperature", self.DEFAULT_TEMPERATURE)
top_p = top_p or self.DEFAULT_TOP_P
timeout = timeout or self.DEFAULT_TIMEOUT
# Format prompt based on model type
is_claude = current_model_type in [
ModelType.CLAUDE_HAIKU,
ModelType.CLAUDE_SONNET,
ModelType.CLAUDE_SONNET_4
]
if is_claude:
input_params = self._build_claude_params(
prompt, system_prompt, max_tokens, temperature, top_p
)
else:
# Llama-style formatting
formatted_prompt = self._format_llama_prompt(prompt, system_prompt)
input_params = {
"prompt": formatted_prompt,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
}
logger.debug(
"Generating text with Replicate",
model=current_model,
max_tokens=max_tokens,
temperature=temperature,
is_claude=is_claude
)
# Execute with retry logic
start_time = time.time()
output = self._execute_with_retry(current_model, input_params, timeout)
generation_time = time.time() - start_time
# Parse response
text = self._parse_response(output)
# Estimate tokens (rough approximation: ~4 chars per token)
# Calculate input tokens from the actual prompt sent
prompt_text = input_params.get("prompt", "")
system_text = input_params.get("system_prompt", "")
total_input_text = prompt_text + system_text
tokens_input = len(total_input_text) // 4
# Calculate output tokens from response
tokens_output = len(text) // 4
# Total for backwards compatibility
tokens_used = tokens_input + tokens_output
logger.info(
"Replicate generation complete",
model=current_model,
tokens_input=tokens_input,
tokens_output=tokens_output,
tokens_used=tokens_used,
generation_time=f"{generation_time:.2f}s",
response_length=len(text)
)
return ReplicateResponse(
text=text.strip(),
tokens_used=tokens_used,
tokens_input=tokens_input,
tokens_output=tokens_output,
model=current_model,
generation_time=generation_time
)
def _build_claude_params(
self,
prompt: str,
system_prompt: str | None,
max_tokens: int,
temperature: float,
top_p: float
) -> dict[str, Any]:
"""
Build input parameters for Claude models on Replicate.
Args:
prompt: User prompt.
system_prompt: Optional system prompt.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Top-p sampling parameter.
Returns:
Dictionary of input parameters for Replicate API.
"""
params = {
"prompt": prompt,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
}
if system_prompt:
params["system_prompt"] = system_prompt
return params
def _format_llama_prompt(self, prompt: str, system_prompt: str | None = None) -> str:
"""
Format prompt for Llama-3 Instruct model.
Llama-3 Instruct uses a specific format with special tokens.
Args:
prompt: User prompt.
system_prompt: Optional system prompt.
Returns:
Formatted prompt string.
"""
parts = []
if system_prompt:
parts.append(f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|>")
else:
parts.append("<|begin_of_text|>")
parts.append(f"<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|>")
parts.append("<|start_header_id|>assistant<|end_header_id|>\n\n")
return "".join(parts)
def _parse_response(self, output: Any) -> str:
"""
Parse response from Replicate API.
Handles both streaming iterators and direct string responses.
Args:
output: Raw output from Replicate API.
Returns:
Parsed text string.
"""
if hasattr(output, '__iter__') and not isinstance(output, str):
return "".join(output)
return str(output)
def _execute_with_retry(
self,
model: str,
input_params: dict[str, Any],
timeout: int
) -> Any:
"""
Execute Replicate API call with retry logic.
Implements exponential backoff for rate limit errors.
Args:
model: Model identifier to run.
input_params: Input parameters for the model.
timeout: Timeout in seconds.
Returns:
API response output.
Raises:
ReplicateAPIError: For API errors after retries.
ReplicateRateLimitError: When rate limit persists after retries.
ReplicateTimeoutError: When request times out.
"""
last_error = None
retry_delay = self.INITIAL_RETRY_DELAY
for attempt in range(self.MAX_RETRIES):
try:
output = replicate.run(
model,
input=input_params
)
return output
except replicate.exceptions.ReplicateError as e:
error_message = str(e).lower()
if "rate limit" in error_message or "429" in error_message:
last_error = ReplicateRateLimitError(f"Rate limited: {e}")
if attempt < self.MAX_RETRIES - 1:
logger.warning(
"Rate limited, retrying",
attempt=attempt + 1,
retry_delay=retry_delay
)
time.sleep(retry_delay)
retry_delay *= 2
continue
else:
raise last_error
elif "timeout" in error_message:
raise ReplicateTimeoutError(f"Request timed out: {e}")
else:
raise ReplicateAPIError(f"API error: {e}")
except Exception as e:
error_message = str(e).lower()
if "timeout" in error_message:
raise ReplicateTimeoutError(f"Request timed out: {e}")
raise ReplicateAPIError(f"Unexpected error: {e}")
if last_error:
raise last_error
raise ReplicateAPIError("Max retries exceeded")
def validate_api_key(self) -> bool:
"""
Validate that the API key is valid.
Makes a minimal API call to check credentials.
Returns:
True if API key is valid, False otherwise.
"""
try:
model_name = self.model.split(':')[0]
model = replicate.models.get(model_name)
return model is not None
except Exception as e:
logger.warning(
"API key validation failed",
error=str(e)
)
return False