""" Unit tests for model selector module. """ import pytest from app.ai import ( ModelSelector, ModelConfig, UserTier, ContextType, ModelType, ) class TestModelSelector: """Tests for ModelSelector class.""" def setup_method(self): """Set up test fixtures.""" self.selector = ModelSelector() def test_initialization(self): """Test ModelSelector initializes correctly.""" assert self.selector is not None # Test tier to model mapping def test_free_tier_gets_llama(self): """Free tier should get Llama-3 8B.""" config = self.selector.select_model(UserTier.FREE) assert config.model_type == ModelType.LLAMA_3_8B def test_basic_tier_gets_haiku(self): """Basic tier should get Claude Haiku.""" config = self.selector.select_model(UserTier.BASIC) assert config.model_type == ModelType.CLAUDE_HAIKU def test_premium_tier_gets_sonnet(self): """Premium tier should get Claude Sonnet.""" config = self.selector.select_model(UserTier.PREMIUM) assert config.model_type == ModelType.CLAUDE_SONNET def test_elite_tier_gets_opus(self): """Elite tier should get Claude Opus.""" config = self.selector.select_model(UserTier.ELITE) assert config.model_type == ModelType.CLAUDE_SONNET_4 # Test token limits by tier (using STORY_PROGRESSION for full allocation) def test_free_tier_token_limit(self): """Free tier should have 256 base tokens.""" config = self.selector.select_model(UserTier.FREE, ContextType.STORY_PROGRESSION) assert config.max_tokens == 256 def test_basic_tier_token_limit(self): """Basic tier should have 512 base tokens.""" config = self.selector.select_model(UserTier.BASIC, ContextType.STORY_PROGRESSION) assert config.max_tokens == 512 def test_premium_tier_token_limit(self): """Premium tier should have 1024 base tokens.""" config = self.selector.select_model(UserTier.PREMIUM, ContextType.STORY_PROGRESSION) assert config.max_tokens == 1024 def test_elite_tier_token_limit(self): """Elite tier should have 2048 base tokens.""" config = self.selector.select_model(UserTier.ELITE, ContextType.STORY_PROGRESSION) assert config.max_tokens == 2048 # Test context-based token adjustments def test_story_progression_full_tokens(self): """Story progression should use full token allocation.""" config = self.selector.select_model( UserTier.PREMIUM, ContextType.STORY_PROGRESSION ) # Full allocation = 1024 tokens for premium assert config.max_tokens == 1024 def test_combat_narration_reduced_tokens(self): """Combat narration should use 75% of tokens.""" config = self.selector.select_model( UserTier.PREMIUM, ContextType.COMBAT_NARRATION ) # 75% of 1024 = 768 assert config.max_tokens == 768 def test_quest_selection_half_tokens(self): """Quest selection should use 50% of tokens.""" config = self.selector.select_model( UserTier.PREMIUM, ContextType.QUEST_SELECTION ) # 50% of 1024 = 512 assert config.max_tokens == 512 def test_npc_dialogue_reduced_tokens(self): """NPC dialogue should use 75% of tokens.""" config = self.selector.select_model( UserTier.PREMIUM, ContextType.NPC_DIALOGUE ) # 75% of 1024 = 768 assert config.max_tokens == 768 def test_simple_response_half_tokens(self): """Simple response should use 50% of tokens.""" config = self.selector.select_model( UserTier.PREMIUM, ContextType.SIMPLE_RESPONSE ) # 50% of 1024 = 512 assert config.max_tokens == 512 # Test context-based temperature settings def test_story_progression_high_temperature(self): """Story progression should have high temperature (0.9).""" config = self.selector.select_model( UserTier.PREMIUM, ContextType.STORY_PROGRESSION ) assert config.temperature == 0.9 def test_combat_narration_medium_high_temperature(self): """Combat narration should have medium-high temperature (0.8).""" config = self.selector.select_model( UserTier.PREMIUM, ContextType.COMBAT_NARRATION ) assert config.temperature == 0.8 def test_quest_selection_low_temperature(self): """Quest selection should have low temperature (0.5).""" config = self.selector.select_model( UserTier.PREMIUM, ContextType.QUEST_SELECTION ) assert config.temperature == 0.5 def test_npc_dialogue_medium_temperature(self): """NPC dialogue should have medium temperature (0.85).""" config = self.selector.select_model( UserTier.PREMIUM, ContextType.NPC_DIALOGUE ) assert config.temperature == 0.85 def test_simple_response_balanced_temperature(self): """Simple response should have balanced temperature (0.7).""" config = self.selector.select_model( UserTier.PREMIUM, ContextType.SIMPLE_RESPONSE ) assert config.temperature == 0.7 # Test ModelConfig properties def test_model_config_model_property(self): """ModelConfig.model should return model identifier string.""" config = self.selector.select_model(UserTier.PREMIUM) assert config.model == "anthropic/claude-3.5-sonnet" # Test get_model_for_tier method def test_get_model_for_tier_free(self): """get_model_for_tier should return correct model for free tier.""" model = self.selector.get_model_for_tier(UserTier.FREE) assert model == ModelType.LLAMA_3_8B def test_get_model_for_tier_elite(self): """get_model_for_tier should return correct model for elite tier.""" model = self.selector.get_model_for_tier(UserTier.ELITE) assert model == ModelType.CLAUDE_SONNET_4 # Test get_tier_info method def test_get_tier_info_structure(self): """get_tier_info should return complete tier information.""" info = self.selector.get_tier_info(UserTier.PREMIUM) assert "tier" in info assert "model" in info assert "model_name" in info assert "base_tokens" in info assert "quality" in info def test_get_tier_info_premium_values(self): """get_tier_info should return correct values for premium tier.""" info = self.selector.get_tier_info(UserTier.PREMIUM) assert info["tier"] == "premium" assert info["model"] == "anthropic/claude-3.5-sonnet" assert info["model_name"] == "Claude 3.5 Sonnet" assert info["base_tokens"] == 1024 def test_get_tier_info_free_values(self): """get_tier_info should return correct values for free tier.""" info = self.selector.get_tier_info(UserTier.FREE) assert info["tier"] == "free" assert info["model_name"] == "Llama 3 8B" assert info["base_tokens"] == 256 # Test estimate_cost_per_request method def test_free_tier_zero_cost(self): """Free tier should have zero cost.""" cost = self.selector.estimate_cost_per_request(UserTier.FREE) assert cost == 0.0 def test_basic_tier_has_cost(self): """Basic tier should have non-zero cost.""" cost = self.selector.estimate_cost_per_request(UserTier.BASIC) assert cost > 0 def test_premium_tier_higher_cost(self): """Premium tier should have higher cost than basic.""" basic_cost = self.selector.estimate_cost_per_request(UserTier.BASIC) premium_cost = self.selector.estimate_cost_per_request(UserTier.PREMIUM) assert premium_cost > basic_cost def test_elite_tier_highest_cost(self): """Elite tier should have highest cost.""" premium_cost = self.selector.estimate_cost_per_request(UserTier.PREMIUM) elite_cost = self.selector.estimate_cost_per_request(UserTier.ELITE) assert elite_cost > premium_cost # Test all tier combinations def test_all_tiers_return_valid_config(self): """All tiers should return valid ModelConfig objects.""" for tier in UserTier: config = self.selector.select_model(tier) assert isinstance(config, ModelConfig) assert config.model_type in ModelType assert config.max_tokens > 0 assert 0 <= config.temperature <= 1 # Test all context combinations def test_all_contexts_return_valid_config(self): """All context types should return valid ModelConfig objects.""" for context in ContextType: config = self.selector.select_model(UserTier.PREMIUM, context) assert isinstance(config, ModelConfig) assert config.max_tokens > 0 assert 0 <= config.temperature <= 1 class TestUserTierEnum: """Tests for UserTier enum.""" def test_tier_values(self): """Test UserTier enum values are correct strings.""" assert UserTier.FREE.value == "free" assert UserTier.BASIC.value == "basic" assert UserTier.PREMIUM.value == "premium" assert UserTier.ELITE.value == "elite" def test_tier_string_conversion(self): """Test UserTier can be converted to string.""" assert str(UserTier.FREE) == "UserTier.FREE" class TestContextTypeEnum: """Tests for ContextType enum.""" def test_context_values(self): """Test ContextType enum values are correct strings.""" assert ContextType.STORY_PROGRESSION.value == "story_progression" assert ContextType.COMBAT_NARRATION.value == "combat_narration" assert ContextType.QUEST_SELECTION.value == "quest_selection" assert ContextType.NPC_DIALOGUE.value == "npc_dialogue" assert ContextType.SIMPLE_RESPONSE.value == "simple_response" class TestModelConfig: """Tests for ModelConfig dataclass.""" def test_model_config_creation(self): """Test ModelConfig can be created with valid data.""" config = ModelConfig( model_type=ModelType.CLAUDE_SONNET, max_tokens=1024, temperature=0.9 ) assert config.model_type == ModelType.CLAUDE_SONNET assert config.max_tokens == 1024 assert config.temperature == 0.9 def test_model_property(self): """Test model property returns model identifier.""" config = ModelConfig( model_type=ModelType.LLAMA_3_8B, max_tokens=256, temperature=0.7 ) assert config.model == "meta/meta-llama-3-8b-instruct"