Files
Code_of_Conquest/api/app/ai/model_selector.py
2025-11-24 23:10:55 -06:00

227 lines
7.0 KiB
Python

"""
Model selector for tier-based AI model routing.
This module provides intelligent model selection based on user subscription tiers
and context types to optimize cost and quality.
"""
from dataclasses import dataclass
from enum import Enum, auto
import structlog
from app.ai.replicate_client import ModelType
logger = structlog.get_logger(__name__)
class UserTier(str, Enum):
"""User subscription tiers."""
FREE = "free"
BASIC = "basic"
PREMIUM = "premium"
ELITE = "elite"
class ContextType(str, Enum):
"""Types of AI generation contexts."""
STORY_PROGRESSION = "story_progression"
COMBAT_NARRATION = "combat_narration"
QUEST_SELECTION = "quest_selection"
NPC_DIALOGUE = "npc_dialogue"
SIMPLE_RESPONSE = "simple_response"
@dataclass
class ModelConfig:
"""Configuration for a selected model."""
model_type: ModelType
max_tokens: int
temperature: float
@property
def model(self) -> str:
"""Get the model identifier string."""
return self.model_type.value
class ModelSelector:
"""
Selects appropriate AI models based on user tier and context.
This class implements tier-based routing to ensure:
- Free users get Llama-3 (no cost)
- Basic users get Claude Haiku (low cost)
- Premium users get Claude Sonnet (medium cost)
- Elite users get Claude Opus (high cost)
Context-specific optimizations adjust token limits and temperature
for different use cases.
"""
# Tier to model mapping
TIER_MODELS = {
UserTier.FREE: ModelType.LLAMA_3_8B,
UserTier.BASIC: ModelType.CLAUDE_HAIKU,
UserTier.PREMIUM: ModelType.CLAUDE_SONNET,
UserTier.ELITE: ModelType.CLAUDE_SONNET_4,
}
# Base token limits by tier
BASE_TOKEN_LIMITS = {
UserTier.FREE: 256,
UserTier.BASIC: 512,
UserTier.PREMIUM: 1024,
UserTier.ELITE: 2048,
}
# Temperature settings by context type
CONTEXT_TEMPERATURES = {
ContextType.STORY_PROGRESSION: 0.9, # Creative, varied
ContextType.COMBAT_NARRATION: 0.8, # Exciting but coherent
ContextType.QUEST_SELECTION: 0.5, # More deterministic
ContextType.NPC_DIALOGUE: 0.85, # Natural conversation
ContextType.SIMPLE_RESPONSE: 0.7, # Balanced
}
# Token multipliers by context (relative to base)
CONTEXT_TOKEN_MULTIPLIERS = {
ContextType.STORY_PROGRESSION: 1.0, # Full allocation
ContextType.COMBAT_NARRATION: 0.75, # Shorter, punchier
ContextType.QUEST_SELECTION: 0.5, # Brief selection
ContextType.NPC_DIALOGUE: 0.75, # Conversational
ContextType.SIMPLE_RESPONSE: 0.5, # Quick responses
}
def __init__(self):
"""Initialize the model selector."""
logger.info("ModelSelector initialized")
def select_model(
self,
user_tier: UserTier,
context_type: ContextType = ContextType.SIMPLE_RESPONSE
) -> ModelConfig:
"""
Select the appropriate model configuration for a user and context.
Args:
user_tier: The user's subscription tier.
context_type: The type of content being generated.
Returns:
ModelConfig with model type, token limit, and temperature.
Example:
>>> selector = ModelSelector()
>>> config = selector.select_model(UserTier.PREMIUM, ContextType.STORY_PROGRESSION)
>>> config.model_type
<ModelType.CLAUDE_SONNET: 'anthropic/claude-3.5-sonnet'>
"""
# Get model for tier
model_type = self.TIER_MODELS[user_tier]
# Calculate max tokens
base_tokens = self.BASE_TOKEN_LIMITS[user_tier]
multiplier = self.CONTEXT_TOKEN_MULTIPLIERS.get(context_type, 1.0)
max_tokens = int(base_tokens * multiplier)
# Get temperature for context
temperature = self.CONTEXT_TEMPERATURES.get(context_type, 0.7)
config = ModelConfig(
model_type=model_type,
max_tokens=max_tokens,
temperature=temperature
)
logger.debug(
"Model selected",
user_tier=user_tier.value,
context_type=context_type.value,
model=model_type.value,
max_tokens=max_tokens,
temperature=temperature
)
return config
def get_model_for_tier(self, user_tier: UserTier) -> ModelType:
"""
Get the default model for a user tier.
Args:
user_tier: The user's subscription tier.
Returns:
The ModelType for this tier.
"""
return self.TIER_MODELS[user_tier]
def get_tier_info(self, user_tier: UserTier) -> dict:
"""
Get information about a tier's AI capabilities.
Args:
user_tier: The user's subscription tier.
Returns:
Dictionary with tier information.
"""
model_type = self.TIER_MODELS[user_tier]
# Map models to friendly names
model_names = {
ModelType.LLAMA_3_8B: "Llama 3 8B",
ModelType.CLAUDE_HAIKU: "Claude 3 Haiku",
ModelType.CLAUDE_SONNET: "Claude 3.5 Sonnet",
ModelType.CLAUDE_SONNET_4: "Claude Sonnet 4",
}
# Model quality descriptions
quality_descriptions = {
ModelType.LLAMA_3_8B: "Good quality, optimized for speed",
ModelType.CLAUDE_HAIKU: "High quality, fast responses",
ModelType.CLAUDE_SONNET: "Excellent quality, detailed narratives",
ModelType.CLAUDE_SONNET_4: "Best quality, most creative and nuanced",
}
return {
"tier": user_tier.value,
"model": model_type.value,
"model_name": model_names.get(model_type, model_type.value),
"base_tokens": self.BASE_TOKEN_LIMITS[user_tier],
"quality": quality_descriptions.get(model_type, "Standard quality"),
}
def estimate_cost_per_request(self, user_tier: UserTier) -> float:
"""
Estimate the cost per AI request for a tier.
Args:
user_tier: The user's subscription tier.
Returns:
Estimated cost in USD per request.
Note:
These are rough estimates based on typical usage.
Actual costs depend on input/output tokens.
"""
# Approximate cost per 1K tokens (input + output average)
COST_PER_1K_TOKENS = {
ModelType.LLAMA_3_8B: 0.0, # Free tier
ModelType.CLAUDE_HAIKU: 0.001, # $0.25/1M input, $1.25/1M output
ModelType.CLAUDE_SONNET: 0.006, # $3/1M input, $15/1M output
ModelType.CLAUDE_SONNET_4: 0.015, # Claude Sonnet 4 pricing
}
model_type = self.TIER_MODELS[user_tier]
base_tokens = self.BASE_TOKEN_LIMITS[user_tier]
cost_per_1k = COST_PER_1K_TOKENS.get(model_type, 0.0)
# Estimate: base tokens for output + ~50% for input tokens
estimated_tokens = base_tokens * 1.5
return (estimated_tokens / 1000) * cost_per_1k