Files
Code_of_Conquest/api/tests/test_model_selector.py
2025-11-24 23:10:55 -06:00

295 lines
11 KiB
Python

"""
Unit tests for model selector module.
"""
import pytest
from app.ai import (
ModelSelector,
ModelConfig,
UserTier,
ContextType,
ModelType,
)
class TestModelSelector:
"""Tests for ModelSelector class."""
def setup_method(self):
"""Set up test fixtures."""
self.selector = ModelSelector()
def test_initialization(self):
"""Test ModelSelector initializes correctly."""
assert self.selector is not None
# Test tier to model mapping
def test_free_tier_gets_llama(self):
"""Free tier should get Llama-3 8B."""
config = self.selector.select_model(UserTier.FREE)
assert config.model_type == ModelType.LLAMA_3_8B
def test_basic_tier_gets_haiku(self):
"""Basic tier should get Claude Haiku."""
config = self.selector.select_model(UserTier.BASIC)
assert config.model_type == ModelType.CLAUDE_HAIKU
def test_premium_tier_gets_sonnet(self):
"""Premium tier should get Claude Sonnet."""
config = self.selector.select_model(UserTier.PREMIUM)
assert config.model_type == ModelType.CLAUDE_SONNET
def test_elite_tier_gets_opus(self):
"""Elite tier should get Claude Opus."""
config = self.selector.select_model(UserTier.ELITE)
assert config.model_type == ModelType.CLAUDE_SONNET_4
# Test token limits by tier (using STORY_PROGRESSION for full allocation)
def test_free_tier_token_limit(self):
"""Free tier should have 256 base tokens."""
config = self.selector.select_model(UserTier.FREE, ContextType.STORY_PROGRESSION)
assert config.max_tokens == 256
def test_basic_tier_token_limit(self):
"""Basic tier should have 512 base tokens."""
config = self.selector.select_model(UserTier.BASIC, ContextType.STORY_PROGRESSION)
assert config.max_tokens == 512
def test_premium_tier_token_limit(self):
"""Premium tier should have 1024 base tokens."""
config = self.selector.select_model(UserTier.PREMIUM, ContextType.STORY_PROGRESSION)
assert config.max_tokens == 1024
def test_elite_tier_token_limit(self):
"""Elite tier should have 2048 base tokens."""
config = self.selector.select_model(UserTier.ELITE, ContextType.STORY_PROGRESSION)
assert config.max_tokens == 2048
# Test context-based token adjustments
def test_story_progression_full_tokens(self):
"""Story progression should use full token allocation."""
config = self.selector.select_model(
UserTier.PREMIUM,
ContextType.STORY_PROGRESSION
)
# Full allocation = 1024 tokens for premium
assert config.max_tokens == 1024
def test_combat_narration_reduced_tokens(self):
"""Combat narration should use 75% of tokens."""
config = self.selector.select_model(
UserTier.PREMIUM,
ContextType.COMBAT_NARRATION
)
# 75% of 1024 = 768
assert config.max_tokens == 768
def test_quest_selection_half_tokens(self):
"""Quest selection should use 50% of tokens."""
config = self.selector.select_model(
UserTier.PREMIUM,
ContextType.QUEST_SELECTION
)
# 50% of 1024 = 512
assert config.max_tokens == 512
def test_npc_dialogue_reduced_tokens(self):
"""NPC dialogue should use 75% of tokens."""
config = self.selector.select_model(
UserTier.PREMIUM,
ContextType.NPC_DIALOGUE
)
# 75% of 1024 = 768
assert config.max_tokens == 768
def test_simple_response_half_tokens(self):
"""Simple response should use 50% of tokens."""
config = self.selector.select_model(
UserTier.PREMIUM,
ContextType.SIMPLE_RESPONSE
)
# 50% of 1024 = 512
assert config.max_tokens == 512
# Test context-based temperature settings
def test_story_progression_high_temperature(self):
"""Story progression should have high temperature (0.9)."""
config = self.selector.select_model(
UserTier.PREMIUM,
ContextType.STORY_PROGRESSION
)
assert config.temperature == 0.9
def test_combat_narration_medium_high_temperature(self):
"""Combat narration should have medium-high temperature (0.8)."""
config = self.selector.select_model(
UserTier.PREMIUM,
ContextType.COMBAT_NARRATION
)
assert config.temperature == 0.8
def test_quest_selection_low_temperature(self):
"""Quest selection should have low temperature (0.5)."""
config = self.selector.select_model(
UserTier.PREMIUM,
ContextType.QUEST_SELECTION
)
assert config.temperature == 0.5
def test_npc_dialogue_medium_temperature(self):
"""NPC dialogue should have medium temperature (0.85)."""
config = self.selector.select_model(
UserTier.PREMIUM,
ContextType.NPC_DIALOGUE
)
assert config.temperature == 0.85
def test_simple_response_balanced_temperature(self):
"""Simple response should have balanced temperature (0.7)."""
config = self.selector.select_model(
UserTier.PREMIUM,
ContextType.SIMPLE_RESPONSE
)
assert config.temperature == 0.7
# Test ModelConfig properties
def test_model_config_model_property(self):
"""ModelConfig.model should return model identifier string."""
config = self.selector.select_model(UserTier.PREMIUM)
assert config.model == "anthropic/claude-3.5-sonnet"
# Test get_model_for_tier method
def test_get_model_for_tier_free(self):
"""get_model_for_tier should return correct model for free tier."""
model = self.selector.get_model_for_tier(UserTier.FREE)
assert model == ModelType.LLAMA_3_8B
def test_get_model_for_tier_elite(self):
"""get_model_for_tier should return correct model for elite tier."""
model = self.selector.get_model_for_tier(UserTier.ELITE)
assert model == ModelType.CLAUDE_SONNET_4
# Test get_tier_info method
def test_get_tier_info_structure(self):
"""get_tier_info should return complete tier information."""
info = self.selector.get_tier_info(UserTier.PREMIUM)
assert "tier" in info
assert "model" in info
assert "model_name" in info
assert "base_tokens" in info
assert "quality" in info
def test_get_tier_info_premium_values(self):
"""get_tier_info should return correct values for premium tier."""
info = self.selector.get_tier_info(UserTier.PREMIUM)
assert info["tier"] == "premium"
assert info["model"] == "anthropic/claude-3.5-sonnet"
assert info["model_name"] == "Claude 3.5 Sonnet"
assert info["base_tokens"] == 1024
def test_get_tier_info_free_values(self):
"""get_tier_info should return correct values for free tier."""
info = self.selector.get_tier_info(UserTier.FREE)
assert info["tier"] == "free"
assert info["model_name"] == "Llama 3 8B"
assert info["base_tokens"] == 256
# Test estimate_cost_per_request method
def test_free_tier_zero_cost(self):
"""Free tier should have zero cost."""
cost = self.selector.estimate_cost_per_request(UserTier.FREE)
assert cost == 0.0
def test_basic_tier_has_cost(self):
"""Basic tier should have non-zero cost."""
cost = self.selector.estimate_cost_per_request(UserTier.BASIC)
assert cost > 0
def test_premium_tier_higher_cost(self):
"""Premium tier should have higher cost than basic."""
basic_cost = self.selector.estimate_cost_per_request(UserTier.BASIC)
premium_cost = self.selector.estimate_cost_per_request(UserTier.PREMIUM)
assert premium_cost > basic_cost
def test_elite_tier_highest_cost(self):
"""Elite tier should have highest cost."""
premium_cost = self.selector.estimate_cost_per_request(UserTier.PREMIUM)
elite_cost = self.selector.estimate_cost_per_request(UserTier.ELITE)
assert elite_cost > premium_cost
# Test all tier combinations
def test_all_tiers_return_valid_config(self):
"""All tiers should return valid ModelConfig objects."""
for tier in UserTier:
config = self.selector.select_model(tier)
assert isinstance(config, ModelConfig)
assert config.model_type in ModelType
assert config.max_tokens > 0
assert 0 <= config.temperature <= 1
# Test all context combinations
def test_all_contexts_return_valid_config(self):
"""All context types should return valid ModelConfig objects."""
for context in ContextType:
config = self.selector.select_model(UserTier.PREMIUM, context)
assert isinstance(config, ModelConfig)
assert config.max_tokens > 0
assert 0 <= config.temperature <= 1
class TestUserTierEnum:
"""Tests for UserTier enum."""
def test_tier_values(self):
"""Test UserTier enum values are correct strings."""
assert UserTier.FREE.value == "free"
assert UserTier.BASIC.value == "basic"
assert UserTier.PREMIUM.value == "premium"
assert UserTier.ELITE.value == "elite"
def test_tier_string_conversion(self):
"""Test UserTier can be converted to string."""
assert str(UserTier.FREE) == "UserTier.FREE"
class TestContextTypeEnum:
"""Tests for ContextType enum."""
def test_context_values(self):
"""Test ContextType enum values are correct strings."""
assert ContextType.STORY_PROGRESSION.value == "story_progression"
assert ContextType.COMBAT_NARRATION.value == "combat_narration"
assert ContextType.QUEST_SELECTION.value == "quest_selection"
assert ContextType.NPC_DIALOGUE.value == "npc_dialogue"
assert ContextType.SIMPLE_RESPONSE.value == "simple_response"
class TestModelConfig:
"""Tests for ModelConfig dataclass."""
def test_model_config_creation(self):
"""Test ModelConfig can be created with valid data."""
config = ModelConfig(
model_type=ModelType.CLAUDE_SONNET,
max_tokens=1024,
temperature=0.9
)
assert config.model_type == ModelType.CLAUDE_SONNET
assert config.max_tokens == 1024
assert config.temperature == 0.9
def test_model_property(self):
"""Test model property returns model identifier."""
config = ModelConfig(
model_type=ModelType.LLAMA_3_8B,
max_tokens=256,
temperature=0.7
)
assert config.model == "meta/meta-llama-3-8b-instruct"