295 lines
11 KiB
Python
295 lines
11 KiB
Python
"""
|
|
Unit tests for model selector module.
|
|
"""
|
|
|
|
import pytest
|
|
|
|
from app.ai import (
|
|
ModelSelector,
|
|
ModelConfig,
|
|
UserTier,
|
|
ContextType,
|
|
ModelType,
|
|
)
|
|
|
|
|
|
class TestModelSelector:
|
|
"""Tests for ModelSelector class."""
|
|
|
|
def setup_method(self):
|
|
"""Set up test fixtures."""
|
|
self.selector = ModelSelector()
|
|
|
|
def test_initialization(self):
|
|
"""Test ModelSelector initializes correctly."""
|
|
assert self.selector is not None
|
|
|
|
# Test tier to model mapping
|
|
def test_free_tier_gets_llama(self):
|
|
"""Free tier should get Llama-3 8B."""
|
|
config = self.selector.select_model(UserTier.FREE)
|
|
assert config.model_type == ModelType.LLAMA_3_8B
|
|
|
|
def test_basic_tier_gets_haiku(self):
|
|
"""Basic tier should get Claude Haiku."""
|
|
config = self.selector.select_model(UserTier.BASIC)
|
|
assert config.model_type == ModelType.CLAUDE_HAIKU
|
|
|
|
def test_premium_tier_gets_sonnet(self):
|
|
"""Premium tier should get Claude Sonnet."""
|
|
config = self.selector.select_model(UserTier.PREMIUM)
|
|
assert config.model_type == ModelType.CLAUDE_SONNET
|
|
|
|
def test_elite_tier_gets_opus(self):
|
|
"""Elite tier should get Claude Opus."""
|
|
config = self.selector.select_model(UserTier.ELITE)
|
|
assert config.model_type == ModelType.CLAUDE_SONNET_4
|
|
|
|
# Test token limits by tier (using STORY_PROGRESSION for full allocation)
|
|
def test_free_tier_token_limit(self):
|
|
"""Free tier should have 256 base tokens."""
|
|
config = self.selector.select_model(UserTier.FREE, ContextType.STORY_PROGRESSION)
|
|
assert config.max_tokens == 256
|
|
|
|
def test_basic_tier_token_limit(self):
|
|
"""Basic tier should have 512 base tokens."""
|
|
config = self.selector.select_model(UserTier.BASIC, ContextType.STORY_PROGRESSION)
|
|
assert config.max_tokens == 512
|
|
|
|
def test_premium_tier_token_limit(self):
|
|
"""Premium tier should have 1024 base tokens."""
|
|
config = self.selector.select_model(UserTier.PREMIUM, ContextType.STORY_PROGRESSION)
|
|
assert config.max_tokens == 1024
|
|
|
|
def test_elite_tier_token_limit(self):
|
|
"""Elite tier should have 2048 base tokens."""
|
|
config = self.selector.select_model(UserTier.ELITE, ContextType.STORY_PROGRESSION)
|
|
assert config.max_tokens == 2048
|
|
|
|
# Test context-based token adjustments
|
|
def test_story_progression_full_tokens(self):
|
|
"""Story progression should use full token allocation."""
|
|
config = self.selector.select_model(
|
|
UserTier.PREMIUM,
|
|
ContextType.STORY_PROGRESSION
|
|
)
|
|
# Full allocation = 1024 tokens for premium
|
|
assert config.max_tokens == 1024
|
|
|
|
def test_combat_narration_reduced_tokens(self):
|
|
"""Combat narration should use 75% of tokens."""
|
|
config = self.selector.select_model(
|
|
UserTier.PREMIUM,
|
|
ContextType.COMBAT_NARRATION
|
|
)
|
|
# 75% of 1024 = 768
|
|
assert config.max_tokens == 768
|
|
|
|
def test_quest_selection_half_tokens(self):
|
|
"""Quest selection should use 50% of tokens."""
|
|
config = self.selector.select_model(
|
|
UserTier.PREMIUM,
|
|
ContextType.QUEST_SELECTION
|
|
)
|
|
# 50% of 1024 = 512
|
|
assert config.max_tokens == 512
|
|
|
|
def test_npc_dialogue_reduced_tokens(self):
|
|
"""NPC dialogue should use 75% of tokens."""
|
|
config = self.selector.select_model(
|
|
UserTier.PREMIUM,
|
|
ContextType.NPC_DIALOGUE
|
|
)
|
|
# 75% of 1024 = 768
|
|
assert config.max_tokens == 768
|
|
|
|
def test_simple_response_half_tokens(self):
|
|
"""Simple response should use 50% of tokens."""
|
|
config = self.selector.select_model(
|
|
UserTier.PREMIUM,
|
|
ContextType.SIMPLE_RESPONSE
|
|
)
|
|
# 50% of 1024 = 512
|
|
assert config.max_tokens == 512
|
|
|
|
# Test context-based temperature settings
|
|
def test_story_progression_high_temperature(self):
|
|
"""Story progression should have high temperature (0.9)."""
|
|
config = self.selector.select_model(
|
|
UserTier.PREMIUM,
|
|
ContextType.STORY_PROGRESSION
|
|
)
|
|
assert config.temperature == 0.9
|
|
|
|
def test_combat_narration_medium_high_temperature(self):
|
|
"""Combat narration should have medium-high temperature (0.8)."""
|
|
config = self.selector.select_model(
|
|
UserTier.PREMIUM,
|
|
ContextType.COMBAT_NARRATION
|
|
)
|
|
assert config.temperature == 0.8
|
|
|
|
def test_quest_selection_low_temperature(self):
|
|
"""Quest selection should have low temperature (0.5)."""
|
|
config = self.selector.select_model(
|
|
UserTier.PREMIUM,
|
|
ContextType.QUEST_SELECTION
|
|
)
|
|
assert config.temperature == 0.5
|
|
|
|
def test_npc_dialogue_medium_temperature(self):
|
|
"""NPC dialogue should have medium temperature (0.85)."""
|
|
config = self.selector.select_model(
|
|
UserTier.PREMIUM,
|
|
ContextType.NPC_DIALOGUE
|
|
)
|
|
assert config.temperature == 0.85
|
|
|
|
def test_simple_response_balanced_temperature(self):
|
|
"""Simple response should have balanced temperature (0.7)."""
|
|
config = self.selector.select_model(
|
|
UserTier.PREMIUM,
|
|
ContextType.SIMPLE_RESPONSE
|
|
)
|
|
assert config.temperature == 0.7
|
|
|
|
# Test ModelConfig properties
|
|
def test_model_config_model_property(self):
|
|
"""ModelConfig.model should return model identifier string."""
|
|
config = self.selector.select_model(UserTier.PREMIUM)
|
|
assert config.model == "anthropic/claude-3.5-sonnet"
|
|
|
|
# Test get_model_for_tier method
|
|
def test_get_model_for_tier_free(self):
|
|
"""get_model_for_tier should return correct model for free tier."""
|
|
model = self.selector.get_model_for_tier(UserTier.FREE)
|
|
assert model == ModelType.LLAMA_3_8B
|
|
|
|
def test_get_model_for_tier_elite(self):
|
|
"""get_model_for_tier should return correct model for elite tier."""
|
|
model = self.selector.get_model_for_tier(UserTier.ELITE)
|
|
assert model == ModelType.CLAUDE_SONNET_4
|
|
|
|
# Test get_tier_info method
|
|
def test_get_tier_info_structure(self):
|
|
"""get_tier_info should return complete tier information."""
|
|
info = self.selector.get_tier_info(UserTier.PREMIUM)
|
|
|
|
assert "tier" in info
|
|
assert "model" in info
|
|
assert "model_name" in info
|
|
assert "base_tokens" in info
|
|
assert "quality" in info
|
|
|
|
def test_get_tier_info_premium_values(self):
|
|
"""get_tier_info should return correct values for premium tier."""
|
|
info = self.selector.get_tier_info(UserTier.PREMIUM)
|
|
|
|
assert info["tier"] == "premium"
|
|
assert info["model"] == "anthropic/claude-3.5-sonnet"
|
|
assert info["model_name"] == "Claude 3.5 Sonnet"
|
|
assert info["base_tokens"] == 1024
|
|
|
|
def test_get_tier_info_free_values(self):
|
|
"""get_tier_info should return correct values for free tier."""
|
|
info = self.selector.get_tier_info(UserTier.FREE)
|
|
|
|
assert info["tier"] == "free"
|
|
assert info["model_name"] == "Llama 3 8B"
|
|
assert info["base_tokens"] == 256
|
|
|
|
# Test estimate_cost_per_request method
|
|
def test_free_tier_zero_cost(self):
|
|
"""Free tier should have zero cost."""
|
|
cost = self.selector.estimate_cost_per_request(UserTier.FREE)
|
|
assert cost == 0.0
|
|
|
|
def test_basic_tier_has_cost(self):
|
|
"""Basic tier should have non-zero cost."""
|
|
cost = self.selector.estimate_cost_per_request(UserTier.BASIC)
|
|
assert cost > 0
|
|
|
|
def test_premium_tier_higher_cost(self):
|
|
"""Premium tier should have higher cost than basic."""
|
|
basic_cost = self.selector.estimate_cost_per_request(UserTier.BASIC)
|
|
premium_cost = self.selector.estimate_cost_per_request(UserTier.PREMIUM)
|
|
assert premium_cost > basic_cost
|
|
|
|
def test_elite_tier_highest_cost(self):
|
|
"""Elite tier should have highest cost."""
|
|
premium_cost = self.selector.estimate_cost_per_request(UserTier.PREMIUM)
|
|
elite_cost = self.selector.estimate_cost_per_request(UserTier.ELITE)
|
|
assert elite_cost > premium_cost
|
|
|
|
# Test all tier combinations
|
|
def test_all_tiers_return_valid_config(self):
|
|
"""All tiers should return valid ModelConfig objects."""
|
|
for tier in UserTier:
|
|
config = self.selector.select_model(tier)
|
|
assert isinstance(config, ModelConfig)
|
|
assert config.model_type in ModelType
|
|
assert config.max_tokens > 0
|
|
assert 0 <= config.temperature <= 1
|
|
|
|
# Test all context combinations
|
|
def test_all_contexts_return_valid_config(self):
|
|
"""All context types should return valid ModelConfig objects."""
|
|
for context in ContextType:
|
|
config = self.selector.select_model(UserTier.PREMIUM, context)
|
|
assert isinstance(config, ModelConfig)
|
|
assert config.max_tokens > 0
|
|
assert 0 <= config.temperature <= 1
|
|
|
|
|
|
class TestUserTierEnum:
|
|
"""Tests for UserTier enum."""
|
|
|
|
def test_tier_values(self):
|
|
"""Test UserTier enum values are correct strings."""
|
|
assert UserTier.FREE.value == "free"
|
|
assert UserTier.BASIC.value == "basic"
|
|
assert UserTier.PREMIUM.value == "premium"
|
|
assert UserTier.ELITE.value == "elite"
|
|
|
|
def test_tier_string_conversion(self):
|
|
"""Test UserTier can be converted to string."""
|
|
assert str(UserTier.FREE) == "UserTier.FREE"
|
|
|
|
|
|
class TestContextTypeEnum:
|
|
"""Tests for ContextType enum."""
|
|
|
|
def test_context_values(self):
|
|
"""Test ContextType enum values are correct strings."""
|
|
assert ContextType.STORY_PROGRESSION.value == "story_progression"
|
|
assert ContextType.COMBAT_NARRATION.value == "combat_narration"
|
|
assert ContextType.QUEST_SELECTION.value == "quest_selection"
|
|
assert ContextType.NPC_DIALOGUE.value == "npc_dialogue"
|
|
assert ContextType.SIMPLE_RESPONSE.value == "simple_response"
|
|
|
|
|
|
class TestModelConfig:
|
|
"""Tests for ModelConfig dataclass."""
|
|
|
|
def test_model_config_creation(self):
|
|
"""Test ModelConfig can be created with valid data."""
|
|
config = ModelConfig(
|
|
model_type=ModelType.CLAUDE_SONNET,
|
|
max_tokens=1024,
|
|
temperature=0.9
|
|
)
|
|
|
|
assert config.model_type == ModelType.CLAUDE_SONNET
|
|
assert config.max_tokens == 1024
|
|
assert config.temperature == 0.9
|
|
|
|
def test_model_property(self):
|
|
"""Test model property returns model identifier."""
|
|
config = ModelConfig(
|
|
model_type=ModelType.LLAMA_3_8B,
|
|
max_tokens=256,
|
|
temperature=0.7
|
|
)
|
|
|
|
assert config.model == "meta/meta-llama-3-8b-instruct"
|