first commit
This commit is contained in:
226
api/app/ai/model_selector.py
Normal file
226
api/app/ai/model_selector.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""
|
||||
Model selector for tier-based AI model routing.
|
||||
|
||||
This module provides intelligent model selection based on user subscription tiers
|
||||
and context types to optimize cost and quality.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
|
||||
import structlog
|
||||
|
||||
from app.ai.replicate_client import ModelType
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class UserTier(str, Enum):
|
||||
"""User subscription tiers."""
|
||||
FREE = "free"
|
||||
BASIC = "basic"
|
||||
PREMIUM = "premium"
|
||||
ELITE = "elite"
|
||||
|
||||
|
||||
class ContextType(str, Enum):
|
||||
"""Types of AI generation contexts."""
|
||||
STORY_PROGRESSION = "story_progression"
|
||||
COMBAT_NARRATION = "combat_narration"
|
||||
QUEST_SELECTION = "quest_selection"
|
||||
NPC_DIALOGUE = "npc_dialogue"
|
||||
SIMPLE_RESPONSE = "simple_response"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
"""Configuration for a selected model."""
|
||||
model_type: ModelType
|
||||
max_tokens: int
|
||||
temperature: float
|
||||
|
||||
@property
|
||||
def model(self) -> str:
|
||||
"""Get the model identifier string."""
|
||||
return self.model_type.value
|
||||
|
||||
|
||||
class ModelSelector:
|
||||
"""
|
||||
Selects appropriate AI models based on user tier and context.
|
||||
|
||||
This class implements tier-based routing to ensure:
|
||||
- Free users get Llama-3 (no cost)
|
||||
- Basic users get Claude Haiku (low cost)
|
||||
- Premium users get Claude Sonnet (medium cost)
|
||||
- Elite users get Claude Opus (high cost)
|
||||
|
||||
Context-specific optimizations adjust token limits and temperature
|
||||
for different use cases.
|
||||
"""
|
||||
|
||||
# Tier to model mapping
|
||||
TIER_MODELS = {
|
||||
UserTier.FREE: ModelType.LLAMA_3_8B,
|
||||
UserTier.BASIC: ModelType.CLAUDE_HAIKU,
|
||||
UserTier.PREMIUM: ModelType.CLAUDE_SONNET,
|
||||
UserTier.ELITE: ModelType.CLAUDE_SONNET_4,
|
||||
}
|
||||
|
||||
# Base token limits by tier
|
||||
BASE_TOKEN_LIMITS = {
|
||||
UserTier.FREE: 256,
|
||||
UserTier.BASIC: 512,
|
||||
UserTier.PREMIUM: 1024,
|
||||
UserTier.ELITE: 2048,
|
||||
}
|
||||
|
||||
# Temperature settings by context type
|
||||
CONTEXT_TEMPERATURES = {
|
||||
ContextType.STORY_PROGRESSION: 0.9, # Creative, varied
|
||||
ContextType.COMBAT_NARRATION: 0.8, # Exciting but coherent
|
||||
ContextType.QUEST_SELECTION: 0.5, # More deterministic
|
||||
ContextType.NPC_DIALOGUE: 0.85, # Natural conversation
|
||||
ContextType.SIMPLE_RESPONSE: 0.7, # Balanced
|
||||
}
|
||||
|
||||
# Token multipliers by context (relative to base)
|
||||
CONTEXT_TOKEN_MULTIPLIERS = {
|
||||
ContextType.STORY_PROGRESSION: 1.0, # Full allocation
|
||||
ContextType.COMBAT_NARRATION: 0.75, # Shorter, punchier
|
||||
ContextType.QUEST_SELECTION: 0.5, # Brief selection
|
||||
ContextType.NPC_DIALOGUE: 0.75, # Conversational
|
||||
ContextType.SIMPLE_RESPONSE: 0.5, # Quick responses
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the model selector."""
|
||||
logger.info("ModelSelector initialized")
|
||||
|
||||
def select_model(
|
||||
self,
|
||||
user_tier: UserTier,
|
||||
context_type: ContextType = ContextType.SIMPLE_RESPONSE
|
||||
) -> ModelConfig:
|
||||
"""
|
||||
Select the appropriate model configuration for a user and context.
|
||||
|
||||
Args:
|
||||
user_tier: The user's subscription tier.
|
||||
context_type: The type of content being generated.
|
||||
|
||||
Returns:
|
||||
ModelConfig with model type, token limit, and temperature.
|
||||
|
||||
Example:
|
||||
>>> selector = ModelSelector()
|
||||
>>> config = selector.select_model(UserTier.PREMIUM, ContextType.STORY_PROGRESSION)
|
||||
>>> config.model_type
|
||||
<ModelType.CLAUDE_SONNET: 'anthropic/claude-3.5-sonnet'>
|
||||
"""
|
||||
# Get model for tier
|
||||
model_type = self.TIER_MODELS[user_tier]
|
||||
|
||||
# Calculate max tokens
|
||||
base_tokens = self.BASE_TOKEN_LIMITS[user_tier]
|
||||
multiplier = self.CONTEXT_TOKEN_MULTIPLIERS.get(context_type, 1.0)
|
||||
max_tokens = int(base_tokens * multiplier)
|
||||
|
||||
# Get temperature for context
|
||||
temperature = self.CONTEXT_TEMPERATURES.get(context_type, 0.7)
|
||||
|
||||
config = ModelConfig(
|
||||
model_type=model_type,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Model selected",
|
||||
user_tier=user_tier.value,
|
||||
context_type=context_type.value,
|
||||
model=model_type.value,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
def get_model_for_tier(self, user_tier: UserTier) -> ModelType:
|
||||
"""
|
||||
Get the default model for a user tier.
|
||||
|
||||
Args:
|
||||
user_tier: The user's subscription tier.
|
||||
|
||||
Returns:
|
||||
The ModelType for this tier.
|
||||
"""
|
||||
return self.TIER_MODELS[user_tier]
|
||||
|
||||
def get_tier_info(self, user_tier: UserTier) -> dict:
|
||||
"""
|
||||
Get information about a tier's AI capabilities.
|
||||
|
||||
Args:
|
||||
user_tier: The user's subscription tier.
|
||||
|
||||
Returns:
|
||||
Dictionary with tier information.
|
||||
"""
|
||||
model_type = self.TIER_MODELS[user_tier]
|
||||
|
||||
# Map models to friendly names
|
||||
model_names = {
|
||||
ModelType.LLAMA_3_8B: "Llama 3 8B",
|
||||
ModelType.CLAUDE_HAIKU: "Claude 3 Haiku",
|
||||
ModelType.CLAUDE_SONNET: "Claude 3.5 Sonnet",
|
||||
ModelType.CLAUDE_SONNET_4: "Claude Sonnet 4",
|
||||
}
|
||||
|
||||
# Model quality descriptions
|
||||
quality_descriptions = {
|
||||
ModelType.LLAMA_3_8B: "Good quality, optimized for speed",
|
||||
ModelType.CLAUDE_HAIKU: "High quality, fast responses",
|
||||
ModelType.CLAUDE_SONNET: "Excellent quality, detailed narratives",
|
||||
ModelType.CLAUDE_SONNET_4: "Best quality, most creative and nuanced",
|
||||
}
|
||||
|
||||
return {
|
||||
"tier": user_tier.value,
|
||||
"model": model_type.value,
|
||||
"model_name": model_names.get(model_type, model_type.value),
|
||||
"base_tokens": self.BASE_TOKEN_LIMITS[user_tier],
|
||||
"quality": quality_descriptions.get(model_type, "Standard quality"),
|
||||
}
|
||||
|
||||
def estimate_cost_per_request(self, user_tier: UserTier) -> float:
|
||||
"""
|
||||
Estimate the cost per AI request for a tier.
|
||||
|
||||
Args:
|
||||
user_tier: The user's subscription tier.
|
||||
|
||||
Returns:
|
||||
Estimated cost in USD per request.
|
||||
|
||||
Note:
|
||||
These are rough estimates based on typical usage.
|
||||
Actual costs depend on input/output tokens.
|
||||
"""
|
||||
# Approximate cost per 1K tokens (input + output average)
|
||||
COST_PER_1K_TOKENS = {
|
||||
ModelType.LLAMA_3_8B: 0.0, # Free tier
|
||||
ModelType.CLAUDE_HAIKU: 0.001, # $0.25/1M input, $1.25/1M output
|
||||
ModelType.CLAUDE_SONNET: 0.006, # $3/1M input, $15/1M output
|
||||
ModelType.CLAUDE_SONNET_4: 0.015, # Claude Sonnet 4 pricing
|
||||
}
|
||||
|
||||
model_type = self.TIER_MODELS[user_tier]
|
||||
base_tokens = self.BASE_TOKEN_LIMITS[user_tier]
|
||||
cost_per_1k = COST_PER_1K_TOKENS.get(model_type, 0.0)
|
||||
|
||||
# Estimate: base tokens for output + ~50% for input tokens
|
||||
estimated_tokens = base_tokens * 1.5
|
||||
|
||||
return (estimated_tokens / 1000) * cost_per_1k
|
||||
Reference in New Issue
Block a user