227 lines
7.0 KiB
Python
227 lines
7.0 KiB
Python
"""
|
|
Model selector for tier-based AI model routing.
|
|
|
|
This module provides intelligent model selection based on user subscription tiers
|
|
and context types to optimize cost and quality.
|
|
"""
|
|
|
|
from dataclasses import dataclass
|
|
from enum import Enum, auto
|
|
|
|
import structlog
|
|
|
|
from app.ai.replicate_client import ModelType
|
|
|
|
logger = structlog.get_logger(__name__)
|
|
|
|
|
|
class UserTier(str, Enum):
|
|
"""User subscription tiers."""
|
|
FREE = "free"
|
|
BASIC = "basic"
|
|
PREMIUM = "premium"
|
|
ELITE = "elite"
|
|
|
|
|
|
class ContextType(str, Enum):
|
|
"""Types of AI generation contexts."""
|
|
STORY_PROGRESSION = "story_progression"
|
|
COMBAT_NARRATION = "combat_narration"
|
|
QUEST_SELECTION = "quest_selection"
|
|
NPC_DIALOGUE = "npc_dialogue"
|
|
SIMPLE_RESPONSE = "simple_response"
|
|
|
|
|
|
@dataclass
|
|
class ModelConfig:
|
|
"""Configuration for a selected model."""
|
|
model_type: ModelType
|
|
max_tokens: int
|
|
temperature: float
|
|
|
|
@property
|
|
def model(self) -> str:
|
|
"""Get the model identifier string."""
|
|
return self.model_type.value
|
|
|
|
|
|
class ModelSelector:
|
|
"""
|
|
Selects appropriate AI models based on user tier and context.
|
|
|
|
This class implements tier-based routing to ensure:
|
|
- Free users get Llama-3 (no cost)
|
|
- Basic users get Claude Haiku (low cost)
|
|
- Premium users get Claude Sonnet (medium cost)
|
|
- Elite users get Claude Opus (high cost)
|
|
|
|
Context-specific optimizations adjust token limits and temperature
|
|
for different use cases.
|
|
"""
|
|
|
|
# Tier to model mapping
|
|
TIER_MODELS = {
|
|
UserTier.FREE: ModelType.LLAMA_3_8B,
|
|
UserTier.BASIC: ModelType.CLAUDE_HAIKU,
|
|
UserTier.PREMIUM: ModelType.CLAUDE_SONNET,
|
|
UserTier.ELITE: ModelType.CLAUDE_SONNET_4,
|
|
}
|
|
|
|
# Base token limits by tier
|
|
BASE_TOKEN_LIMITS = {
|
|
UserTier.FREE: 256,
|
|
UserTier.BASIC: 512,
|
|
UserTier.PREMIUM: 1024,
|
|
UserTier.ELITE: 2048,
|
|
}
|
|
|
|
# Temperature settings by context type
|
|
CONTEXT_TEMPERATURES = {
|
|
ContextType.STORY_PROGRESSION: 0.9, # Creative, varied
|
|
ContextType.COMBAT_NARRATION: 0.8, # Exciting but coherent
|
|
ContextType.QUEST_SELECTION: 0.5, # More deterministic
|
|
ContextType.NPC_DIALOGUE: 0.85, # Natural conversation
|
|
ContextType.SIMPLE_RESPONSE: 0.7, # Balanced
|
|
}
|
|
|
|
# Token multipliers by context (relative to base)
|
|
CONTEXT_TOKEN_MULTIPLIERS = {
|
|
ContextType.STORY_PROGRESSION: 1.0, # Full allocation
|
|
ContextType.COMBAT_NARRATION: 0.75, # Shorter, punchier
|
|
ContextType.QUEST_SELECTION: 0.5, # Brief selection
|
|
ContextType.NPC_DIALOGUE: 0.75, # Conversational
|
|
ContextType.SIMPLE_RESPONSE: 0.5, # Quick responses
|
|
}
|
|
|
|
def __init__(self):
|
|
"""Initialize the model selector."""
|
|
logger.info("ModelSelector initialized")
|
|
|
|
def select_model(
|
|
self,
|
|
user_tier: UserTier,
|
|
context_type: ContextType = ContextType.SIMPLE_RESPONSE
|
|
) -> ModelConfig:
|
|
"""
|
|
Select the appropriate model configuration for a user and context.
|
|
|
|
Args:
|
|
user_tier: The user's subscription tier.
|
|
context_type: The type of content being generated.
|
|
|
|
Returns:
|
|
ModelConfig with model type, token limit, and temperature.
|
|
|
|
Example:
|
|
>>> selector = ModelSelector()
|
|
>>> config = selector.select_model(UserTier.PREMIUM, ContextType.STORY_PROGRESSION)
|
|
>>> config.model_type
|
|
<ModelType.CLAUDE_SONNET: 'anthropic/claude-3.5-sonnet'>
|
|
"""
|
|
# Get model for tier
|
|
model_type = self.TIER_MODELS[user_tier]
|
|
|
|
# Calculate max tokens
|
|
base_tokens = self.BASE_TOKEN_LIMITS[user_tier]
|
|
multiplier = self.CONTEXT_TOKEN_MULTIPLIERS.get(context_type, 1.0)
|
|
max_tokens = int(base_tokens * multiplier)
|
|
|
|
# Get temperature for context
|
|
temperature = self.CONTEXT_TEMPERATURES.get(context_type, 0.7)
|
|
|
|
config = ModelConfig(
|
|
model_type=model_type,
|
|
max_tokens=max_tokens,
|
|
temperature=temperature
|
|
)
|
|
|
|
logger.debug(
|
|
"Model selected",
|
|
user_tier=user_tier.value,
|
|
context_type=context_type.value,
|
|
model=model_type.value,
|
|
max_tokens=max_tokens,
|
|
temperature=temperature
|
|
)
|
|
|
|
return config
|
|
|
|
def get_model_for_tier(self, user_tier: UserTier) -> ModelType:
|
|
"""
|
|
Get the default model for a user tier.
|
|
|
|
Args:
|
|
user_tier: The user's subscription tier.
|
|
|
|
Returns:
|
|
The ModelType for this tier.
|
|
"""
|
|
return self.TIER_MODELS[user_tier]
|
|
|
|
def get_tier_info(self, user_tier: UserTier) -> dict:
|
|
"""
|
|
Get information about a tier's AI capabilities.
|
|
|
|
Args:
|
|
user_tier: The user's subscription tier.
|
|
|
|
Returns:
|
|
Dictionary with tier information.
|
|
"""
|
|
model_type = self.TIER_MODELS[user_tier]
|
|
|
|
# Map models to friendly names
|
|
model_names = {
|
|
ModelType.LLAMA_3_8B: "Llama 3 8B",
|
|
ModelType.CLAUDE_HAIKU: "Claude 3 Haiku",
|
|
ModelType.CLAUDE_SONNET: "Claude 3.5 Sonnet",
|
|
ModelType.CLAUDE_SONNET_4: "Claude Sonnet 4",
|
|
}
|
|
|
|
# Model quality descriptions
|
|
quality_descriptions = {
|
|
ModelType.LLAMA_3_8B: "Good quality, optimized for speed",
|
|
ModelType.CLAUDE_HAIKU: "High quality, fast responses",
|
|
ModelType.CLAUDE_SONNET: "Excellent quality, detailed narratives",
|
|
ModelType.CLAUDE_SONNET_4: "Best quality, most creative and nuanced",
|
|
}
|
|
|
|
return {
|
|
"tier": user_tier.value,
|
|
"model": model_type.value,
|
|
"model_name": model_names.get(model_type, model_type.value),
|
|
"base_tokens": self.BASE_TOKEN_LIMITS[user_tier],
|
|
"quality": quality_descriptions.get(model_type, "Standard quality"),
|
|
}
|
|
|
|
def estimate_cost_per_request(self, user_tier: UserTier) -> float:
|
|
"""
|
|
Estimate the cost per AI request for a tier.
|
|
|
|
Args:
|
|
user_tier: The user's subscription tier.
|
|
|
|
Returns:
|
|
Estimated cost in USD per request.
|
|
|
|
Note:
|
|
These are rough estimates based on typical usage.
|
|
Actual costs depend on input/output tokens.
|
|
"""
|
|
# Approximate cost per 1K tokens (input + output average)
|
|
COST_PER_1K_TOKENS = {
|
|
ModelType.LLAMA_3_8B: 0.0, # Free tier
|
|
ModelType.CLAUDE_HAIKU: 0.001, # $0.25/1M input, $1.25/1M output
|
|
ModelType.CLAUDE_SONNET: 0.006, # $3/1M input, $15/1M output
|
|
ModelType.CLAUDE_SONNET_4: 0.015, # Claude Sonnet 4 pricing
|
|
}
|
|
|
|
model_type = self.TIER_MODELS[user_tier]
|
|
base_tokens = self.BASE_TOKEN_LIMITS[user_tier]
|
|
cost_per_1k = COST_PER_1K_TOKENS.get(model_type, 0.0)
|
|
|
|
# Estimate: base tokens for output + ~50% for input tokens
|
|
estimated_tokens = base_tokens * 1.5
|
|
|
|
return (estimated_tokens / 1000) * cost_per_1k
|