""" Model selector for tier-based AI model routing. This module provides intelligent model selection based on user subscription tiers and context types to optimize cost and quality. """ from dataclasses import dataclass from enum import Enum, auto import structlog from app.ai.replicate_client import ModelType logger = structlog.get_logger(__name__) class UserTier(str, Enum): """User subscription tiers.""" FREE = "free" BASIC = "basic" PREMIUM = "premium" ELITE = "elite" class ContextType(str, Enum): """Types of AI generation contexts.""" STORY_PROGRESSION = "story_progression" COMBAT_NARRATION = "combat_narration" QUEST_SELECTION = "quest_selection" NPC_DIALOGUE = "npc_dialogue" SIMPLE_RESPONSE = "simple_response" @dataclass class ModelConfig: """Configuration for a selected model.""" model_type: ModelType max_tokens: int temperature: float @property def model(self) -> str: """Get the model identifier string.""" return self.model_type.value class ModelSelector: """ Selects appropriate AI models based on user tier and context. This class implements tier-based routing to ensure: - Free users get Llama-3 (no cost) - Basic users get Claude Haiku (low cost) - Premium users get Claude Sonnet (medium cost) - Elite users get Claude Opus (high cost) Context-specific optimizations adjust token limits and temperature for different use cases. """ # Tier to model mapping TIER_MODELS = { UserTier.FREE: ModelType.LLAMA_3_8B, UserTier.BASIC: ModelType.CLAUDE_HAIKU, UserTier.PREMIUM: ModelType.CLAUDE_SONNET, UserTier.ELITE: ModelType.CLAUDE_SONNET_4, } # Base token limits by tier BASE_TOKEN_LIMITS = { UserTier.FREE: 256, UserTier.BASIC: 512, UserTier.PREMIUM: 1024, UserTier.ELITE: 2048, } # Temperature settings by context type CONTEXT_TEMPERATURES = { ContextType.STORY_PROGRESSION: 0.9, # Creative, varied ContextType.COMBAT_NARRATION: 0.8, # Exciting but coherent ContextType.QUEST_SELECTION: 0.5, # More deterministic ContextType.NPC_DIALOGUE: 0.85, # Natural conversation ContextType.SIMPLE_RESPONSE: 0.7, # Balanced } # Token multipliers by context (relative to base) CONTEXT_TOKEN_MULTIPLIERS = { ContextType.STORY_PROGRESSION: 1.0, # Full allocation ContextType.COMBAT_NARRATION: 0.75, # Shorter, punchier ContextType.QUEST_SELECTION: 0.5, # Brief selection ContextType.NPC_DIALOGUE: 0.75, # Conversational ContextType.SIMPLE_RESPONSE: 0.5, # Quick responses } def __init__(self): """Initialize the model selector.""" logger.info("ModelSelector initialized") def select_model( self, user_tier: UserTier, context_type: ContextType = ContextType.SIMPLE_RESPONSE ) -> ModelConfig: """ Select the appropriate model configuration for a user and context. Args: user_tier: The user's subscription tier. context_type: The type of content being generated. Returns: ModelConfig with model type, token limit, and temperature. Example: >>> selector = ModelSelector() >>> config = selector.select_model(UserTier.PREMIUM, ContextType.STORY_PROGRESSION) >>> config.model_type """ # Get model for tier model_type = self.TIER_MODELS[user_tier] # Calculate max tokens base_tokens = self.BASE_TOKEN_LIMITS[user_tier] multiplier = self.CONTEXT_TOKEN_MULTIPLIERS.get(context_type, 1.0) max_tokens = int(base_tokens * multiplier) # Get temperature for context temperature = self.CONTEXT_TEMPERATURES.get(context_type, 0.7) config = ModelConfig( model_type=model_type, max_tokens=max_tokens, temperature=temperature ) logger.debug( "Model selected", user_tier=user_tier.value, context_type=context_type.value, model=model_type.value, max_tokens=max_tokens, temperature=temperature ) return config def get_model_for_tier(self, user_tier: UserTier) -> ModelType: """ Get the default model for a user tier. Args: user_tier: The user's subscription tier. Returns: The ModelType for this tier. """ return self.TIER_MODELS[user_tier] def get_tier_info(self, user_tier: UserTier) -> dict: """ Get information about a tier's AI capabilities. Args: user_tier: The user's subscription tier. Returns: Dictionary with tier information. """ model_type = self.TIER_MODELS[user_tier] # Map models to friendly names model_names = { ModelType.LLAMA_3_8B: "Llama 3 8B", ModelType.CLAUDE_HAIKU: "Claude 3 Haiku", ModelType.CLAUDE_SONNET: "Claude 3.5 Sonnet", ModelType.CLAUDE_SONNET_4: "Claude Sonnet 4", } # Model quality descriptions quality_descriptions = { ModelType.LLAMA_3_8B: "Good quality, optimized for speed", ModelType.CLAUDE_HAIKU: "High quality, fast responses", ModelType.CLAUDE_SONNET: "Excellent quality, detailed narratives", ModelType.CLAUDE_SONNET_4: "Best quality, most creative and nuanced", } return { "tier": user_tier.value, "model": model_type.value, "model_name": model_names.get(model_type, model_type.value), "base_tokens": self.BASE_TOKEN_LIMITS[user_tier], "quality": quality_descriptions.get(model_type, "Standard quality"), } def estimate_cost_per_request(self, user_tier: UserTier) -> float: """ Estimate the cost per AI request for a tier. Args: user_tier: The user's subscription tier. Returns: Estimated cost in USD per request. Note: These are rough estimates based on typical usage. Actual costs depend on input/output tokens. """ # Approximate cost per 1K tokens (input + output average) COST_PER_1K_TOKENS = { ModelType.LLAMA_3_8B: 0.0, # Free tier ModelType.CLAUDE_HAIKU: 0.001, # $0.25/1M input, $1.25/1M output ModelType.CLAUDE_SONNET: 0.006, # $3/1M input, $15/1M output ModelType.CLAUDE_SONNET_4: 0.015, # Claude Sonnet 4 pricing } model_type = self.TIER_MODELS[user_tier] base_tokens = self.BASE_TOKEN_LIMITS[user_tier] cost_per_1k = COST_PER_1K_TOKENS.get(model_type, 0.0) # Estimate: base tokens for output + ~50% for input tokens estimated_tokens = base_tokens * 1.5 return (estimated_tokens / 1000) * cost_per_1k