first commit
This commit is contained in:
3
api/tests/__init__.py
Normal file
3
api/tests/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Unit tests for Code of Conquest.
|
||||
"""
|
||||
311
api/tests/test_action_prompt.py
Normal file
311
api/tests/test_action_prompt.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""
|
||||
Tests for ActionPrompt model
|
||||
|
||||
Tests the action prompt availability logic, tier filtering,
|
||||
location filtering, and serialization.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.action_prompt import (
|
||||
ActionPrompt,
|
||||
ActionCategory,
|
||||
LocationType,
|
||||
)
|
||||
from app.ai.model_selector import UserTier
|
||||
|
||||
|
||||
class TestActionPrompt:
|
||||
"""Tests for ActionPrompt dataclass."""
|
||||
|
||||
@pytest.fixture
|
||||
def free_action(self):
|
||||
"""Create a free tier action available in towns."""
|
||||
return ActionPrompt(
|
||||
prompt_id="ask_locals",
|
||||
category=ActionCategory.ASK_QUESTION,
|
||||
display_text="Ask locals for information",
|
||||
description="Talk to NPCs to learn about quests and rumors",
|
||||
tier_required=UserTier.FREE,
|
||||
context_filter=[LocationType.TOWN, LocationType.TAVERN],
|
||||
dm_prompt_template="The player asks locals about {{ topic }}.",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def premium_action(self):
|
||||
"""Create a premium tier action available anywhere."""
|
||||
return ActionPrompt(
|
||||
prompt_id="investigate",
|
||||
category=ActionCategory.GATHER_INFO,
|
||||
display_text="Investigate suspicious activity",
|
||||
description="Look for clues and hidden details",
|
||||
tier_required=UserTier.PREMIUM,
|
||||
context_filter=[LocationType.ANY],
|
||||
dm_prompt_template="The player investigates the area.",
|
||||
icon="magnifying_glass",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def elite_action(self):
|
||||
"""Create an elite tier action for libraries."""
|
||||
return ActionPrompt(
|
||||
prompt_id="consult_texts",
|
||||
category=ActionCategory.SPECIAL,
|
||||
display_text="Consult ancient texts",
|
||||
description="Study rare manuscripts for hidden knowledge",
|
||||
tier_required=UserTier.ELITE,
|
||||
context_filter=[LocationType.LIBRARY, LocationType.TOWN],
|
||||
dm_prompt_template="The player studies ancient texts.",
|
||||
cooldown_turns=3,
|
||||
)
|
||||
|
||||
# Availability tests
|
||||
|
||||
def test_free_action_available_to_free_user(self, free_action):
|
||||
"""Free action should be available to free tier users."""
|
||||
assert free_action.is_available(UserTier.FREE, LocationType.TOWN) is True
|
||||
|
||||
def test_free_action_available_to_premium_user(self, free_action):
|
||||
"""Free action should be available to higher tier users."""
|
||||
assert free_action.is_available(UserTier.PREMIUM, LocationType.TOWN) is True
|
||||
assert free_action.is_available(UserTier.ELITE, LocationType.TOWN) is True
|
||||
|
||||
def test_premium_action_not_available_to_free_user(self, premium_action):
|
||||
"""Premium action should not be available to free tier users."""
|
||||
assert premium_action.is_available(UserTier.FREE, LocationType.TOWN) is False
|
||||
|
||||
def test_premium_action_available_to_premium_user(self, premium_action):
|
||||
"""Premium action should be available to premium tier users."""
|
||||
assert premium_action.is_available(UserTier.PREMIUM, LocationType.TOWN) is True
|
||||
|
||||
def test_elite_action_not_available_to_premium_user(self, elite_action):
|
||||
"""Elite action should not be available to premium tier users."""
|
||||
assert elite_action.is_available(UserTier.PREMIUM, LocationType.LIBRARY) is False
|
||||
|
||||
def test_elite_action_available_to_elite_user(self, elite_action):
|
||||
"""Elite action should be available to elite tier users."""
|
||||
assert elite_action.is_available(UserTier.ELITE, LocationType.LIBRARY) is True
|
||||
|
||||
# Location filtering tests
|
||||
|
||||
def test_action_available_in_matching_location(self, free_action):
|
||||
"""Action should be available in matching locations."""
|
||||
assert free_action.is_available(UserTier.FREE, LocationType.TOWN) is True
|
||||
assert free_action.is_available(UserTier.FREE, LocationType.TAVERN) is True
|
||||
|
||||
def test_action_not_available_in_non_matching_location(self, free_action):
|
||||
"""Action should not be available in non-matching locations."""
|
||||
assert free_action.is_available(UserTier.FREE, LocationType.WILDERNESS) is False
|
||||
assert free_action.is_available(UserTier.FREE, LocationType.DUNGEON) is False
|
||||
|
||||
def test_any_location_matches_all(self, premium_action):
|
||||
"""Action with ANY location should be available everywhere."""
|
||||
assert premium_action.is_available(UserTier.PREMIUM, LocationType.TOWN) is True
|
||||
assert premium_action.is_available(UserTier.PREMIUM, LocationType.WILDERNESS) is True
|
||||
assert premium_action.is_available(UserTier.PREMIUM, LocationType.DUNGEON) is True
|
||||
assert premium_action.is_available(UserTier.PREMIUM, LocationType.LIBRARY) is True
|
||||
|
||||
def test_both_tier_and_location_must_match(self, free_action):
|
||||
"""Both tier and location requirements must be met."""
|
||||
# Wrong location, right tier
|
||||
assert free_action.is_available(UserTier.ELITE, LocationType.DUNGEON) is False
|
||||
|
||||
# Lock status tests
|
||||
|
||||
def test_is_locked_for_lower_tier(self, premium_action):
|
||||
"""Action should be locked for lower tier users."""
|
||||
assert premium_action.is_locked(UserTier.FREE) is True
|
||||
assert premium_action.is_locked(UserTier.BASIC) is True
|
||||
|
||||
def test_is_not_locked_for_sufficient_tier(self, premium_action):
|
||||
"""Action should not be locked for sufficient tier users."""
|
||||
assert premium_action.is_locked(UserTier.PREMIUM) is False
|
||||
assert premium_action.is_locked(UserTier.ELITE) is False
|
||||
|
||||
def test_get_lock_reason_returns_message(self, premium_action):
|
||||
"""Lock reason should explain tier requirement."""
|
||||
reason = premium_action.get_lock_reason(UserTier.FREE)
|
||||
assert reason is not None
|
||||
assert "Premium" in reason
|
||||
|
||||
def test_get_lock_reason_returns_none_when_unlocked(self, premium_action):
|
||||
"""Lock reason should be None when action is unlocked."""
|
||||
reason = premium_action.get_lock_reason(UserTier.PREMIUM)
|
||||
assert reason is None
|
||||
|
||||
# Tier hierarchy tests
|
||||
|
||||
def test_tier_hierarchy_free_to_elite(self):
|
||||
"""Test full tier hierarchy from FREE to ELITE."""
|
||||
action = ActionPrompt(
|
||||
prompt_id="test",
|
||||
category=ActionCategory.EXPLORE,
|
||||
display_text="Test",
|
||||
description="Test action",
|
||||
tier_required=UserTier.BASIC,
|
||||
context_filter=[LocationType.ANY],
|
||||
dm_prompt_template="Test",
|
||||
)
|
||||
|
||||
# FREE < BASIC (should fail)
|
||||
assert action.is_available(UserTier.FREE, LocationType.TOWN) is False
|
||||
|
||||
# BASIC >= BASIC (should pass)
|
||||
assert action.is_available(UserTier.BASIC, LocationType.TOWN) is True
|
||||
|
||||
# PREMIUM > BASIC (should pass)
|
||||
assert action.is_available(UserTier.PREMIUM, LocationType.TOWN) is True
|
||||
|
||||
# ELITE > BASIC (should pass)
|
||||
assert action.is_available(UserTier.ELITE, LocationType.TOWN) is True
|
||||
|
||||
# Serialization tests
|
||||
|
||||
def test_to_dict(self, free_action):
|
||||
"""Test serialization to dictionary."""
|
||||
data = free_action.to_dict()
|
||||
|
||||
assert data["prompt_id"] == "ask_locals"
|
||||
assert data["category"] == "ask_question"
|
||||
assert data["display_text"] == "Ask locals for information"
|
||||
assert data["tier_required"] == "free"
|
||||
assert data["context_filter"] == ["town", "tavern"]
|
||||
assert "dm_prompt_template" in data
|
||||
|
||||
def test_from_dict(self):
|
||||
"""Test deserialization from dictionary."""
|
||||
data = {
|
||||
"prompt_id": "explore_area",
|
||||
"category": "explore",
|
||||
"display_text": "Explore the area",
|
||||
"description": "Look around for points of interest",
|
||||
"tier_required": "free",
|
||||
"context_filter": ["wilderness", "dungeon"],
|
||||
"dm_prompt_template": "The player explores {{ location }}.",
|
||||
"icon": "compass",
|
||||
"cooldown_turns": 2,
|
||||
}
|
||||
|
||||
action = ActionPrompt.from_dict(data)
|
||||
|
||||
assert action.prompt_id == "explore_area"
|
||||
assert action.category == ActionCategory.EXPLORE
|
||||
assert action.tier_required == UserTier.FREE
|
||||
assert LocationType.WILDERNESS in action.context_filter
|
||||
assert action.icon == "compass"
|
||||
assert action.cooldown_turns == 2
|
||||
|
||||
def test_round_trip_serialization(self, free_action):
|
||||
"""Test that to_dict and from_dict are inverse operations."""
|
||||
data = free_action.to_dict()
|
||||
restored = ActionPrompt.from_dict(data)
|
||||
|
||||
assert restored.prompt_id == free_action.prompt_id
|
||||
assert restored.category == free_action.category
|
||||
assert restored.display_text == free_action.display_text
|
||||
assert restored.tier_required == free_action.tier_required
|
||||
assert restored.context_filter == free_action.context_filter
|
||||
|
||||
def test_from_dict_invalid_category(self):
|
||||
"""Test error handling for invalid category."""
|
||||
data = {
|
||||
"prompt_id": "test",
|
||||
"category": "invalid_category",
|
||||
"display_text": "Test",
|
||||
"description": "Test",
|
||||
"tier_required": "free",
|
||||
"context_filter": ["any"],
|
||||
"dm_prompt_template": "Test",
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ActionPrompt.from_dict(data)
|
||||
|
||||
assert "Invalid action category" in str(exc_info.value)
|
||||
|
||||
def test_from_dict_invalid_tier(self):
|
||||
"""Test error handling for invalid tier."""
|
||||
data = {
|
||||
"prompt_id": "test",
|
||||
"category": "explore",
|
||||
"display_text": "Test",
|
||||
"description": "Test",
|
||||
"tier_required": "super_premium",
|
||||
"context_filter": ["any"],
|
||||
"dm_prompt_template": "Test",
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ActionPrompt.from_dict(data)
|
||||
|
||||
assert "Invalid user tier" in str(exc_info.value)
|
||||
|
||||
def test_from_dict_invalid_location(self):
|
||||
"""Test error handling for invalid location type."""
|
||||
data = {
|
||||
"prompt_id": "test",
|
||||
"category": "explore",
|
||||
"display_text": "Test",
|
||||
"description": "Test",
|
||||
"tier_required": "free",
|
||||
"context_filter": ["invalid_location"],
|
||||
"dm_prompt_template": "Test",
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ActionPrompt.from_dict(data)
|
||||
|
||||
assert "Invalid location type" in str(exc_info.value)
|
||||
|
||||
# Optional fields tests
|
||||
|
||||
def test_optional_icon(self, free_action, premium_action):
|
||||
"""Test that icon is optional."""
|
||||
assert free_action.icon is None
|
||||
assert premium_action.icon == "magnifying_glass"
|
||||
|
||||
def test_default_cooldown(self, free_action, elite_action):
|
||||
"""Test default and custom cooldown values."""
|
||||
assert free_action.cooldown_turns == 0
|
||||
assert elite_action.cooldown_turns == 3
|
||||
|
||||
# Repr test
|
||||
|
||||
def test_repr(self, free_action):
|
||||
"""Test string representation."""
|
||||
repr_str = repr(free_action)
|
||||
assert "ask_locals" in repr_str
|
||||
assert "ask_question" in repr_str
|
||||
assert "free" in repr_str
|
||||
|
||||
|
||||
class TestActionCategory:
|
||||
"""Tests for ActionCategory enum."""
|
||||
|
||||
def test_all_categories_defined(self):
|
||||
"""Verify all expected categories exist."""
|
||||
categories = [cat.value for cat in ActionCategory]
|
||||
|
||||
assert "ask_question" in categories
|
||||
assert "travel" in categories
|
||||
assert "gather_info" in categories
|
||||
assert "rest" in categories
|
||||
assert "interact" in categories
|
||||
assert "explore" in categories
|
||||
assert "special" in categories
|
||||
|
||||
|
||||
class TestLocationType:
|
||||
"""Tests for LocationType enum."""
|
||||
|
||||
def test_all_location_types_defined(self):
|
||||
"""Verify all expected location types exist."""
|
||||
locations = [loc.value for loc in LocationType]
|
||||
|
||||
assert "town" in locations
|
||||
assert "tavern" in locations
|
||||
assert "wilderness" in locations
|
||||
assert "dungeon" in locations
|
||||
assert "safe_area" in locations
|
||||
assert "library" in locations
|
||||
assert "any" in locations
|
||||
314
api/tests/test_action_prompt_loader.py
Normal file
314
api/tests/test_action_prompt_loader.py
Normal file
@@ -0,0 +1,314 @@
|
||||
"""
|
||||
Tests for ActionPromptLoader service
|
||||
|
||||
Tests loading from YAML, filtering by tier and location,
|
||||
and error handling.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
from app.services.action_prompt_loader import (
|
||||
ActionPromptLoader,
|
||||
ActionPromptLoaderError,
|
||||
ActionPromptNotFoundError,
|
||||
)
|
||||
from app.models.action_prompt import LocationType
|
||||
from app.ai.model_selector import UserTier
|
||||
|
||||
|
||||
class TestActionPromptLoader:
|
||||
"""Tests for ActionPromptLoader service."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_singleton(self):
|
||||
"""Reset singleton before each test."""
|
||||
ActionPromptLoader.reset_instance()
|
||||
yield
|
||||
ActionPromptLoader.reset_instance()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_yaml(self):
|
||||
"""Create a sample YAML file for testing."""
|
||||
content = """
|
||||
action_prompts:
|
||||
- prompt_id: test_free
|
||||
category: explore
|
||||
display_text: Free Action
|
||||
description: Available to all
|
||||
tier_required: free
|
||||
context_filter: [town, tavern]
|
||||
dm_prompt_template: Test template
|
||||
|
||||
- prompt_id: test_premium
|
||||
category: gather_info
|
||||
display_text: Premium Action
|
||||
description: Premium only
|
||||
tier_required: premium
|
||||
context_filter: [any]
|
||||
dm_prompt_template: Premium template
|
||||
|
||||
- prompt_id: test_elite
|
||||
category: special
|
||||
display_text: Elite Action
|
||||
description: Elite only
|
||||
tier_required: elite
|
||||
context_filter: [library]
|
||||
dm_prompt_template: Elite template
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
|
||||
f.write(content)
|
||||
filepath = f.name
|
||||
|
||||
yield filepath
|
||||
os.unlink(filepath)
|
||||
|
||||
@pytest.fixture
|
||||
def loader(self, sample_yaml):
|
||||
"""Create a loader with sample data."""
|
||||
loader = ActionPromptLoader()
|
||||
loader.load_from_yaml(sample_yaml)
|
||||
return loader
|
||||
|
||||
# Loading tests
|
||||
|
||||
def test_load_from_yaml(self, sample_yaml):
|
||||
"""Test loading prompts from YAML file."""
|
||||
loader = ActionPromptLoader()
|
||||
count = loader.load_from_yaml(sample_yaml)
|
||||
|
||||
assert count == 3
|
||||
assert loader.is_loaded()
|
||||
|
||||
def test_load_file_not_found(self):
|
||||
"""Test error when file doesn't exist."""
|
||||
loader = ActionPromptLoader()
|
||||
|
||||
with pytest.raises(ActionPromptLoaderError) as exc_info:
|
||||
loader.load_from_yaml("/nonexistent/path.yaml")
|
||||
|
||||
assert "not found" in str(exc_info.value)
|
||||
|
||||
def test_load_invalid_yaml(self):
|
||||
"""Test error when YAML is malformed."""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
|
||||
f.write("invalid: yaml: content: [")
|
||||
filepath = f.name
|
||||
|
||||
try:
|
||||
loader = ActionPromptLoader()
|
||||
with pytest.raises(ActionPromptLoaderError) as exc_info:
|
||||
loader.load_from_yaml(filepath)
|
||||
|
||||
assert "Invalid YAML" in str(exc_info.value)
|
||||
finally:
|
||||
os.unlink(filepath)
|
||||
|
||||
def test_load_missing_key(self):
|
||||
"""Test error when action_prompts key is missing."""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
|
||||
f.write("other_key: value")
|
||||
filepath = f.name
|
||||
|
||||
try:
|
||||
loader = ActionPromptLoader()
|
||||
with pytest.raises(ActionPromptLoaderError) as exc_info:
|
||||
loader.load_from_yaml(filepath)
|
||||
|
||||
assert "Missing 'action_prompts'" in str(exc_info.value)
|
||||
finally:
|
||||
os.unlink(filepath)
|
||||
|
||||
# Get methods tests
|
||||
|
||||
def test_get_all_actions(self, loader):
|
||||
"""Test getting all loaded actions."""
|
||||
actions = loader.get_all_actions()
|
||||
|
||||
assert len(actions) == 3
|
||||
prompt_ids = [a.prompt_id for a in actions]
|
||||
assert "test_free" in prompt_ids
|
||||
assert "test_premium" in prompt_ids
|
||||
assert "test_elite" in prompt_ids
|
||||
|
||||
def test_get_action_by_id(self, loader):
|
||||
"""Test getting action by ID."""
|
||||
action = loader.get_action_by_id("test_free")
|
||||
|
||||
assert action.prompt_id == "test_free"
|
||||
assert action.display_text == "Free Action"
|
||||
|
||||
def test_get_action_by_id_not_found(self, loader):
|
||||
"""Test error when action ID not found."""
|
||||
with pytest.raises(ActionPromptNotFoundError):
|
||||
loader.get_action_by_id("nonexistent")
|
||||
|
||||
# Filtering tests
|
||||
|
||||
def test_get_available_actions_free_tier(self, loader):
|
||||
"""Test filtering for free tier user."""
|
||||
actions = loader.get_available_actions(UserTier.FREE, LocationType.TOWN)
|
||||
|
||||
assert len(actions) == 1
|
||||
assert actions[0].prompt_id == "test_free"
|
||||
|
||||
def test_get_available_actions_premium_tier(self, loader):
|
||||
"""Test filtering for premium tier user."""
|
||||
actions = loader.get_available_actions(UserTier.PREMIUM, LocationType.TOWN)
|
||||
|
||||
# Premium gets: test_free (town) + test_premium (any)
|
||||
assert len(actions) == 2
|
||||
prompt_ids = [a.prompt_id for a in actions]
|
||||
assert "test_free" in prompt_ids
|
||||
assert "test_premium" in prompt_ids
|
||||
|
||||
def test_get_available_actions_elite_tier(self, loader):
|
||||
"""Test filtering for elite tier user."""
|
||||
actions = loader.get_available_actions(UserTier.ELITE, LocationType.LIBRARY)
|
||||
|
||||
# Elite in library gets: test_premium (any) + test_elite (library)
|
||||
assert len(actions) == 2
|
||||
prompt_ids = [a.prompt_id for a in actions]
|
||||
assert "test_premium" in prompt_ids
|
||||
assert "test_elite" in prompt_ids
|
||||
|
||||
def test_get_available_actions_location_filter(self, loader):
|
||||
"""Test that location filtering works correctly."""
|
||||
# In town, no elite actions available
|
||||
actions = loader.get_available_actions(UserTier.ELITE, LocationType.TOWN)
|
||||
prompt_ids = [a.prompt_id for a in actions]
|
||||
|
||||
assert "test_elite" not in prompt_ids # Only in library
|
||||
|
||||
def test_get_actions_by_tier(self, loader):
|
||||
"""Test getting actions by tier without location filter."""
|
||||
free_actions = loader.get_actions_by_tier(UserTier.FREE)
|
||||
premium_actions = loader.get_actions_by_tier(UserTier.PREMIUM)
|
||||
elite_actions = loader.get_actions_by_tier(UserTier.ELITE)
|
||||
|
||||
assert len(free_actions) == 1
|
||||
assert len(premium_actions) == 2
|
||||
assert len(elite_actions) == 3
|
||||
|
||||
def test_get_actions_by_category(self, loader):
|
||||
"""Test getting actions by category."""
|
||||
explore_actions = loader.get_actions_by_category("explore")
|
||||
special_actions = loader.get_actions_by_category("special")
|
||||
|
||||
assert len(explore_actions) == 1
|
||||
assert explore_actions[0].prompt_id == "test_free"
|
||||
|
||||
assert len(special_actions) == 1
|
||||
assert special_actions[0].prompt_id == "test_elite"
|
||||
|
||||
def test_get_locked_actions(self, loader):
|
||||
"""Test getting locked actions for upgrade prompts."""
|
||||
# Free user in library sees elite action as locked
|
||||
locked = loader.get_locked_actions(UserTier.FREE, LocationType.LIBRARY)
|
||||
|
||||
# test_premium (any) and test_elite (library) are locked for free
|
||||
assert len(locked) == 2
|
||||
prompt_ids = [a.prompt_id for a in locked]
|
||||
assert "test_premium" in prompt_ids
|
||||
assert "test_elite" in prompt_ids
|
||||
|
||||
# Singleton and reload tests
|
||||
|
||||
def test_singleton_pattern(self, sample_yaml):
|
||||
"""Test that loader is singleton."""
|
||||
loader1 = ActionPromptLoader()
|
||||
loader1.load_from_yaml(sample_yaml)
|
||||
|
||||
loader2 = ActionPromptLoader()
|
||||
|
||||
assert loader1 is loader2
|
||||
assert loader2.is_loaded()
|
||||
|
||||
def test_reload(self, sample_yaml):
|
||||
"""Test reloading prompts."""
|
||||
loader = ActionPromptLoader()
|
||||
loader.load_from_yaml(sample_yaml)
|
||||
|
||||
# Modify and reload
|
||||
count = loader.reload(sample_yaml)
|
||||
assert count == 3
|
||||
|
||||
def test_get_prompt_count(self, loader):
|
||||
"""Test getting prompt count."""
|
||||
assert loader.get_prompt_count() == 3
|
||||
|
||||
|
||||
class TestActionPromptLoaderIntegration:
|
||||
"""Integration tests with actual YAML file."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_singleton(self):
|
||||
"""Reset singleton before each test."""
|
||||
ActionPromptLoader.reset_instance()
|
||||
yield
|
||||
ActionPromptLoader.reset_instance()
|
||||
|
||||
def test_load_actual_yaml(self):
|
||||
"""Test loading the actual action_prompts.yaml file."""
|
||||
loader = ActionPromptLoader()
|
||||
filepath = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'..', 'app', 'data', 'action_prompts.yaml'
|
||||
)
|
||||
|
||||
if os.path.exists(filepath):
|
||||
count = loader.load_from_yaml(filepath)
|
||||
|
||||
# Should have 10 actions
|
||||
assert count == 10
|
||||
|
||||
# Verify tier distribution
|
||||
free_actions = loader.get_actions_by_tier(UserTier.FREE)
|
||||
premium_actions = loader.get_actions_by_tier(UserTier.PREMIUM)
|
||||
elite_actions = loader.get_actions_by_tier(UserTier.ELITE)
|
||||
|
||||
assert len(free_actions) == 4 # Only free tier
|
||||
assert len(premium_actions) == 7 # Free + premium
|
||||
assert len(elite_actions) == 10 # All
|
||||
|
||||
def test_free_tier_town_actions(self):
|
||||
"""Test free tier actions in town location."""
|
||||
loader = ActionPromptLoader()
|
||||
filepath = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'..', 'app', 'data', 'action_prompts.yaml'
|
||||
)
|
||||
|
||||
if os.path.exists(filepath):
|
||||
loader.load_from_yaml(filepath)
|
||||
|
||||
actions = loader.get_available_actions(UserTier.FREE, LocationType.TOWN)
|
||||
|
||||
# Free user in town should have:
|
||||
# - ask_locals (town/tavern)
|
||||
# - search_supplies (any)
|
||||
# - rest_recover (town/tavern/safe_area)
|
||||
prompt_ids = [a.prompt_id for a in actions]
|
||||
assert "ask_locals" in prompt_ids
|
||||
assert "search_supplies" in prompt_ids
|
||||
assert "rest_recover" in prompt_ids
|
||||
|
||||
def test_premium_tier_wilderness_actions(self):
|
||||
"""Test premium tier actions in wilderness location."""
|
||||
loader = ActionPromptLoader()
|
||||
filepath = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'..', 'app', 'data', 'action_prompts.yaml'
|
||||
)
|
||||
|
||||
if os.path.exists(filepath):
|
||||
loader.load_from_yaml(filepath)
|
||||
|
||||
actions = loader.get_available_actions(UserTier.PREMIUM, LocationType.WILDERNESS)
|
||||
|
||||
prompt_ids = [a.prompt_id for a in actions]
|
||||
# Should include wilderness actions
|
||||
assert "explore_area" in prompt_ids
|
||||
assert "make_camp" in prompt_ids
|
||||
assert "search_supplies" in prompt_ids
|
||||
571
api/tests/test_ai_tasks.py
Normal file
571
api/tests/test_ai_tasks.py
Normal file
@@ -0,0 +1,571 @@
|
||||
"""
|
||||
Unit tests for AI Task Jobs.
|
||||
|
||||
These tests verify the job enqueueing, status tracking,
|
||||
and result storage functionality.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
import json
|
||||
|
||||
from app.tasks.ai_tasks import (
|
||||
enqueue_ai_task,
|
||||
process_ai_task,
|
||||
get_job_status,
|
||||
get_job_result,
|
||||
JobStatus,
|
||||
TaskType,
|
||||
TaskPriority,
|
||||
_store_job_status,
|
||||
_update_job_status,
|
||||
_store_job_result,
|
||||
)
|
||||
|
||||
|
||||
class TestEnqueueAITask:
|
||||
"""Test AI task enqueueing."""
|
||||
|
||||
@patch('app.tasks.ai_tasks.get_queue')
|
||||
@patch('app.tasks.ai_tasks._store_job_status')
|
||||
def test_enqueue_narrative_task(self, mock_store_status, mock_get_queue):
|
||||
"""Test enqueueing a narrative task."""
|
||||
# Setup mock queue
|
||||
mock_queue = MagicMock()
|
||||
mock_job = MagicMock()
|
||||
mock_job.id = "test_job_123"
|
||||
mock_queue.enqueue.return_value = mock_job
|
||||
mock_get_queue.return_value = mock_queue
|
||||
|
||||
# Enqueue task
|
||||
result = enqueue_ai_task(
|
||||
task_type="narrative",
|
||||
user_id="user_123",
|
||||
context={"action": "explore"},
|
||||
priority="normal",
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert "job_id" in result
|
||||
assert result["status"] == "queued"
|
||||
mock_queue.enqueue.assert_called_once()
|
||||
mock_store_status.assert_called_once()
|
||||
|
||||
@patch('app.tasks.ai_tasks.get_queue')
|
||||
@patch('app.tasks.ai_tasks._store_job_status')
|
||||
def test_enqueue_high_priority_task(self, mock_store_status, mock_get_queue):
|
||||
"""Test high priority task goes to front of queue."""
|
||||
mock_queue = MagicMock()
|
||||
mock_job = MagicMock()
|
||||
mock_queue.enqueue.return_value = mock_job
|
||||
mock_get_queue.return_value = mock_queue
|
||||
|
||||
enqueue_ai_task(
|
||||
task_type="narrative",
|
||||
user_id="user_123",
|
||||
context={},
|
||||
priority="high",
|
||||
)
|
||||
|
||||
# Verify at_front=True for high priority
|
||||
call_kwargs = mock_queue.enqueue.call_args[1]
|
||||
assert call_kwargs["at_front"] is True
|
||||
|
||||
@patch('app.tasks.ai_tasks.get_queue')
|
||||
@patch('app.tasks.ai_tasks._store_job_status')
|
||||
def test_enqueue_with_session_and_character(self, mock_store_status, mock_get_queue):
|
||||
"""Test enqueueing with session and character IDs."""
|
||||
mock_queue = MagicMock()
|
||||
mock_job = MagicMock()
|
||||
mock_queue.enqueue.return_value = mock_job
|
||||
mock_get_queue.return_value = mock_queue
|
||||
|
||||
result = enqueue_ai_task(
|
||||
task_type="combat",
|
||||
user_id="user_123",
|
||||
context={"enemy": "goblin"},
|
||||
session_id="sess_456",
|
||||
character_id="char_789",
|
||||
)
|
||||
|
||||
assert "job_id" in result
|
||||
|
||||
# Verify kwargs passed to enqueue include session and character
|
||||
call_kwargs = mock_queue.enqueue.call_args[1]
|
||||
job_kwargs = call_kwargs["kwargs"]
|
||||
assert job_kwargs["session_id"] == "sess_456"
|
||||
assert job_kwargs["character_id"] == "char_789"
|
||||
|
||||
def test_enqueue_invalid_task_type(self):
|
||||
"""Test enqueueing with invalid task type raises error."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
enqueue_ai_task(
|
||||
task_type="invalid_type",
|
||||
user_id="user_123",
|
||||
context={},
|
||||
)
|
||||
|
||||
assert "Invalid task_type" in str(exc_info.value)
|
||||
|
||||
def test_enqueue_invalid_priority(self):
|
||||
"""Test enqueueing with invalid priority raises error."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
enqueue_ai_task(
|
||||
task_type="narrative",
|
||||
user_id="user_123",
|
||||
context={},
|
||||
priority="invalid_priority",
|
||||
)
|
||||
|
||||
assert "Invalid priority" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestProcessAITask:
|
||||
"""Test AI task processing with mocked NarrativeGenerator."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_character(self):
|
||||
"""Sample character data for tests."""
|
||||
return {
|
||||
"name": "Aldric",
|
||||
"level": 3,
|
||||
"player_class": "Fighter",
|
||||
"current_hp": 25,
|
||||
"max_hp": 30,
|
||||
"stats": {"strength": 16, "dexterity": 12},
|
||||
"skills": [],
|
||||
"effects": []
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def sample_game_state(self):
|
||||
"""Sample game state for tests."""
|
||||
return {
|
||||
"current_location": "Tavern",
|
||||
"location_type": "TAVERN",
|
||||
"discovered_locations": [],
|
||||
"active_quests": []
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_narrative_response(self):
|
||||
"""Mock NarrativeResponse for tests."""
|
||||
from app.ai.narrative_generator import NarrativeResponse
|
||||
return NarrativeResponse(
|
||||
narrative="You enter the tavern...",
|
||||
tokens_used=150,
|
||||
model="meta/meta-llama-3-8b-instruct",
|
||||
context_type="story_progression",
|
||||
generation_time=2.5
|
||||
)
|
||||
|
||||
@patch('app.tasks.ai_tasks._get_user_tier')
|
||||
@patch('app.tasks.ai_tasks._update_game_session')
|
||||
@patch('app.tasks.ai_tasks.NarrativeGenerator')
|
||||
@patch('app.tasks.ai_tasks._update_job_status')
|
||||
@patch('app.tasks.ai_tasks._store_job_result')
|
||||
def test_process_narrative_task(
|
||||
self, mock_store_result, mock_update_status, mock_generator_class,
|
||||
mock_update_session, mock_get_tier, sample_character, sample_game_state, mock_narrative_response
|
||||
):
|
||||
"""Test processing a narrative task with NarrativeGenerator."""
|
||||
from app.ai.model_selector import UserTier
|
||||
|
||||
# Setup mocks
|
||||
mock_get_tier.return_value = UserTier.FREE
|
||||
mock_generator = MagicMock()
|
||||
mock_generator.generate_story_response.return_value = mock_narrative_response
|
||||
mock_generator_class.return_value = mock_generator
|
||||
|
||||
# Process task
|
||||
result = process_ai_task(
|
||||
task_type="narrative",
|
||||
user_id="user_123",
|
||||
context={
|
||||
"action": "I explore the tavern",
|
||||
"character": sample_character,
|
||||
"game_state": sample_game_state
|
||||
},
|
||||
job_id="job_123",
|
||||
session_id="sess_456",
|
||||
character_id="char_789"
|
||||
)
|
||||
|
||||
# Verify result structure
|
||||
assert "narrative" in result
|
||||
assert result["narrative"] == "You enter the tavern..."
|
||||
assert result["tokens_used"] == 150
|
||||
assert result["model"] == "meta/meta-llama-3-8b-instruct"
|
||||
|
||||
# Verify generator was called
|
||||
mock_generator.generate_story_response.assert_called_once()
|
||||
|
||||
# Verify session update was called
|
||||
mock_update_session.assert_called_once()
|
||||
|
||||
@patch('app.tasks.ai_tasks._get_user_tier')
|
||||
@patch('app.tasks.ai_tasks.NarrativeGenerator')
|
||||
@patch('app.tasks.ai_tasks._update_job_status')
|
||||
@patch('app.tasks.ai_tasks._store_job_result')
|
||||
def test_process_combat_task(
|
||||
self, mock_store_result, mock_update_status, mock_generator_class,
|
||||
mock_get_tier, sample_character, mock_narrative_response
|
||||
):
|
||||
"""Test processing a combat task."""
|
||||
from app.ai.model_selector import UserTier
|
||||
|
||||
mock_get_tier.return_value = UserTier.BASIC
|
||||
mock_generator = MagicMock()
|
||||
mock_generator.generate_combat_narration.return_value = mock_narrative_response
|
||||
mock_generator_class.return_value = mock_generator
|
||||
|
||||
result = process_ai_task(
|
||||
task_type="combat",
|
||||
user_id="user_123",
|
||||
context={
|
||||
"character": sample_character,
|
||||
"combat_state": {"round_number": 1, "enemies": [], "current_turn": "player"},
|
||||
"action": "swings sword",
|
||||
"action_result": {"hit": True, "damage": 10}
|
||||
},
|
||||
job_id="job_123",
|
||||
)
|
||||
|
||||
assert "combat_narrative" in result
|
||||
mock_generator.generate_combat_narration.assert_called_once()
|
||||
|
||||
@patch('app.tasks.ai_tasks._get_user_tier')
|
||||
@patch('app.tasks.ai_tasks.NarrativeGenerator')
|
||||
@patch('app.tasks.ai_tasks._update_job_status')
|
||||
@patch('app.tasks.ai_tasks._store_job_result')
|
||||
def test_process_quest_selection_task(
|
||||
self, mock_store_result, mock_update_status, mock_generator_class,
|
||||
mock_get_tier, sample_character, sample_game_state
|
||||
):
|
||||
"""Test processing a quest selection task."""
|
||||
from app.ai.model_selector import UserTier
|
||||
|
||||
mock_get_tier.return_value = UserTier.FREE
|
||||
mock_generator = MagicMock()
|
||||
mock_generator.generate_quest_selection.return_value = "goblin_cave"
|
||||
mock_generator_class.return_value = mock_generator
|
||||
|
||||
result = process_ai_task(
|
||||
task_type="quest_selection",
|
||||
user_id="user_123",
|
||||
context={
|
||||
"character": sample_character,
|
||||
"eligible_quests": [{"quest_id": "goblin_cave"}],
|
||||
"game_context": sample_game_state
|
||||
},
|
||||
job_id="job_123",
|
||||
)
|
||||
|
||||
assert result["selected_quest_id"] == "goblin_cave"
|
||||
mock_generator.generate_quest_selection.assert_called_once()
|
||||
|
||||
@patch('app.tasks.ai_tasks._get_user_tier')
|
||||
@patch('app.tasks.ai_tasks.NarrativeGenerator')
|
||||
@patch('app.tasks.ai_tasks._update_job_status')
|
||||
@patch('app.tasks.ai_tasks._store_job_result')
|
||||
def test_process_npc_dialogue_task(
|
||||
self, mock_store_result, mock_update_status, mock_generator_class,
|
||||
mock_get_tier, sample_character, sample_game_state, mock_narrative_response
|
||||
):
|
||||
"""Test processing an NPC dialogue task."""
|
||||
from app.ai.model_selector import UserTier
|
||||
|
||||
mock_get_tier.return_value = UserTier.PREMIUM
|
||||
mock_generator = MagicMock()
|
||||
mock_generator.generate_npc_dialogue.return_value = mock_narrative_response
|
||||
mock_generator_class.return_value = mock_generator
|
||||
|
||||
result = process_ai_task(
|
||||
task_type="npc_dialogue",
|
||||
user_id="user_123",
|
||||
context={
|
||||
"character": sample_character,
|
||||
"npc": {"name": "Innkeeper", "role": "Barkeeper"},
|
||||
"conversation_topic": "What's the news?",
|
||||
"game_state": sample_game_state
|
||||
},
|
||||
job_id="job_123",
|
||||
)
|
||||
|
||||
assert "dialogue" in result
|
||||
mock_generator.generate_npc_dialogue.assert_called_once()
|
||||
|
||||
@patch('app.tasks.ai_tasks._update_job_status')
|
||||
def test_process_task_failure_invalid_type(self, mock_update_status):
|
||||
"""Test task processing failure with invalid task type."""
|
||||
with pytest.raises(ValueError):
|
||||
process_ai_task(
|
||||
task_type="invalid",
|
||||
user_id="user_123",
|
||||
context={},
|
||||
job_id="job_123",
|
||||
)
|
||||
|
||||
# Verify status was updated to PROCESSING then FAILED
|
||||
calls = mock_update_status.call_args_list
|
||||
assert any(call[0][1] == JobStatus.FAILED for call in calls)
|
||||
|
||||
@patch('app.tasks.ai_tasks._get_user_tier')
|
||||
@patch('app.tasks.ai_tasks._update_job_status')
|
||||
def test_process_task_missing_context_fields(self, mock_update_status, mock_get_tier):
|
||||
"""Test task processing failure with missing required context fields."""
|
||||
from app.ai.model_selector import UserTier
|
||||
mock_get_tier.return_value = UserTier.FREE
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
process_ai_task(
|
||||
task_type="narrative",
|
||||
user_id="user_123",
|
||||
context={"action": "test"}, # Missing character and game_state
|
||||
job_id="job_123",
|
||||
)
|
||||
|
||||
assert "Missing required context field" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestJobStatus:
|
||||
"""Test job status retrieval."""
|
||||
|
||||
@patch('app.tasks.ai_tasks.RedisService')
|
||||
def test_get_job_status_from_cache(self, mock_redis_class):
|
||||
"""Test getting job status from Redis cache."""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get_json.return_value = {
|
||||
"job_id": "job_123",
|
||||
"status": "completed",
|
||||
"created_at": "2025-11-21T10:00:00Z",
|
||||
}
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
result = get_job_status("job_123")
|
||||
|
||||
assert result["job_id"] == "job_123"
|
||||
assert result["status"] == "completed"
|
||||
|
||||
@patch('app.tasks.ai_tasks.RedisService')
|
||||
@patch('app.tasks.ai_tasks.get_redis_connection')
|
||||
@patch('app.tasks.ai_tasks.Job')
|
||||
def test_get_job_status_from_rq(self, mock_job_class, mock_get_conn, mock_redis_class):
|
||||
"""Test getting job status from RQ when not in cache."""
|
||||
# No cached status
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get_json.return_value = None
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
# RQ job
|
||||
mock_job = MagicMock()
|
||||
mock_job.is_finished = True
|
||||
mock_job.is_failed = False
|
||||
mock_job.is_started = False
|
||||
mock_job.created_at = None
|
||||
mock_job.started_at = None
|
||||
mock_job.ended_at = None
|
||||
mock_job_class.fetch.return_value = mock_job
|
||||
|
||||
result = get_job_status("job_123")
|
||||
|
||||
assert result["status"] == "completed"
|
||||
|
||||
@patch('app.tasks.ai_tasks.RedisService')
|
||||
@patch('app.tasks.ai_tasks.get_redis_connection')
|
||||
@patch('app.tasks.ai_tasks.Job')
|
||||
def test_get_job_status_not_found(self, mock_job_class, mock_get_conn, mock_redis_class):
|
||||
"""Test getting status of non-existent job."""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get_json.return_value = None
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
mock_job_class.fetch.side_effect = Exception("Job not found")
|
||||
|
||||
result = get_job_status("nonexistent_job")
|
||||
|
||||
assert result["status"] == "unknown"
|
||||
assert "error" in result
|
||||
|
||||
|
||||
class TestJobResult:
|
||||
"""Test job result retrieval."""
|
||||
|
||||
@patch('app.tasks.ai_tasks.RedisService')
|
||||
def test_get_job_result_from_cache(self, mock_redis_class):
|
||||
"""Test getting job result from Redis cache."""
|
||||
expected_result = {
|
||||
"narrative": "You enter the tavern...",
|
||||
"tokens_used": 450,
|
||||
}
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get_json.return_value = expected_result
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
result = get_job_result("job_123")
|
||||
|
||||
assert result == expected_result
|
||||
|
||||
@patch('app.tasks.ai_tasks.RedisService')
|
||||
@patch('app.tasks.ai_tasks.get_redis_connection')
|
||||
@patch('app.tasks.ai_tasks.Job')
|
||||
def test_get_job_result_from_rq(self, mock_job_class, mock_get_conn, mock_redis_class):
|
||||
"""Test getting job result from RQ when not in cache."""
|
||||
expected_result = {"narrative": "You find a sword..."}
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get_json.return_value = None
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
mock_job = MagicMock()
|
||||
mock_job.is_finished = True
|
||||
mock_job.result = expected_result
|
||||
mock_job_class.fetch.return_value = mock_job
|
||||
|
||||
result = get_job_result("job_123")
|
||||
|
||||
assert result == expected_result
|
||||
|
||||
@patch('app.tasks.ai_tasks.RedisService')
|
||||
@patch('app.tasks.ai_tasks.get_redis_connection')
|
||||
@patch('app.tasks.ai_tasks.Job')
|
||||
def test_get_job_result_not_found(self, mock_job_class, mock_get_conn, mock_redis_class):
|
||||
"""Test getting result of non-existent job."""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get_json.return_value = None
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
mock_job_class.fetch.side_effect = Exception("Job not found")
|
||||
|
||||
result = get_job_result("nonexistent_job")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestStatusStorage:
|
||||
"""Test job status storage functions."""
|
||||
|
||||
@patch('app.tasks.ai_tasks.RedisService')
|
||||
def test_store_job_status(self, mock_redis_class):
|
||||
"""Test storing initial job status."""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
_store_job_status(
|
||||
job_id="job_123",
|
||||
status=JobStatus.QUEUED,
|
||||
task_type="narrative",
|
||||
user_id="user_456",
|
||||
)
|
||||
|
||||
mock_redis.set_json.assert_called_once()
|
||||
call_args = mock_redis.set_json.call_args
|
||||
stored_data = call_args[0][1]
|
||||
|
||||
assert stored_data["job_id"] == "job_123"
|
||||
assert stored_data["status"] == "queued"
|
||||
assert stored_data["task_type"] == "narrative"
|
||||
assert stored_data["user_id"] == "user_456"
|
||||
|
||||
@patch('app.tasks.ai_tasks.RedisService')
|
||||
def test_update_job_status_to_processing(self, mock_redis_class):
|
||||
"""Test updating job status to processing."""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get_json.return_value = {
|
||||
"job_id": "job_123",
|
||||
"status": "queued",
|
||||
"created_at": "2025-11-21T10:00:00Z",
|
||||
}
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
_update_job_status("job_123", JobStatus.PROCESSING)
|
||||
|
||||
mock_redis.set_json.assert_called_once()
|
||||
call_args = mock_redis.set_json.call_args
|
||||
updated_data = call_args[0][1]
|
||||
|
||||
assert updated_data["status"] == "processing"
|
||||
assert updated_data["started_at"] is not None
|
||||
|
||||
@patch('app.tasks.ai_tasks.RedisService')
|
||||
def test_update_job_status_to_completed(self, mock_redis_class):
|
||||
"""Test updating job status to completed with result."""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get_json.return_value = {
|
||||
"job_id": "job_123",
|
||||
"status": "processing",
|
||||
}
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
result_data = {"narrative": "Test result"}
|
||||
_update_job_status("job_123", JobStatus.COMPLETED, result=result_data)
|
||||
|
||||
call_args = mock_redis.set_json.call_args
|
||||
updated_data = call_args[0][1]
|
||||
|
||||
assert updated_data["status"] == "completed"
|
||||
assert updated_data["completed_at"] is not None
|
||||
assert updated_data["result"] == result_data
|
||||
|
||||
@patch('app.tasks.ai_tasks.RedisService')
|
||||
def test_update_job_status_to_failed(self, mock_redis_class):
|
||||
"""Test updating job status to failed with error."""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get_json.return_value = {
|
||||
"job_id": "job_123",
|
||||
"status": "processing",
|
||||
}
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
_update_job_status("job_123", JobStatus.FAILED, error="Something went wrong")
|
||||
|
||||
call_args = mock_redis.set_json.call_args
|
||||
updated_data = call_args[0][1]
|
||||
|
||||
assert updated_data["status"] == "failed"
|
||||
assert updated_data["error"] == "Something went wrong"
|
||||
|
||||
@patch('app.tasks.ai_tasks.RedisService')
|
||||
def test_store_job_result(self, mock_redis_class):
|
||||
"""Test storing job result."""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
result_data = {
|
||||
"narrative": "You discover a hidden passage...",
|
||||
"tokens_used": 500,
|
||||
}
|
||||
|
||||
_store_job_result("job_123", result_data)
|
||||
|
||||
mock_redis.set_json.assert_called_once()
|
||||
call_args = mock_redis.set_json.call_args
|
||||
|
||||
assert call_args[0][0] == "ai_job_result:job_123"
|
||||
assert call_args[0][1] == result_data
|
||||
assert call_args[1]["ttl"] == 3600 # 1 hour
|
||||
|
||||
|
||||
class TestTaskTypes:
|
||||
"""Test task type and priority enums."""
|
||||
|
||||
def test_task_type_values(self):
|
||||
"""Test all task types are defined."""
|
||||
assert TaskType.NARRATIVE.value == "narrative"
|
||||
assert TaskType.COMBAT.value == "combat"
|
||||
assert TaskType.QUEST_SELECTION.value == "quest_selection"
|
||||
assert TaskType.NPC_DIALOGUE.value == "npc_dialogue"
|
||||
|
||||
def test_task_priority_values(self):
|
||||
"""Test all priorities are defined."""
|
||||
assert TaskPriority.LOW.value == "low"
|
||||
assert TaskPriority.NORMAL.value == "normal"
|
||||
assert TaskPriority.HIGH.value == "high"
|
||||
|
||||
def test_job_status_values(self):
|
||||
"""Test all job statuses are defined."""
|
||||
assert JobStatus.QUEUED.value == "queued"
|
||||
assert JobStatus.PROCESSING.value == "processing"
|
||||
assert JobStatus.COMPLETED.value == "completed"
|
||||
assert JobStatus.FAILED.value == "failed"
|
||||
579
api/tests/test_api_characters_integration.py
Normal file
579
api/tests/test_api_characters_integration.py
Normal file
@@ -0,0 +1,579 @@
|
||||
"""
|
||||
Integration tests for Character API endpoints.
|
||||
|
||||
These tests verify the complete character management API flow including:
|
||||
- List characters
|
||||
- Get character details
|
||||
- Create character (with tier limits)
|
||||
- Delete character
|
||||
- Unlock skills
|
||||
- Respec skills
|
||||
- Get classes and origins (reference data)
|
||||
|
||||
Tests use Flask test client with mocked authentication and database layers.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from datetime import datetime
|
||||
|
||||
from app.models.character import Character
|
||||
from app.models.stats import Stats
|
||||
from app.models.skills import PlayerClass, SkillTree, SkillNode
|
||||
from app.models.origins import Origin, StartingLocation, StartingBonus
|
||||
from app.services.character_service import (
|
||||
CharacterLimitExceeded,
|
||||
CharacterNotFound,
|
||||
SkillUnlockError,
|
||||
InsufficientGold
|
||||
)
|
||||
from app.services.database_service import DatabaseDocument
|
||||
|
||||
|
||||
class TestCharacterAPIIntegration:
|
||||
"""Integration tests for Character API endpoints."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create Flask application for testing."""
|
||||
# Import here to avoid circular imports
|
||||
from app import create_app
|
||||
app = create_app()
|
||||
app.config['TESTING'] = True
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def client(self, app):
|
||||
"""Create test client."""
|
||||
return app.test_client()
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth_decorator(self):
|
||||
"""Mock the require_auth decorator to bypass authentication."""
|
||||
def pass_through_decorator(func):
|
||||
"""Decorator that does nothing - passes through to the function."""
|
||||
return func
|
||||
|
||||
with patch('app.api.characters.require_auth', side_effect=pass_through_decorator):
|
||||
yield
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user(self):
|
||||
"""Create a mock authenticated user."""
|
||||
user = Mock()
|
||||
user.id = "test_user_123"
|
||||
user.email = "test@example.com"
|
||||
user.name = "Test User"
|
||||
user.tier = "free"
|
||||
user.email_verified = True
|
||||
return user
|
||||
|
||||
@pytest.fixture
|
||||
def sample_class(self):
|
||||
"""Create a sample player class."""
|
||||
base_stats = Stats(strength=14, dexterity=10, constitution=14,
|
||||
intelligence=8, wisdom=10, charisma=9)
|
||||
|
||||
skill_nodes = [
|
||||
SkillNode(
|
||||
skill_id="shield_wall",
|
||||
name="Shield Wall",
|
||||
description="Increase armor by 5",
|
||||
tier=1,
|
||||
prerequisites=[],
|
||||
effects={"armor": 5}
|
||||
),
|
||||
SkillNode(
|
||||
skill_id="toughness",
|
||||
name="Toughness",
|
||||
description="Increase HP by 10",
|
||||
tier=2,
|
||||
prerequisites=["shield_wall"],
|
||||
effects={"hit_points": 10}
|
||||
)
|
||||
]
|
||||
|
||||
skill_tree = SkillTree(
|
||||
tree_id="shield_bearer",
|
||||
name="Shield Bearer",
|
||||
description="Defensive techniques",
|
||||
nodes=skill_nodes
|
||||
)
|
||||
|
||||
return PlayerClass(
|
||||
class_id="vanguard",
|
||||
name="Vanguard",
|
||||
description="Armored warrior",
|
||||
base_stats=base_stats,
|
||||
skill_trees=[skill_tree],
|
||||
starting_equipment=["Rusty Sword", "Tattered Cloth Armor"],
|
||||
starting_abilities=[]
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_origin(self):
|
||||
"""Create a sample origin."""
|
||||
starting_location = StartingLocation(
|
||||
id="forgotten_crypt",
|
||||
name="The Forgotten Crypt",
|
||||
region="The Deadlands",
|
||||
description="A crumbling stone tomb beneath a dead forest"
|
||||
)
|
||||
|
||||
starting_bonus = StartingBonus(
|
||||
type="stat",
|
||||
value={"constitution": 1}
|
||||
)
|
||||
|
||||
return Origin(
|
||||
id="soul_revenant",
|
||||
name="The Soul Revenant",
|
||||
description="You died. That much you remember...",
|
||||
starting_location=starting_location,
|
||||
narrative_hooks=["Who brought you back?"],
|
||||
starting_bonus=starting_bonus
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_character(self, sample_class, sample_origin):
|
||||
"""Create a sample character."""
|
||||
return Character(
|
||||
character_id="char_123",
|
||||
user_id="test_user_123",
|
||||
name="Thorin Ironforge",
|
||||
player_class=sample_class,
|
||||
origin=sample_origin,
|
||||
level=5,
|
||||
experience=2400,
|
||||
base_stats=Stats(strength=16, dexterity=10, constitution=14,
|
||||
intelligence=8, wisdom=12, charisma=10),
|
||||
unlocked_skills=["shield_wall"],
|
||||
inventory=[],
|
||||
equipped={},
|
||||
gold=500,
|
||||
active_quests=[],
|
||||
discovered_locations=["forgotten_crypt"],
|
||||
current_location="forgotten_crypt"
|
||||
)
|
||||
|
||||
# ===== LIST CHARACTERS =====
|
||||
|
||||
@patch('app.api.characters.get_current_user')
|
||||
@patch('app.api.characters.get_character_service')
|
||||
def test_list_characters_success(self, mock_get_service, mock_get_user,
|
||||
client, mock_user, sample_character):
|
||||
"""Test listing user's characters."""
|
||||
# Setup
|
||||
mock_get_user.return_value = mock_user
|
||||
mock_service = Mock()
|
||||
mock_service.get_user_characters.return_value = [sample_character]
|
||||
mock_get_service.return_value = mock_service
|
||||
|
||||
# Execute
|
||||
response = client.get('/api/v1/characters')
|
||||
|
||||
# Verify
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert data['status'] == 200
|
||||
assert len(data['result']['characters']) == 1
|
||||
assert data['result']['characters'][0]['name'] == "Thorin Ironforge"
|
||||
assert data['result']['characters'][0]['class'] == "vanguard"
|
||||
assert data['result']['count'] == 1
|
||||
assert data['result']['tier'] == "free"
|
||||
assert data['result']['limit'] == 1
|
||||
|
||||
@patch('app.api.characters.get_current_user')
|
||||
def test_list_characters_unauthorized(self, mock_get_user, client):
|
||||
"""Test listing characters without authentication."""
|
||||
# Setup - simulate no user logged in
|
||||
mock_get_user.return_value = None
|
||||
|
||||
# Execute
|
||||
response = client.get('/api/v1/characters')
|
||||
|
||||
# Verify
|
||||
assert response.status_code == 401
|
||||
|
||||
# ===== GET CHARACTER =====
|
||||
|
||||
@patch('app.api.characters.get_current_user')
|
||||
@patch('app.api.characters.get_character_service')
|
||||
def test_get_character_success(self, mock_get_service, mock_get_user,
|
||||
client, mock_user, sample_character):
|
||||
"""Test getting a single character."""
|
||||
# Setup
|
||||
mock_get_user.return_value = mock_user
|
||||
mock_service = Mock()
|
||||
mock_service.get_character.return_value = sample_character
|
||||
mock_get_service.return_value = mock_service
|
||||
|
||||
# Execute
|
||||
response = client.get('/api/v1/characters/char_123')
|
||||
|
||||
# Verify
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert data['status'] == 200
|
||||
assert data['result']['name'] == "Thorin Ironforge"
|
||||
assert data['result']['level'] == 5
|
||||
assert data['result']['gold'] == 500
|
||||
|
||||
@patch('app.api.characters.get_current_user')
|
||||
@patch('app.api.characters.get_character_service')
|
||||
def test_get_character_not_found(self, mock_get_service, mock_get_user,
|
||||
client, mock_user):
|
||||
"""Test getting a non-existent character."""
|
||||
# Setup
|
||||
mock_get_user.return_value = mock_user
|
||||
mock_service = Mock()
|
||||
mock_service.get_character.side_effect = CharacterNotFound("Character not found")
|
||||
mock_get_service.return_value = mock_service
|
||||
|
||||
# Execute
|
||||
response = client.get('/api/v1/characters/char_999')
|
||||
|
||||
# Verify
|
||||
assert response.status_code == 404
|
||||
data = json.loads(response.data)
|
||||
assert 'error' in data
|
||||
|
||||
# ===== CREATE CHARACTER =====
|
||||
|
||||
@patch('app.api.characters.get_current_user')
|
||||
@patch('app.api.characters.get_character_service')
|
||||
def test_create_character_success(self, mock_get_service, mock_get_user,
|
||||
client, mock_user, sample_character):
|
||||
"""Test creating a new character."""
|
||||
# Setup
|
||||
mock_get_user.return_value = mock_user
|
||||
mock_service = Mock()
|
||||
mock_service.create_character.return_value = sample_character
|
||||
mock_get_service.return_value = mock_service
|
||||
|
||||
# Execute
|
||||
response = client.post('/api/v1/characters',
|
||||
json={
|
||||
'name': 'Thorin Ironforge',
|
||||
'class_id': 'vanguard',
|
||||
'origin_id': 'soul_revenant'
|
||||
})
|
||||
|
||||
# Verify
|
||||
assert response.status_code == 201
|
||||
data = json.loads(response.data)
|
||||
assert data['status'] == 201
|
||||
assert data['result']['name'] == "Thorin Ironforge"
|
||||
assert data['result']['class'] == "vanguard"
|
||||
assert data['result']['origin'] == "soul_revenant"
|
||||
|
||||
@patch('app.api.characters.get_current_user')
|
||||
@patch('app.api.characters.get_character_service')
|
||||
def test_create_character_limit_exceeded(self, mock_get_service, mock_get_user,
|
||||
client, mock_user):
|
||||
"""Test creating a character when tier limit is reached."""
|
||||
# Setup
|
||||
mock_get_user.return_value = mock_user
|
||||
mock_service = Mock()
|
||||
mock_service.create_character.side_effect = CharacterLimitExceeded(
|
||||
"Character limit reached for free tier (1/1)"
|
||||
)
|
||||
mock_get_service.return_value = mock_service
|
||||
|
||||
# Execute
|
||||
response = client.post('/api/v1/characters',
|
||||
json={
|
||||
'name': 'Thorin Ironforge',
|
||||
'class_id': 'vanguard',
|
||||
'origin_id': 'soul_revenant'
|
||||
})
|
||||
|
||||
# Verify
|
||||
assert response.status_code == 400
|
||||
data = json.loads(response.data)
|
||||
assert 'error' in data
|
||||
assert data['error']['code'] == 'CHARACTER_LIMIT_EXCEEDED'
|
||||
|
||||
def test_create_character_validation_errors(self, client, mock_user):
|
||||
"""Test character creation with invalid input."""
|
||||
with patch('app.api.characters.get_current_user', return_value=mock_user):
|
||||
# Test missing name
|
||||
response = client.post('/api/v1/characters',
|
||||
json={
|
||||
'class_id': 'vanguard',
|
||||
'origin_id': 'soul_revenant'
|
||||
})
|
||||
assert response.status_code == 400
|
||||
|
||||
# Test invalid class_id
|
||||
response = client.post('/api/v1/characters',
|
||||
json={
|
||||
'name': 'Test',
|
||||
'class_id': 'invalid_class',
|
||||
'origin_id': 'soul_revenant'
|
||||
})
|
||||
assert response.status_code == 400
|
||||
|
||||
# Test invalid origin_id
|
||||
response = client.post('/api/v1/characters',
|
||||
json={
|
||||
'name': 'Test',
|
||||
'class_id': 'vanguard',
|
||||
'origin_id': 'invalid_origin'
|
||||
})
|
||||
assert response.status_code == 400
|
||||
|
||||
# Test name too short
|
||||
response = client.post('/api/v1/characters',
|
||||
json={
|
||||
'name': 'T',
|
||||
'class_id': 'vanguard',
|
||||
'origin_id': 'soul_revenant'
|
||||
})
|
||||
assert response.status_code == 400
|
||||
|
||||
# ===== DELETE CHARACTER =====
|
||||
|
||||
@patch('app.api.characters.get_current_user')
|
||||
@patch('app.api.characters.get_character_service')
|
||||
def test_delete_character_success(self, mock_get_service, mock_get_user,
|
||||
client, mock_user):
|
||||
"""Test deleting a character."""
|
||||
# Setup
|
||||
mock_get_user.return_value = mock_user
|
||||
mock_service = Mock()
|
||||
mock_service.delete_character.return_value = True
|
||||
mock_get_service.return_value = mock_service
|
||||
|
||||
# Execute
|
||||
response = client.delete('/api/v1/characters/char_123')
|
||||
|
||||
# Verify
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert data['status'] == 200
|
||||
assert 'deleted successfully' in data['result']['message']
|
||||
|
||||
@patch('app.api.characters.get_current_user')
|
||||
@patch('app.api.characters.get_character_service')
|
||||
def test_delete_character_not_found(self, mock_get_service, mock_get_user,
|
||||
client, mock_user):
|
||||
"""Test deleting a non-existent character."""
|
||||
# Setup
|
||||
mock_get_user.return_value = mock_user
|
||||
mock_service = Mock()
|
||||
mock_service.delete_character.side_effect = CharacterNotFound("Character not found")
|
||||
mock_get_service.return_value = mock_service
|
||||
|
||||
# Execute
|
||||
response = client.delete('/api/v1/characters/char_999')
|
||||
|
||||
# Verify
|
||||
assert response.status_code == 404
|
||||
|
||||
# ===== UNLOCK SKILL =====
|
||||
|
||||
@patch('app.api.characters.get_current_user')
|
||||
@patch('app.api.characters.get_character_service')
|
||||
def test_unlock_skill_success(self, mock_get_service, mock_get_user,
|
||||
client, mock_user, sample_character):
|
||||
"""Test unlocking a skill."""
|
||||
# Setup
|
||||
mock_get_user.return_value = mock_user
|
||||
sample_character.unlocked_skills = ["shield_wall", "toughness"]
|
||||
mock_service = Mock()
|
||||
mock_service.unlock_skill.return_value = sample_character
|
||||
mock_get_service.return_value = mock_service
|
||||
|
||||
# Execute
|
||||
response = client.post('/api/v1/characters/char_123/skills/unlock',
|
||||
json={'skill_id': 'toughness'})
|
||||
|
||||
# Verify
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert data['status'] == 200
|
||||
assert data['result']['skill_id'] == 'toughness'
|
||||
assert 'toughness' in data['result']['unlocked_skills']
|
||||
|
||||
@patch('app.api.characters.get_current_user')
|
||||
@patch('app.api.characters.get_character_service')
|
||||
def test_unlock_skill_prerequisites_not_met(self, mock_get_service, mock_get_user,
|
||||
client, mock_user):
|
||||
"""Test unlocking a skill without meeting prerequisites."""
|
||||
# Setup
|
||||
mock_get_user.return_value = mock_user
|
||||
mock_service = Mock()
|
||||
mock_service.unlock_skill.side_effect = SkillUnlockError(
|
||||
"Prerequisite not met: shield_wall required for toughness"
|
||||
)
|
||||
mock_get_service.return_value = mock_service
|
||||
|
||||
# Execute
|
||||
response = client.post('/api/v1/characters/char_123/skills/unlock',
|
||||
json={'skill_id': 'toughness'})
|
||||
|
||||
# Verify
|
||||
assert response.status_code == 400
|
||||
data = json.loads(response.data)
|
||||
assert 'error' in data
|
||||
assert data['error']['code'] == 'SKILL_UNLOCK_ERROR'
|
||||
|
||||
@patch('app.api.characters.get_current_user')
|
||||
@patch('app.api.characters.get_character_service')
|
||||
def test_unlock_skill_no_points(self, mock_get_service, mock_get_user,
|
||||
client, mock_user):
|
||||
"""Test unlocking a skill without available skill points."""
|
||||
# Setup
|
||||
mock_get_user.return_value = mock_user
|
||||
mock_service = Mock()
|
||||
mock_service.unlock_skill.side_effect = SkillUnlockError(
|
||||
"No skill points available (Level 1, 1 skills unlocked)"
|
||||
)
|
||||
mock_get_service.return_value = mock_service
|
||||
|
||||
# Execute
|
||||
response = client.post('/api/v1/characters/char_123/skills/unlock',
|
||||
json={'skill_id': 'shield_wall'})
|
||||
|
||||
# Verify
|
||||
assert response.status_code == 400
|
||||
data = json.loads(response.data)
|
||||
assert 'error' in data
|
||||
|
||||
# ===== RESPEC SKILLS =====
|
||||
|
||||
@patch('app.api.characters.get_current_user')
|
||||
@patch('app.api.characters.get_character_service')
|
||||
def test_respec_skills_success(self, mock_get_service, mock_get_user,
|
||||
client, mock_user, sample_character):
|
||||
"""Test respeccing character skills."""
|
||||
# Setup
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
# Get character returns current state
|
||||
char_before = sample_character
|
||||
char_before.unlocked_skills = ["shield_wall"]
|
||||
char_before.gold = 500
|
||||
|
||||
# After respec, skills cleared and gold reduced
|
||||
char_after = sample_character
|
||||
char_after.unlocked_skills = []
|
||||
char_after.gold = 0 # 500 - (5 * 100)
|
||||
|
||||
mock_service = Mock()
|
||||
mock_service.get_character.return_value = char_before
|
||||
mock_service.respec_skills.return_value = char_after
|
||||
mock_get_service.return_value = mock_service
|
||||
|
||||
# Execute
|
||||
response = client.post('/api/v1/characters/char_123/skills/respec')
|
||||
|
||||
# Verify
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert data['status'] == 200
|
||||
assert data['result']['cost'] == 500 # level 5 * 100
|
||||
assert data['result']['remaining_gold'] == 0
|
||||
assert data['result']['available_points'] == 5
|
||||
|
||||
@patch('app.api.characters.get_current_user')
|
||||
@patch('app.api.characters.get_character_service')
|
||||
def test_respec_skills_insufficient_gold(self, mock_get_service, mock_get_user,
|
||||
client, mock_user, sample_character):
|
||||
"""Test respeccing without enough gold."""
|
||||
# Setup
|
||||
mock_get_user.return_value = mock_user
|
||||
sample_character.gold = 100 # Not enough for level 5 respec (needs 500)
|
||||
|
||||
mock_service = Mock()
|
||||
mock_service.get_character.return_value = sample_character
|
||||
mock_service.respec_skills.side_effect = InsufficientGold(
|
||||
"Insufficient gold for respec. Cost: 500, Available: 100"
|
||||
)
|
||||
mock_get_service.return_value = mock_service
|
||||
|
||||
# Execute
|
||||
response = client.post('/api/v1/characters/char_123/skills/respec')
|
||||
|
||||
# Verify
|
||||
assert response.status_code == 400
|
||||
data = json.loads(response.data)
|
||||
assert 'error' in data
|
||||
assert data['error']['code'] == 'INSUFFICIENT_GOLD'
|
||||
|
||||
# ===== CLASSES ENDPOINTS (REFERENCE DATA) =====
|
||||
|
||||
@patch('app.api.characters.get_class_loader')
|
||||
def test_list_classes(self, mock_get_loader, client, sample_class):
|
||||
"""Test listing all character classes."""
|
||||
# Setup
|
||||
mock_loader = Mock()
|
||||
mock_loader.get_all_class_ids.return_value = ['vanguard', 'arcanist']
|
||||
mock_loader.load_class.return_value = sample_class
|
||||
mock_get_loader.return_value = mock_loader
|
||||
|
||||
# Execute
|
||||
response = client.get('/api/v1/classes')
|
||||
|
||||
# Verify
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert data['status'] == 200
|
||||
assert len(data['result']['classes']) == 2
|
||||
assert data['result']['count'] == 2
|
||||
|
||||
@patch('app.api.characters.get_class_loader')
|
||||
def test_get_class_details(self, mock_get_loader, client, sample_class):
|
||||
"""Test getting details of a specific class."""
|
||||
# Setup
|
||||
mock_loader = Mock()
|
||||
mock_loader.load_class.return_value = sample_class
|
||||
mock_get_loader.return_value = mock_loader
|
||||
|
||||
# Execute
|
||||
response = client.get('/api/v1/classes/vanguard')
|
||||
|
||||
# Verify
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert data['status'] == 200
|
||||
assert data['result']['class_id'] == 'vanguard'
|
||||
assert data['result']['name'] == 'Vanguard'
|
||||
assert len(data['result']['skill_trees']) == 1
|
||||
|
||||
@patch('app.api.characters.get_class_loader')
|
||||
def test_get_class_not_found(self, mock_get_loader, client):
|
||||
"""Test getting a non-existent class."""
|
||||
# Setup
|
||||
mock_loader = Mock()
|
||||
mock_loader.load_class.return_value = None
|
||||
mock_get_loader.return_value = mock_loader
|
||||
|
||||
# Execute
|
||||
response = client.get('/api/v1/classes/invalid_class')
|
||||
|
||||
# Verify
|
||||
assert response.status_code == 404
|
||||
|
||||
# ===== ORIGINS ENDPOINTS (REFERENCE DATA) =====
|
||||
|
||||
@patch('app.api.characters.get_origin_service')
|
||||
def test_list_origins(self, mock_get_service, client, sample_origin):
|
||||
"""Test listing all character origins."""
|
||||
# Setup
|
||||
mock_service = Mock()
|
||||
mock_service.get_all_origin_ids.return_value = ['soul_revenant', 'memory_thief']
|
||||
mock_service.load_origin.return_value = sample_origin
|
||||
mock_get_service.return_value = mock_service
|
||||
|
||||
# Execute
|
||||
response = client.get('/api/v1/origins')
|
||||
|
||||
# Verify
|
||||
assert response.status_code == 200
|
||||
data = json.loads(response.data)
|
||||
assert data['status'] == 200
|
||||
assert len(data['result']['origins']) == 2
|
||||
assert data['result']['count'] == 2
|
||||
454
api/tests/test_character.py
Normal file
454
api/tests/test_character.py
Normal file
@@ -0,0 +1,454 @@
|
||||
"""
|
||||
Unit tests for Character dataclass.
|
||||
|
||||
Tests the critical get_effective_stats() method which combines all stat modifiers,
|
||||
as well as inventory, equipment, experience, and serialization.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from app.models.character import Character
|
||||
from app.models.stats import Stats
|
||||
from app.models.items import Item
|
||||
from app.models.effects import Effect
|
||||
from app.models.skills import PlayerClass, SkillTree, SkillNode
|
||||
from app.models.origins import Origin, StartingLocation, StartingBonus
|
||||
from app.models.enums import ItemType, EffectType, StatType, DamageType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_player_class():
|
||||
"""Create a basic player class for testing."""
|
||||
base_stats = Stats(strength=12, dexterity=10, constitution=14, intelligence=8, wisdom=10, charisma=11)
|
||||
|
||||
# Create a simple skill tree
|
||||
skill_tree = SkillTree(
|
||||
tree_id="warrior_offense",
|
||||
name="Warrior Offense",
|
||||
description="Offensive combat skills",
|
||||
nodes=[
|
||||
SkillNode(
|
||||
skill_id="power_strike",
|
||||
name="Power Strike",
|
||||
description="+5 Strength",
|
||||
tier=1,
|
||||
effects={"strength": 5},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
return PlayerClass(
|
||||
class_id="warrior",
|
||||
name="Warrior",
|
||||
description="Strong melee fighter",
|
||||
base_stats=base_stats,
|
||||
skill_trees=[skill_tree],
|
||||
starting_equipment=["basic_sword"],
|
||||
starting_abilities=["basic_attack"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_origin():
|
||||
"""Create a basic origin for testing."""
|
||||
starting_location = StartingLocation(
|
||||
id="test_location",
|
||||
name="Test Village",
|
||||
region="Test Region",
|
||||
description="A simple test location"
|
||||
)
|
||||
|
||||
starting_bonus = StartingBonus(
|
||||
trait="Test Trait",
|
||||
description="A test trait for testing",
|
||||
effect="+1 to all stats"
|
||||
)
|
||||
|
||||
return Origin(
|
||||
id="test_origin",
|
||||
name="Test Origin",
|
||||
description="A test origin for character testing",
|
||||
starting_location=starting_location,
|
||||
narrative_hooks=["Test hook 1", "Test hook 2"],
|
||||
starting_bonus=starting_bonus
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_character(basic_player_class, basic_origin):
|
||||
"""Create a basic character for testing."""
|
||||
return Character(
|
||||
character_id="char_001",
|
||||
user_id="user_001",
|
||||
name="Test Hero",
|
||||
player_class=basic_player_class,
|
||||
origin=basic_origin,
|
||||
level=1,
|
||||
experience=0,
|
||||
base_stats=basic_player_class.base_stats.copy(),
|
||||
)
|
||||
|
||||
|
||||
def test_character_creation(basic_character):
|
||||
"""Test creating a Character instance."""
|
||||
assert basic_character.character_id == "char_001"
|
||||
assert basic_character.user_id == "user_001"
|
||||
assert basic_character.name == "Test Hero"
|
||||
assert basic_character.level == 1
|
||||
assert basic_character.experience == 0
|
||||
assert basic_character.gold == 0
|
||||
|
||||
|
||||
def test_get_effective_stats_base_only(basic_character):
|
||||
"""Test get_effective_stats() with only base stats (no modifiers)."""
|
||||
effective = basic_character.get_effective_stats()
|
||||
|
||||
# Should match base stats exactly
|
||||
assert effective.strength == 12
|
||||
assert effective.dexterity == 10
|
||||
assert effective.constitution == 14
|
||||
assert effective.intelligence == 8
|
||||
assert effective.wisdom == 10
|
||||
assert effective.charisma == 11
|
||||
|
||||
|
||||
def test_get_effective_stats_with_equipment(basic_character):
|
||||
"""Test get_effective_stats() with equipped items."""
|
||||
# Create a weapon with +5 strength
|
||||
weapon = Item(
|
||||
item_id="iron_sword",
|
||||
name="Iron Sword",
|
||||
item_type=ItemType.WEAPON,
|
||||
description="A sturdy iron sword",
|
||||
stat_bonuses={"strength": 5},
|
||||
damage=10,
|
||||
)
|
||||
|
||||
basic_character.equipped["weapon"] = weapon
|
||||
effective = basic_character.get_effective_stats()
|
||||
|
||||
# Strength should be base (12) + weapon (5) = 17
|
||||
assert effective.strength == 17
|
||||
assert effective.dexterity == 10 # Unchanged
|
||||
|
||||
|
||||
def test_get_effective_stats_with_skill_bonuses(basic_character):
|
||||
"""Test get_effective_stats() with skill tree bonuses."""
|
||||
# Unlock the "power_strike" skill which gives +5 strength
|
||||
basic_character.unlocked_skills.append("power_strike")
|
||||
|
||||
effective = basic_character.get_effective_stats()
|
||||
|
||||
# Strength should be base (12) + skill (5) = 17
|
||||
assert effective.strength == 17
|
||||
|
||||
|
||||
def test_get_effective_stats_with_all_modifiers(basic_character):
|
||||
"""Test get_effective_stats() with equipment + skills + active effects."""
|
||||
# Add equipment: +5 strength
|
||||
weapon = Item(
|
||||
item_id="iron_sword",
|
||||
name="Iron Sword",
|
||||
item_type=ItemType.WEAPON,
|
||||
description="A sturdy iron sword",
|
||||
stat_bonuses={"strength": 5},
|
||||
damage=10,
|
||||
)
|
||||
basic_character.equipped["weapon"] = weapon
|
||||
|
||||
# Unlock skill: +5 strength
|
||||
basic_character.unlocked_skills.append("power_strike")
|
||||
|
||||
# Add buff effect: +3 strength
|
||||
buff = Effect(
|
||||
effect_id="str_buff",
|
||||
name="Strength Boost",
|
||||
effect_type=EffectType.BUFF,
|
||||
duration=3,
|
||||
power=3,
|
||||
stat_affected=StatType.STRENGTH,
|
||||
)
|
||||
|
||||
effective = basic_character.get_effective_stats([buff])
|
||||
|
||||
# Total strength: 12 (base) + 5 (weapon) + 5 (skill) + 3 (buff) = 25
|
||||
assert effective.strength == 25
|
||||
|
||||
|
||||
def test_get_effective_stats_with_debuff(basic_character):
|
||||
"""Test get_effective_stats() with a debuff."""
|
||||
debuff = Effect(
|
||||
effect_id="weakened",
|
||||
name="Weakened",
|
||||
effect_type=EffectType.DEBUFF,
|
||||
duration=2,
|
||||
power=5,
|
||||
stat_affected=StatType.STRENGTH,
|
||||
)
|
||||
|
||||
effective = basic_character.get_effective_stats([debuff])
|
||||
|
||||
# Strength: 12 (base) - 5 (debuff) = 7
|
||||
assert effective.strength == 7
|
||||
|
||||
|
||||
def test_get_effective_stats_debuff_minimum(basic_character):
|
||||
"""Test that debuffs cannot reduce stats below 1."""
|
||||
# Massive debuff
|
||||
debuff = Effect(
|
||||
effect_id="weakened",
|
||||
name="Weakened",
|
||||
effect_type=EffectType.DEBUFF,
|
||||
duration=2,
|
||||
power=20, # More than base strength
|
||||
stat_affected=StatType.STRENGTH,
|
||||
)
|
||||
|
||||
effective = basic_character.get_effective_stats([debuff])
|
||||
|
||||
# Should be clamped at 1, not 0 or negative
|
||||
assert effective.strength == 1
|
||||
|
||||
|
||||
def test_inventory_management(basic_character):
|
||||
"""Test adding and removing items from inventory."""
|
||||
item = Item(
|
||||
item_id="potion",
|
||||
name="Health Potion",
|
||||
item_type=ItemType.CONSUMABLE,
|
||||
description="Restores 50 HP",
|
||||
value=25,
|
||||
)
|
||||
|
||||
# Add item
|
||||
basic_character.add_item(item)
|
||||
assert len(basic_character.inventory) == 1
|
||||
assert basic_character.inventory[0].item_id == "potion"
|
||||
|
||||
# Remove item
|
||||
removed = basic_character.remove_item("potion")
|
||||
assert removed.item_id == "potion"
|
||||
assert len(basic_character.inventory) == 0
|
||||
|
||||
|
||||
def test_equip_and_unequip(basic_character):
|
||||
"""Test equipping and unequipping items."""
|
||||
weapon = Item(
|
||||
item_id="iron_sword",
|
||||
name="Iron Sword",
|
||||
item_type=ItemType.WEAPON,
|
||||
description="A sturdy sword",
|
||||
stat_bonuses={"strength": 5},
|
||||
damage=10,
|
||||
)
|
||||
|
||||
# Add to inventory first
|
||||
basic_character.add_item(weapon)
|
||||
assert len(basic_character.inventory) == 1
|
||||
|
||||
# Equip weapon
|
||||
basic_character.equip_item(weapon, "weapon")
|
||||
assert "weapon" in basic_character.equipped
|
||||
assert basic_character.equipped["weapon"].item_id == "iron_sword"
|
||||
assert len(basic_character.inventory) == 0 # Removed from inventory
|
||||
|
||||
# Unequip weapon
|
||||
unequipped = basic_character.unequip_item("weapon")
|
||||
assert unequipped.item_id == "iron_sword"
|
||||
assert "weapon" not in basic_character.equipped
|
||||
assert len(basic_character.inventory) == 1 # Back in inventory
|
||||
|
||||
|
||||
def test_equip_replaces_existing(basic_character):
|
||||
"""Test that equipping a new item replaces the old one."""
|
||||
weapon1 = Item(
|
||||
item_id="iron_sword",
|
||||
name="Iron Sword",
|
||||
item_type=ItemType.WEAPON,
|
||||
description="A sturdy sword",
|
||||
damage=10,
|
||||
)
|
||||
|
||||
weapon2 = Item(
|
||||
item_id="steel_sword",
|
||||
name="Steel Sword",
|
||||
item_type=ItemType.WEAPON,
|
||||
description="A better sword",
|
||||
damage=15,
|
||||
)
|
||||
|
||||
# Equip first weapon
|
||||
basic_character.add_item(weapon1)
|
||||
basic_character.equip_item(weapon1, "weapon")
|
||||
|
||||
# Equip second weapon
|
||||
basic_character.add_item(weapon2)
|
||||
previous = basic_character.equip_item(weapon2, "weapon")
|
||||
|
||||
assert previous.item_id == "iron_sword" # Old weapon returned
|
||||
assert basic_character.equipped["weapon"].item_id == "steel_sword"
|
||||
assert len(basic_character.inventory) == 1 # Old weapon back in inventory
|
||||
|
||||
|
||||
def test_gold_management(basic_character):
|
||||
"""Test adding and removing gold."""
|
||||
assert basic_character.gold == 0
|
||||
|
||||
# Add gold
|
||||
basic_character.add_gold(100)
|
||||
assert basic_character.gold == 100
|
||||
|
||||
# Remove gold
|
||||
success = basic_character.remove_gold(50)
|
||||
assert success == True
|
||||
assert basic_character.gold == 50
|
||||
|
||||
# Try to remove more than available
|
||||
success = basic_character.remove_gold(100)
|
||||
assert success == False
|
||||
assert basic_character.gold == 50 # Unchanged
|
||||
|
||||
|
||||
def test_can_afford(basic_character):
|
||||
"""Test can_afford() method."""
|
||||
basic_character.gold = 100
|
||||
|
||||
assert basic_character.can_afford(50) == True
|
||||
assert basic_character.can_afford(100) == True
|
||||
assert basic_character.can_afford(101) == False
|
||||
|
||||
|
||||
def test_add_experience_no_level_up(basic_character):
|
||||
"""Test adding experience without leveling up."""
|
||||
leveled_up = basic_character.add_experience(50)
|
||||
|
||||
assert leveled_up == False
|
||||
assert basic_character.level == 1
|
||||
assert basic_character.experience == 50
|
||||
|
||||
|
||||
def test_add_experience_with_level_up(basic_character):
|
||||
"""Test adding enough experience to level up."""
|
||||
# Level 1 requires 100 XP for level 2
|
||||
leveled_up = basic_character.add_experience(100)
|
||||
|
||||
assert leveled_up == True
|
||||
assert basic_character.level == 2
|
||||
assert basic_character.experience == 0 # Reset
|
||||
|
||||
|
||||
def test_add_experience_with_overflow(basic_character):
|
||||
"""Test leveling up with overflow experience."""
|
||||
# Level 1 requires 100 XP, give 150
|
||||
leveled_up = basic_character.add_experience(150)
|
||||
|
||||
assert leveled_up == True
|
||||
assert basic_character.level == 2
|
||||
assert basic_character.experience == 50 # Overflow
|
||||
|
||||
|
||||
def test_xp_calculation(basic_origin):
|
||||
"""Test XP required for each level."""
|
||||
char = Character(
|
||||
character_id="test",
|
||||
user_id="user",
|
||||
name="Test",
|
||||
player_class=PlayerClass(
|
||||
class_id="test",
|
||||
name="Test",
|
||||
description="Test",
|
||||
base_stats=Stats(),
|
||||
),
|
||||
origin=basic_origin,
|
||||
)
|
||||
|
||||
# Formula: 100 * (level ^ 1.5)
|
||||
assert char._calculate_xp_for_next_level() == 100 # Level 1→2
|
||||
|
||||
char.level = 2
|
||||
assert char._calculate_xp_for_next_level() == 282 # Level 2→3
|
||||
|
||||
char.level = 3
|
||||
assert char._calculate_xp_for_next_level() == 519 # Level 3→4
|
||||
|
||||
|
||||
def test_get_unlocked_abilities(basic_character):
|
||||
"""Test getting abilities from class + unlocked skills."""
|
||||
# Should have starting abilities
|
||||
abilities = basic_character.get_unlocked_abilities()
|
||||
assert "basic_attack" in abilities
|
||||
|
||||
# TODO: When skills unlock abilities, test that here
|
||||
|
||||
|
||||
def test_character_serialization(basic_character):
|
||||
"""Test character to_dict() serialization."""
|
||||
# Add some data
|
||||
basic_character.gold = 500
|
||||
basic_character.level = 3
|
||||
basic_character.experience = 100
|
||||
|
||||
data = basic_character.to_dict()
|
||||
|
||||
assert data["character_id"] == "char_001"
|
||||
assert data["user_id"] == "user_001"
|
||||
assert data["name"] == "Test Hero"
|
||||
assert data["level"] == 3
|
||||
assert data["experience"] == 100
|
||||
assert data["gold"] == 500
|
||||
|
||||
|
||||
def test_character_deserialization(basic_player_class, basic_origin):
|
||||
"""Test character from_dict() deserialization."""
|
||||
data = {
|
||||
"character_id": "char_002",
|
||||
"user_id": "user_002",
|
||||
"name": "Restored Hero",
|
||||
"player_class": basic_player_class.to_dict(),
|
||||
"origin": basic_origin.to_dict(),
|
||||
"level": 5,
|
||||
"experience": 200,
|
||||
"base_stats": Stats(strength=15).to_dict(),
|
||||
"unlocked_skills": ["power_strike"],
|
||||
"inventory": [],
|
||||
"equipped": {},
|
||||
"gold": 1000,
|
||||
"active_quests": ["quest_1"],
|
||||
"discovered_locations": ["town_1"],
|
||||
}
|
||||
|
||||
char = Character.from_dict(data)
|
||||
|
||||
assert char.character_id == "char_002"
|
||||
assert char.name == "Restored Hero"
|
||||
assert char.level == 5
|
||||
assert char.gold == 1000
|
||||
assert "power_strike" in char.unlocked_skills
|
||||
|
||||
|
||||
def test_character_round_trip_serialization(basic_character):
|
||||
"""Test that serialization and deserialization preserve all data."""
|
||||
# Add complex state
|
||||
basic_character.gold = 500
|
||||
basic_character.level = 3
|
||||
basic_character.unlocked_skills = ["power_strike"]
|
||||
|
||||
weapon = Item(
|
||||
item_id="sword",
|
||||
name="Sword",
|
||||
item_type=ItemType.WEAPON,
|
||||
description="A sword",
|
||||
damage=10,
|
||||
)
|
||||
basic_character.equipped["weapon"] = weapon
|
||||
|
||||
# Serialize and deserialize
|
||||
data = basic_character.to_dict()
|
||||
restored = Character.from_dict(data)
|
||||
|
||||
assert restored.character_id == basic_character.character_id
|
||||
assert restored.name == basic_character.name
|
||||
assert restored.level == basic_character.level
|
||||
assert restored.gold == basic_character.gold
|
||||
assert restored.unlocked_skills == basic_character.unlocked_skills
|
||||
assert "weapon" in restored.equipped
|
||||
assert restored.equipped["weapon"].item_id == "sword"
|
||||
547
api/tests/test_character_service.py
Normal file
547
api/tests/test_character_service.py
Normal file
@@ -0,0 +1,547 @@
|
||||
"""
|
||||
Unit tests for CharacterService - character CRUD operations and tier limits.
|
||||
|
||||
These tests verify character creation, retrieval, deletion, skill unlock,
|
||||
and respec functionality with proper validation and error handling.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from app.services.character_service import (
|
||||
CharacterService,
|
||||
CharacterLimitExceeded,
|
||||
CharacterNotFound,
|
||||
SkillUnlockError,
|
||||
InsufficientGold,
|
||||
CHARACTER_LIMITS
|
||||
)
|
||||
from app.models.character import Character
|
||||
from app.models.stats import Stats
|
||||
from app.models.skills import PlayerClass, SkillTree, SkillNode
|
||||
from app.models.origins import Origin, StartingLocation, StartingBonus
|
||||
from app.services.database_service import DatabaseDocument
|
||||
|
||||
|
||||
class TestCharacterService:
|
||||
"""Test suite for CharacterService."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db(self):
|
||||
"""Mock database service."""
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_appwrite(self):
|
||||
"""Mock Appwrite service."""
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_class_loader(self):
|
||||
"""Mock class loader."""
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_origin_service(self):
|
||||
"""Mock origin service."""
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def character_service(self, mock_db, mock_appwrite, mock_class_loader, mock_origin_service):
|
||||
"""Create CharacterService with mocked dependencies."""
|
||||
# Patch the singleton getters before instantiation
|
||||
with patch('app.services.character_service.get_database_service', return_value=mock_db), \
|
||||
patch('app.services.character_service.AppwriteService', return_value=mock_appwrite), \
|
||||
patch('app.services.character_service.get_class_loader', return_value=mock_class_loader), \
|
||||
patch('app.services.character_service.get_origin_service', return_value=mock_origin_service):
|
||||
|
||||
service = CharacterService()
|
||||
|
||||
# Ensure mocks are still assigned
|
||||
service.db = mock_db
|
||||
service.appwrite = mock_appwrite
|
||||
service.class_loader = mock_class_loader
|
||||
service.origin_service = mock_origin_service
|
||||
return service
|
||||
|
||||
@pytest.fixture
|
||||
def sample_class(self):
|
||||
"""Create a sample player class."""
|
||||
base_stats = Stats(strength=12, dexterity=10, constitution=14)
|
||||
skill_tree = SkillTree(
|
||||
tree_id="warrior_offense",
|
||||
name="Warrior Offense",
|
||||
description="Offensive skills",
|
||||
nodes=[
|
||||
SkillNode(
|
||||
skill_id="power_strike",
|
||||
name="Power Strike",
|
||||
description="+5 Strength",
|
||||
tier=1,
|
||||
effects={"strength": 5},
|
||||
),
|
||||
SkillNode(
|
||||
skill_id="heavy_blow",
|
||||
name="Heavy Blow",
|
||||
description="+10 Strength",
|
||||
tier=2,
|
||||
prerequisites=["power_strike"],
|
||||
effects={"strength": 10},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
return PlayerClass(
|
||||
class_id="warrior",
|
||||
name="Warrior",
|
||||
description="Strong fighter",
|
||||
base_stats=base_stats,
|
||||
skill_trees=[skill_tree],
|
||||
starting_equipment=["basic_sword"],
|
||||
starting_abilities=["basic_attack"],
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_origin(self):
|
||||
"""Create a sample origin."""
|
||||
starting_location = StartingLocation(
|
||||
id="test_crypt",
|
||||
name="Test Crypt",
|
||||
region="Test Region",
|
||||
description="A test location"
|
||||
)
|
||||
|
||||
starting_bonus = StartingBonus(
|
||||
trait="Test Trait",
|
||||
description="Test bonus",
|
||||
effect="+1 to all stats"
|
||||
)
|
||||
|
||||
return Origin(
|
||||
id="test_origin",
|
||||
name="Test Origin",
|
||||
description="A test origin",
|
||||
starting_location=starting_location,
|
||||
narrative_hooks=["Test hook"],
|
||||
starting_bonus=starting_bonus
|
||||
)
|
||||
|
||||
def test_create_character_success(
|
||||
self, character_service, mock_appwrite, mock_class_loader,
|
||||
mock_origin_service, mock_db, sample_class, sample_origin
|
||||
):
|
||||
"""Test successful character creation."""
|
||||
# Setup mocks
|
||||
mock_appwrite.get_user_tier.return_value = 'free'
|
||||
mock_class_loader.load_class.return_value = sample_class
|
||||
mock_origin_service.load_origin.return_value = sample_origin
|
||||
|
||||
# Mock count_user_characters to return 0
|
||||
character_service.count_user_characters = Mock(return_value=0)
|
||||
|
||||
# Create character
|
||||
with patch('app.services.character_service.ID') as mock_id:
|
||||
mock_id.unique.return_value = 'char_123'
|
||||
|
||||
character = character_service.create_character(
|
||||
user_id='user_001',
|
||||
name='Test Hero',
|
||||
class_id='warrior',
|
||||
origin_id='test_origin'
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert character.character_id == 'char_123'
|
||||
assert character.user_id == 'user_001'
|
||||
assert character.name == 'Test Hero'
|
||||
assert character.level == 1
|
||||
assert character.experience == 0
|
||||
assert character.gold == 0
|
||||
assert character.current_location == 'test_crypt'
|
||||
|
||||
# Verify database was called
|
||||
mock_db.create_document.assert_called_once()
|
||||
|
||||
def test_create_character_exceeds_limit(
|
||||
self, character_service, mock_appwrite, mock_class_loader, mock_origin_service
|
||||
):
|
||||
"""Test character creation fails when tier limit exceeded."""
|
||||
# Setup: user on free tier (limit 1) with 1 existing character
|
||||
mock_appwrite.get_user_tier.return_value = 'free'
|
||||
character_service.count_user_characters = Mock(return_value=1)
|
||||
|
||||
# Attempt to create second character
|
||||
with pytest.raises(CharacterLimitExceeded) as exc_info:
|
||||
character_service.create_character(
|
||||
user_id='user_001',
|
||||
name='Second Hero',
|
||||
class_id='warrior',
|
||||
origin_id='test_origin'
|
||||
)
|
||||
|
||||
assert 'free tier' in str(exc_info.value)
|
||||
assert '1/1' in str(exc_info.value)
|
||||
|
||||
def test_create_character_tier_limits(self):
|
||||
"""Test that tier limits are correctly defined."""
|
||||
assert CHARACTER_LIMITS['free'] == 1
|
||||
assert CHARACTER_LIMITS['basic'] == 3
|
||||
assert CHARACTER_LIMITS['premium'] == 5
|
||||
assert CHARACTER_LIMITS['elite'] == 10
|
||||
|
||||
def test_create_character_invalid_class(
|
||||
self, character_service, mock_appwrite, mock_class_loader, mock_origin_service
|
||||
):
|
||||
"""Test character creation fails with invalid class."""
|
||||
mock_appwrite.get_user_tier.return_value = 'free'
|
||||
character_service.count_user_characters = Mock(return_value=0)
|
||||
mock_class_loader.load_class.return_value = None
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
character_service.create_character(
|
||||
user_id='user_001',
|
||||
name='Test Hero',
|
||||
class_id='invalid_class',
|
||||
origin_id='test_origin'
|
||||
)
|
||||
|
||||
assert 'Class not found' in str(exc_info.value)
|
||||
|
||||
def test_create_character_invalid_origin(
|
||||
self, character_service, mock_appwrite, mock_class_loader,
|
||||
mock_origin_service, sample_class
|
||||
):
|
||||
"""Test character creation fails with invalid origin."""
|
||||
mock_appwrite.get_user_tier.return_value = 'free'
|
||||
character_service.count_user_characters = Mock(return_value=0)
|
||||
mock_class_loader.load_class.return_value = sample_class
|
||||
mock_origin_service.load_origin.return_value = None
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
character_service.create_character(
|
||||
user_id='user_001',
|
||||
name='Test Hero',
|
||||
class_id='warrior',
|
||||
origin_id='invalid_origin'
|
||||
)
|
||||
|
||||
assert 'Origin not found' in str(exc_info.value)
|
||||
|
||||
def test_get_character_success(self, character_service, mock_db, sample_class, sample_origin):
|
||||
"""Test successfully retrieving a character."""
|
||||
# Create test character data
|
||||
character = Character(
|
||||
character_id='char_123',
|
||||
user_id='user_001',
|
||||
name='Test Hero',
|
||||
player_class=sample_class,
|
||||
origin=sample_origin,
|
||||
level=5,
|
||||
gold=500
|
||||
)
|
||||
|
||||
# Mock database response
|
||||
mock_doc = Mock(spec=DatabaseDocument)
|
||||
mock_doc.id = 'char_123'
|
||||
mock_doc.data = {
|
||||
'userId': 'user_001',
|
||||
'characterData': json.dumps(character.to_dict()),
|
||||
'is_active': True
|
||||
}
|
||||
mock_db.get_document.return_value = mock_doc
|
||||
|
||||
# Get character
|
||||
result = character_service.get_character('char_123', 'user_001')
|
||||
|
||||
# Assertions
|
||||
assert result.character_id == 'char_123'
|
||||
assert result.name == 'Test Hero'
|
||||
assert result.level == 5
|
||||
assert result.gold == 500
|
||||
|
||||
def test_get_character_not_found(self, character_service, mock_db):
|
||||
"""Test getting non-existent character raises error."""
|
||||
mock_db.get_document.return_value = None
|
||||
|
||||
with pytest.raises(CharacterNotFound):
|
||||
character_service.get_character('nonexistent', 'user_001')
|
||||
|
||||
def test_get_character_wrong_owner(self, character_service, mock_db):
|
||||
"""Test getting character owned by different user raises error."""
|
||||
mock_doc = Mock(spec=DatabaseDocument)
|
||||
mock_doc.data = {'userId': 'user_002'} # Different user
|
||||
mock_db.get_document.return_value = mock_doc
|
||||
|
||||
with pytest.raises(CharacterNotFound):
|
||||
character_service.get_character('char_123', 'user_001')
|
||||
|
||||
def test_get_user_characters(self, character_service, mock_db, sample_class, sample_origin):
|
||||
"""Test getting all characters for a user."""
|
||||
# Create test character data
|
||||
char1_data = Character(
|
||||
character_id='char_1',
|
||||
user_id='user_001',
|
||||
name='Hero 1',
|
||||
player_class=sample_class,
|
||||
origin=sample_origin
|
||||
).to_dict()
|
||||
|
||||
char2_data = Character(
|
||||
character_id='char_2',
|
||||
user_id='user_001',
|
||||
name='Hero 2',
|
||||
player_class=sample_class,
|
||||
origin=sample_origin
|
||||
).to_dict()
|
||||
|
||||
# Mock database response
|
||||
mock_doc1 = Mock(spec=DatabaseDocument)
|
||||
mock_doc1.id = 'char_1'
|
||||
mock_doc1.data = {'characterData': json.dumps(char1_data)}
|
||||
|
||||
mock_doc2 = Mock(spec=DatabaseDocument)
|
||||
mock_doc2.id = 'char_2'
|
||||
mock_doc2.data = {'characterData': json.dumps(char2_data)}
|
||||
|
||||
mock_db.list_rows.return_value = [mock_doc1, mock_doc2]
|
||||
|
||||
# Get characters
|
||||
characters = character_service.get_user_characters('user_001')
|
||||
|
||||
# Assertions
|
||||
assert len(characters) == 2
|
||||
assert characters[0].name == 'Hero 1'
|
||||
assert characters[1].name == 'Hero 2'
|
||||
|
||||
def test_count_user_characters(self, character_service, mock_db):
|
||||
"""Test counting user's characters."""
|
||||
mock_db.count_documents.return_value = 3
|
||||
|
||||
count = character_service.count_user_characters('user_001')
|
||||
|
||||
assert count == 3
|
||||
mock_db.count_documents.assert_called_once()
|
||||
|
||||
def test_delete_character_success(self, character_service, mock_db):
|
||||
"""Test successfully deleting a character."""
|
||||
# Mock get_character to return a valid character
|
||||
character_service.get_character = Mock(return_value=Mock(character_id='char_123'))
|
||||
|
||||
# Delete character
|
||||
result = character_service.delete_character('char_123', 'user_001')
|
||||
|
||||
# Assertions
|
||||
assert result is True
|
||||
mock_db.update_document.assert_called_once()
|
||||
# Verify it's a soft delete (is_active set to False)
|
||||
call_args = mock_db.update_document.call_args
|
||||
assert call_args[1]['data']['is_active'] is False
|
||||
|
||||
def test_delete_character_not_found(self, character_service):
|
||||
"""Test deleting non-existent character raises error."""
|
||||
character_service.get_character = Mock(side_effect=CharacterNotFound("Not found"))
|
||||
|
||||
with pytest.raises(CharacterNotFound):
|
||||
character_service.delete_character('nonexistent', 'user_001')
|
||||
|
||||
def test_unlock_skill_success(self, character_service, sample_class, sample_origin):
|
||||
"""Test successfully unlocking a skill."""
|
||||
# Create character with level 2 (1 skill point available)
|
||||
character = Character(
|
||||
character_id='char_123',
|
||||
user_id='user_001',
|
||||
name='Test Hero',
|
||||
player_class=sample_class,
|
||||
origin=sample_origin,
|
||||
level=2,
|
||||
unlocked_skills=[]
|
||||
)
|
||||
|
||||
# Mock get_character
|
||||
character_service.get_character = Mock(return_value=character)
|
||||
|
||||
# Mock _save_character
|
||||
character_service._save_character = Mock()
|
||||
|
||||
# Unlock skill
|
||||
result = character_service.unlock_skill('char_123', 'user_001', 'power_strike')
|
||||
|
||||
# Assertions
|
||||
assert 'power_strike' in result.unlocked_skills
|
||||
character_service._save_character.assert_called_once()
|
||||
|
||||
def test_unlock_skill_already_unlocked(self, character_service, sample_class, sample_origin):
|
||||
"""Test unlocking already unlocked skill raises error."""
|
||||
character = Character(
|
||||
character_id='char_123',
|
||||
user_id='user_001',
|
||||
name='Test Hero',
|
||||
player_class=sample_class,
|
||||
origin=sample_origin,
|
||||
level=2,
|
||||
unlocked_skills=['power_strike'] # Already unlocked
|
||||
)
|
||||
|
||||
character_service.get_character = Mock(return_value=character)
|
||||
|
||||
with pytest.raises(SkillUnlockError) as exc_info:
|
||||
character_service.unlock_skill('char_123', 'user_001', 'power_strike')
|
||||
|
||||
assert 'already unlocked' in str(exc_info.value)
|
||||
|
||||
def test_unlock_skill_not_in_class(self, character_service, sample_class, sample_origin):
|
||||
"""Test unlocking skill not in class raises error."""
|
||||
character = Character(
|
||||
character_id='char_123',
|
||||
user_id='user_001',
|
||||
name='Test Hero',
|
||||
player_class=sample_class,
|
||||
origin=sample_origin,
|
||||
level=2,
|
||||
unlocked_skills=[]
|
||||
)
|
||||
|
||||
character_service.get_character = Mock(return_value=character)
|
||||
|
||||
with pytest.raises(SkillUnlockError) as exc_info:
|
||||
character_service.unlock_skill('char_123', 'user_001', 'invalid_skill')
|
||||
|
||||
assert 'not found in class' in str(exc_info.value)
|
||||
|
||||
def test_unlock_skill_missing_prerequisite(self, character_service, sample_class, sample_origin):
|
||||
"""Test unlocking skill without prerequisite raises error."""
|
||||
character = Character(
|
||||
character_id='char_123',
|
||||
user_id='user_001',
|
||||
name='Test Hero',
|
||||
player_class=sample_class,
|
||||
origin=sample_origin,
|
||||
level=2,
|
||||
unlocked_skills=[] # Missing 'power_strike' prerequisite
|
||||
)
|
||||
|
||||
character_service.get_character = Mock(return_value=character)
|
||||
|
||||
with pytest.raises(SkillUnlockError) as exc_info:
|
||||
character_service.unlock_skill('char_123', 'user_001', 'heavy_blow')
|
||||
|
||||
assert 'Prerequisite not met' in str(exc_info.value)
|
||||
assert 'power_strike' in str(exc_info.value)
|
||||
|
||||
def test_unlock_skill_no_points_available(self, character_service, sample_class, sample_origin):
|
||||
"""Test unlocking skill without available points raises error."""
|
||||
# Add a tier 3 skill to test with
|
||||
tier3_skill = SkillNode(
|
||||
skill_id="master_strike",
|
||||
name="Master Strike",
|
||||
description="+15 Strength",
|
||||
tier=3,
|
||||
effects={"strength": 15},
|
||||
)
|
||||
sample_class.skill_trees[0].nodes.append(tier3_skill)
|
||||
|
||||
# Level 1 character with 1 skill already unlocked = 0 points remaining
|
||||
character = Character(
|
||||
character_id='char_123',
|
||||
user_id='user_001',
|
||||
name='Test Hero',
|
||||
player_class=sample_class,
|
||||
origin=sample_origin,
|
||||
level=1,
|
||||
unlocked_skills=['power_strike'] # Used the 1 point from level 1
|
||||
)
|
||||
|
||||
character_service.get_character = Mock(return_value=character)
|
||||
|
||||
# Try to unlock another skill (master_strike exists in class)
|
||||
with pytest.raises(SkillUnlockError) as exc_info:
|
||||
character_service.unlock_skill('char_123', 'user_001', 'master_strike')
|
||||
|
||||
assert 'No skill points available' in str(exc_info.value)
|
||||
|
||||
def test_respec_skills_success(self, character_service, sample_class, sample_origin):
|
||||
"""Test successfully respecing character skills."""
|
||||
# Level 5 character with 3 skills and 500 gold
|
||||
# Respec cost = 5 * 100 = 500 gold
|
||||
character = Character(
|
||||
character_id='char_123',
|
||||
user_id='user_001',
|
||||
name='Test Hero',
|
||||
player_class=sample_class,
|
||||
origin=sample_origin,
|
||||
level=5,
|
||||
gold=500,
|
||||
unlocked_skills=['skill1', 'skill2', 'skill3']
|
||||
)
|
||||
|
||||
character_service.get_character = Mock(return_value=character)
|
||||
character_service._save_character = Mock()
|
||||
|
||||
# Respec skills
|
||||
result = character_service.respec_skills('char_123', 'user_001')
|
||||
|
||||
# Assertions
|
||||
assert len(result.unlocked_skills) == 0 # Skills cleared
|
||||
assert result.gold == 0 # 500 - 500 = 0
|
||||
character_service._save_character.assert_called_once()
|
||||
|
||||
def test_respec_skills_insufficient_gold(self, character_service, sample_class, sample_origin):
|
||||
"""Test respec fails with insufficient gold."""
|
||||
# Level 5 character with only 100 gold (needs 500)
|
||||
character = Character(
|
||||
character_id='char_123',
|
||||
user_id='user_001',
|
||||
name='Test Hero',
|
||||
player_class=sample_class,
|
||||
origin=sample_origin,
|
||||
level=5,
|
||||
gold=100, # Not enough
|
||||
unlocked_skills=['skill1']
|
||||
)
|
||||
|
||||
character_service.get_character = Mock(return_value=character)
|
||||
|
||||
with pytest.raises(InsufficientGold) as exc_info:
|
||||
character_service.respec_skills('char_123', 'user_001')
|
||||
|
||||
assert '500' in str(exc_info.value) # Cost
|
||||
assert '100' in str(exc_info.value) # Available
|
||||
|
||||
def test_update_character_success(self, character_service, sample_class, sample_origin):
|
||||
"""Test successfully updating a character."""
|
||||
character = Character(
|
||||
character_id='char_123',
|
||||
user_id='user_001',
|
||||
name='Test Hero',
|
||||
player_class=sample_class,
|
||||
origin=sample_origin,
|
||||
gold=1000 # Updated gold
|
||||
)
|
||||
|
||||
# Mock get_character to verify ownership
|
||||
character_service.get_character = Mock(return_value=character)
|
||||
character_service._save_character = Mock()
|
||||
|
||||
# Update character
|
||||
result = character_service.update_character(character, 'user_001')
|
||||
|
||||
# Assertions
|
||||
assert result.gold == 1000
|
||||
character_service._save_character.assert_called_once()
|
||||
|
||||
def test_update_character_not_found(self, character_service, sample_class, sample_origin):
|
||||
"""Test updating non-existent character raises error."""
|
||||
character = Character(
|
||||
character_id='nonexistent',
|
||||
user_id='user_001',
|
||||
name='Test Hero',
|
||||
player_class=sample_class,
|
||||
origin=sample_origin
|
||||
)
|
||||
|
||||
character_service.get_character = Mock(side_effect=CharacterNotFound("Not found"))
|
||||
|
||||
with pytest.raises(CharacterNotFound):
|
||||
character_service.update_character(character, 'user_001')
|
||||
256
api/tests/test_class_loader.py
Normal file
256
api/tests/test_class_loader.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""
|
||||
Unit tests for ClassLoader service.
|
||||
|
||||
Tests loading player class definitions from YAML files.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from app.services.class_loader import ClassLoader
|
||||
from app.models.skills import PlayerClass, SkillTree, SkillNode
|
||||
from app.models.stats import Stats
|
||||
|
||||
|
||||
class TestClassLoader:
|
||||
"""Test ClassLoader functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def loader(self):
|
||||
"""Create a ClassLoader instance for testing."""
|
||||
return ClassLoader()
|
||||
|
||||
def test_load_vanguard_class(self, loader):
|
||||
"""Test loading the Vanguard class."""
|
||||
vanguard = loader.load_class("vanguard")
|
||||
|
||||
# Verify class loaded successfully
|
||||
assert vanguard is not None
|
||||
assert isinstance(vanguard, PlayerClass)
|
||||
|
||||
# Check basic properties
|
||||
assert vanguard.class_id == "vanguard"
|
||||
assert vanguard.name == "Vanguard"
|
||||
assert "melee combat" in vanguard.description.lower()
|
||||
|
||||
def test_vanguard_base_stats(self, loader):
|
||||
"""Test Vanguard base stats are correct."""
|
||||
vanguard = loader.load_class("vanguard")
|
||||
|
||||
assert isinstance(vanguard.base_stats, Stats)
|
||||
assert vanguard.base_stats.strength == 14
|
||||
assert vanguard.base_stats.dexterity == 10
|
||||
assert vanguard.base_stats.constitution == 14
|
||||
assert vanguard.base_stats.intelligence == 8
|
||||
assert vanguard.base_stats.wisdom == 10
|
||||
assert vanguard.base_stats.charisma == 9
|
||||
|
||||
def test_vanguard_skill_trees(self, loader):
|
||||
"""Test Vanguard has 2 skill trees."""
|
||||
vanguard = loader.load_class("vanguard")
|
||||
|
||||
# Should have exactly 2 skill trees
|
||||
assert len(vanguard.skill_trees) == 2
|
||||
|
||||
# Get trees by ID
|
||||
shield_bearer = vanguard.get_skill_tree("shield_bearer")
|
||||
weapon_master = vanguard.get_skill_tree("weapon_master")
|
||||
|
||||
assert shield_bearer is not None
|
||||
assert weapon_master is not None
|
||||
|
||||
# Check tree names
|
||||
assert shield_bearer.name == "Shield Bearer"
|
||||
assert weapon_master.name == "Weapon Master"
|
||||
|
||||
def test_shield_bearer_tree_structure(self, loader):
|
||||
"""Test Shield Bearer tree has correct structure."""
|
||||
vanguard = loader.load_class("vanguard")
|
||||
shield_bearer = vanguard.get_skill_tree("shield_bearer")
|
||||
|
||||
# Should have 10 nodes (5 tiers × 2 nodes)
|
||||
assert len(shield_bearer.nodes) == 10
|
||||
|
||||
# Check tier distribution
|
||||
tier_counts = {}
|
||||
for node in shield_bearer.nodes:
|
||||
tier_counts[node.tier] = tier_counts.get(node.tier, 0) + 1
|
||||
|
||||
# Should have 2 nodes per tier for tiers 1-5
|
||||
assert tier_counts == {1: 2, 2: 2, 3: 2, 4: 2, 5: 2}
|
||||
|
||||
def test_weapon_master_tree_structure(self, loader):
|
||||
"""Test Weapon Master tree has correct structure."""
|
||||
vanguard = loader.load_class("vanguard")
|
||||
weapon_master = vanguard.get_skill_tree("weapon_master")
|
||||
|
||||
# Should have 10 nodes (5 tiers × 2 nodes)
|
||||
assert len(weapon_master.nodes) == 10
|
||||
|
||||
# Check tier distribution
|
||||
tier_counts = {}
|
||||
for node in weapon_master.nodes:
|
||||
tier_counts[node.tier] = tier_counts.get(node.tier, 0) + 1
|
||||
|
||||
# Should have 2 nodes per tier for tiers 1-5
|
||||
assert tier_counts == {1: 2, 2: 2, 3: 2, 4: 2, 5: 2}
|
||||
|
||||
def test_skill_node_prerequisites(self, loader):
|
||||
"""Test skill nodes have correct prerequisites."""
|
||||
vanguard = loader.load_class("vanguard")
|
||||
shield_bearer = vanguard.get_skill_tree("shield_bearer")
|
||||
|
||||
# Find tier 1 and tier 2 nodes
|
||||
tier1_nodes = [n for n in shield_bearer.nodes if n.tier == 1]
|
||||
tier2_nodes = [n for n in shield_bearer.nodes if n.tier == 2]
|
||||
|
||||
# Tier 1 nodes should have no prerequisites
|
||||
for node in tier1_nodes:
|
||||
assert len(node.prerequisites) == 0
|
||||
|
||||
# Tier 2 nodes should have prerequisites
|
||||
for node in tier2_nodes:
|
||||
assert len(node.prerequisites) > 0
|
||||
# Prerequisites should reference tier 1 skills
|
||||
for prereq_id in node.prerequisites:
|
||||
prereq_found = any(n.skill_id == prereq_id for n in tier1_nodes)
|
||||
assert prereq_found, f"Prerequisite {prereq_id} not found in tier 1"
|
||||
|
||||
def test_skill_node_effects(self, loader):
|
||||
"""Test skill nodes have proper effects defined."""
|
||||
vanguard = loader.load_class("vanguard")
|
||||
shield_bearer = vanguard.get_skill_tree("shield_bearer")
|
||||
|
||||
# Find the "fortify" skill (passive defense bonus)
|
||||
fortify = next((n for n in shield_bearer.nodes if n.skill_id == "fortify"), None)
|
||||
assert fortify is not None
|
||||
|
||||
# Should have stat bonuses in effects
|
||||
assert "stat_bonuses" in fortify.effects
|
||||
assert "defense" in fortify.effects["stat_bonuses"]
|
||||
assert fortify.effects["stat_bonuses"]["defense"] == 5
|
||||
|
||||
def test_skill_node_abilities(self, loader):
|
||||
"""Test skill nodes with ability unlocks."""
|
||||
vanguard = loader.load_class("vanguard")
|
||||
shield_bearer = vanguard.get_skill_tree("shield_bearer")
|
||||
|
||||
# Find the "shield_bash" skill (active ability)
|
||||
shield_bash = next((n for n in shield_bearer.nodes if n.skill_id == "shield_bash"), None)
|
||||
assert shield_bash is not None
|
||||
|
||||
# Should have abilities in effects
|
||||
assert "abilities" in shield_bash.effects
|
||||
assert "shield_bash" in shield_bash.effects["abilities"]
|
||||
|
||||
def test_starting_equipment(self, loader):
|
||||
"""Test Vanguard starting equipment."""
|
||||
vanguard = loader.load_class("vanguard")
|
||||
|
||||
# Should have starting equipment
|
||||
assert len(vanguard.starting_equipment) > 0
|
||||
assert "rusty_sword" in vanguard.starting_equipment
|
||||
assert "cloth_armor" in vanguard.starting_equipment
|
||||
assert "rusty_knife" in vanguard.starting_equipment
|
||||
|
||||
def test_starting_abilities(self, loader):
|
||||
"""Test Vanguard starting abilities."""
|
||||
vanguard = loader.load_class("vanguard")
|
||||
|
||||
# Should have basic_attack
|
||||
assert len(vanguard.starting_abilities) > 0
|
||||
assert "basic_attack" in vanguard.starting_abilities
|
||||
|
||||
def test_cache_functionality(self, loader):
|
||||
"""Test that classes are cached after first load."""
|
||||
# Load class twice
|
||||
vanguard1 = loader.load_class("vanguard")
|
||||
vanguard2 = loader.load_class("vanguard")
|
||||
|
||||
# Should be the same object (cached)
|
||||
assert vanguard1 is vanguard2
|
||||
|
||||
def test_reload_class(self, loader):
|
||||
"""Test reload_class forces a fresh load."""
|
||||
# Load class
|
||||
vanguard1 = loader.load_class("vanguard")
|
||||
|
||||
# Reload class
|
||||
vanguard2 = loader.reload_class("vanguard")
|
||||
|
||||
# Should still be equal but different objects
|
||||
assert vanguard1.class_id == vanguard2.class_id
|
||||
# Note: May or may not be same object depending on implementation
|
||||
|
||||
def test_load_nonexistent_class(self, loader):
|
||||
"""Test loading a class that doesn't exist."""
|
||||
result = loader.load_class("nonexistent_class")
|
||||
assert result is None
|
||||
|
||||
def test_get_all_class_ids(self, loader):
|
||||
"""Test getting list of all class IDs."""
|
||||
class_ids = loader.get_all_class_ids()
|
||||
|
||||
# Should include vanguard
|
||||
assert "vanguard" in class_ids
|
||||
|
||||
# Should be a list of strings
|
||||
assert all(isinstance(cid, str) for cid in class_ids)
|
||||
|
||||
def test_load_all_classes(self, loader):
|
||||
"""Test loading all classes at once."""
|
||||
classes = loader.load_all_classes()
|
||||
|
||||
# Should have at least 1 class (vanguard)
|
||||
assert len(classes) >= 1
|
||||
|
||||
# All should be PlayerClass instances
|
||||
assert all(isinstance(c, PlayerClass) for c in classes)
|
||||
|
||||
# Vanguard should be in the list
|
||||
vanguard_found = any(c.class_id == "vanguard" for c in classes)
|
||||
assert vanguard_found
|
||||
|
||||
def test_get_all_skills_method(self, loader):
|
||||
"""Test PlayerClass.get_all_skills() method."""
|
||||
vanguard = loader.load_class("vanguard")
|
||||
|
||||
all_skills = vanguard.get_all_skills()
|
||||
|
||||
# Should have 20 skills (2 trees × 10 nodes)
|
||||
assert len(all_skills) == 20
|
||||
|
||||
# All should be SkillNode instances
|
||||
assert all(isinstance(skill, SkillNode) for skill in all_skills)
|
||||
|
||||
# Should include skills from both trees
|
||||
skill_ids = [s.skill_id for s in all_skills]
|
||||
assert "shield_bash" in skill_ids # Shield Bearer tree
|
||||
assert "power_strike" in skill_ids # Weapon Master tree
|
||||
|
||||
def test_tier_5_ultimate_skills(self, loader):
|
||||
"""Test that tier 5 skills exist and have powerful effects."""
|
||||
vanguard = loader.load_class("vanguard")
|
||||
all_skills = vanguard.get_all_skills()
|
||||
|
||||
# Get tier 5 skills
|
||||
tier5_skills = [s for s in all_skills if s.tier == 5]
|
||||
|
||||
# Should have 4 tier 5 skills (2 per tree)
|
||||
assert len(tier5_skills) == 4
|
||||
|
||||
# Each tier 5 skill should have prerequisites
|
||||
for skill in tier5_skills:
|
||||
assert len(skill.prerequisites) > 0
|
||||
|
||||
# At least one tier 5 skill should have significant stat bonuses
|
||||
has_major_bonuses = any(
|
||||
"stat_bonuses" in s.effects and
|
||||
any(v >= 10 for v in s.effects.get("stat_bonuses", {}).values())
|
||||
for s in tier5_skills
|
||||
)
|
||||
assert has_major_bonuses
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
509
api/tests/test_combat_simulation.py
Normal file
509
api/tests/test_combat_simulation.py
Normal file
@@ -0,0 +1,509 @@
|
||||
"""
|
||||
Integration tests for combat system.
|
||||
|
||||
Tests the complete combat flow including damage calculation, effects, and turn order.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from app.models.stats import Stats
|
||||
from app.models.items import Item
|
||||
from app.models.effects import Effect
|
||||
from app.models.abilities import Ability
|
||||
from app.models.combat import Combatant, CombatEncounter
|
||||
from app.models.character import Character
|
||||
from app.models.skills import PlayerClass
|
||||
from app.models.enums import (
|
||||
ItemType,
|
||||
DamageType,
|
||||
EffectType,
|
||||
StatType,
|
||||
AbilityType,
|
||||
CombatStatus,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def warrior_combatant():
|
||||
"""Create a warrior combatant for testing."""
|
||||
stats = Stats(strength=15, dexterity=10, constitution=14, intelligence=8, wisdom=10, charisma=11)
|
||||
|
||||
return Combatant(
|
||||
combatant_id="warrior_1",
|
||||
name="Test Warrior",
|
||||
is_player=True,
|
||||
current_hp=stats.hit_points,
|
||||
max_hp=stats.hit_points,
|
||||
current_mp=stats.mana_points,
|
||||
max_mp=stats.mana_points,
|
||||
stats=stats,
|
||||
abilities=["basic_attack", "power_strike"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def goblin_combatant():
|
||||
"""Create a goblin enemy for testing."""
|
||||
stats = Stats(strength=8, dexterity=12, constitution=10, intelligence=6, wisdom=8, charisma=6)
|
||||
|
||||
return Combatant(
|
||||
combatant_id="goblin_1",
|
||||
name="Goblin",
|
||||
is_player=False,
|
||||
current_hp=stats.hit_points,
|
||||
max_hp=stats.hit_points,
|
||||
current_mp=stats.mana_points,
|
||||
max_mp=stats.mana_points,
|
||||
stats=stats,
|
||||
abilities=["basic_attack"],
|
||||
)
|
||||
|
||||
|
||||
def test_combatant_creation(warrior_combatant):
|
||||
"""Test creating a Combatant."""
|
||||
assert warrior_combatant.combatant_id == "warrior_1"
|
||||
assert warrior_combatant.name == "Test Warrior"
|
||||
assert warrior_combatant.is_player == True
|
||||
assert warrior_combatant.is_alive() == True
|
||||
assert warrior_combatant.is_stunned() == False
|
||||
|
||||
|
||||
def test_combatant_take_damage(warrior_combatant):
|
||||
"""Test taking damage."""
|
||||
initial_hp = warrior_combatant.current_hp
|
||||
|
||||
damage_dealt = warrior_combatant.take_damage(10)
|
||||
|
||||
assert damage_dealt == 10
|
||||
assert warrior_combatant.current_hp == initial_hp - 10
|
||||
|
||||
|
||||
def test_combatant_take_damage_with_shield(warrior_combatant):
|
||||
"""Test taking damage with shield absorption."""
|
||||
# Add a shield effect
|
||||
shield = Effect(
|
||||
effect_id="shield_1",
|
||||
name="Shield",
|
||||
effect_type=EffectType.SHIELD,
|
||||
duration=3,
|
||||
power=15,
|
||||
)
|
||||
warrior_combatant.add_effect(shield)
|
||||
|
||||
initial_hp = warrior_combatant.current_hp
|
||||
|
||||
# Deal 10 damage - should be fully absorbed by shield
|
||||
damage_dealt = warrior_combatant.take_damage(10)
|
||||
|
||||
assert damage_dealt == 0 # No HP damage
|
||||
assert warrior_combatant.current_hp == initial_hp
|
||||
|
||||
|
||||
def test_combatant_death(warrior_combatant):
|
||||
"""Test combatant death."""
|
||||
assert warrior_combatant.is_alive() == True
|
||||
|
||||
# Deal massive damage
|
||||
warrior_combatant.take_damage(1000)
|
||||
|
||||
assert warrior_combatant.is_alive() == False
|
||||
assert warrior_combatant.is_dead() == True
|
||||
|
||||
|
||||
def test_combatant_healing(warrior_combatant):
|
||||
"""Test healing."""
|
||||
# Take some damage first
|
||||
warrior_combatant.take_damage(20)
|
||||
damaged_hp = warrior_combatant.current_hp
|
||||
|
||||
# Heal
|
||||
healed = warrior_combatant.heal(10)
|
||||
|
||||
assert healed == 10
|
||||
assert warrior_combatant.current_hp == damaged_hp + 10
|
||||
|
||||
|
||||
def test_combatant_healing_capped_at_max(warrior_combatant):
|
||||
"""Test that healing cannot exceed max HP."""
|
||||
max_hp = warrior_combatant.max_hp
|
||||
|
||||
# Try to heal beyond max
|
||||
healed = warrior_combatant.heal(1000)
|
||||
|
||||
assert warrior_combatant.current_hp == max_hp
|
||||
|
||||
|
||||
def test_combatant_stun_effect(warrior_combatant):
|
||||
"""Test stun effect prevents actions."""
|
||||
assert warrior_combatant.is_stunned() == False
|
||||
|
||||
# Add stun effect
|
||||
stun = Effect(
|
||||
effect_id="stun_1",
|
||||
name="Stunned",
|
||||
effect_type=EffectType.STUN,
|
||||
duration=1,
|
||||
power=0,
|
||||
)
|
||||
warrior_combatant.add_effect(stun)
|
||||
|
||||
assert warrior_combatant.is_stunned() == True
|
||||
|
||||
|
||||
def test_combatant_tick_effects(warrior_combatant):
|
||||
"""Test that ticking effects deals damage/healing."""
|
||||
# Add a DOT effect
|
||||
poison = Effect(
|
||||
effect_id="poison_1",
|
||||
name="Poison",
|
||||
effect_type=EffectType.DOT,
|
||||
duration=3,
|
||||
power=5,
|
||||
)
|
||||
warrior_combatant.add_effect(poison)
|
||||
|
||||
initial_hp = warrior_combatant.current_hp
|
||||
|
||||
# Tick effects
|
||||
results = warrior_combatant.tick_effects()
|
||||
|
||||
# Should have taken 5 poison damage
|
||||
assert len(results) == 1
|
||||
assert results[0]["effect_type"] == "dot"
|
||||
assert results[0]["value"] == 5
|
||||
assert warrior_combatant.current_hp == initial_hp - 5
|
||||
|
||||
|
||||
def test_combatant_effect_expiration(warrior_combatant):
|
||||
"""Test that expired effects are removed."""
|
||||
# Add effect with 1 turn duration
|
||||
dot = Effect(
|
||||
effect_id="burn_1",
|
||||
name="Burning",
|
||||
effect_type=EffectType.DOT,
|
||||
duration=1,
|
||||
power=5,
|
||||
)
|
||||
warrior_combatant.add_effect(dot)
|
||||
|
||||
assert len(warrior_combatant.active_effects) == 1
|
||||
|
||||
# Tick - effect should expire
|
||||
results = warrior_combatant.tick_effects()
|
||||
|
||||
assert results[0]["expired"] == True
|
||||
assert len(warrior_combatant.active_effects) == 0 # Removed
|
||||
|
||||
|
||||
def test_ability_mana_cost(warrior_combatant):
|
||||
"""Test ability mana cost and usage."""
|
||||
ability = Ability(
|
||||
ability_id="fireball",
|
||||
name="Fireball",
|
||||
description="Fiery explosion",
|
||||
ability_type=AbilityType.SPELL,
|
||||
base_power=30,
|
||||
damage_type=DamageType.FIRE,
|
||||
mana_cost=15,
|
||||
)
|
||||
|
||||
initial_mp = warrior_combatant.current_mp
|
||||
|
||||
# Check if can use
|
||||
assert warrior_combatant.can_use_ability("fireball", ability) == False # Not in ability list
|
||||
warrior_combatant.abilities.append("fireball")
|
||||
assert warrior_combatant.can_use_ability("fireball", ability) == True
|
||||
|
||||
# Use ability
|
||||
warrior_combatant.use_ability_cost(ability, "fireball")
|
||||
|
||||
assert warrior_combatant.current_mp == initial_mp - 15
|
||||
|
||||
|
||||
def test_ability_cooldown(warrior_combatant):
|
||||
"""Test ability cooldowns."""
|
||||
ability = Ability(
|
||||
ability_id="power_strike",
|
||||
name="Power Strike",
|
||||
description="Powerful attack",
|
||||
ability_type=AbilityType.SKILL,
|
||||
base_power=20,
|
||||
cooldown=3,
|
||||
)
|
||||
|
||||
warrior_combatant.abilities.append("power_strike")
|
||||
|
||||
# Can use initially
|
||||
assert warrior_combatant.can_use_ability("power_strike", ability) == True
|
||||
|
||||
# Use ability
|
||||
warrior_combatant.use_ability_cost(ability, "power_strike")
|
||||
|
||||
# Now on cooldown
|
||||
assert "power_strike" in warrior_combatant.cooldowns
|
||||
assert warrior_combatant.cooldowns["power_strike"] == 3
|
||||
assert warrior_combatant.can_use_ability("power_strike", ability) == False
|
||||
|
||||
# Tick cooldown
|
||||
warrior_combatant.tick_cooldowns()
|
||||
assert warrior_combatant.cooldowns["power_strike"] == 2
|
||||
|
||||
# Tick more
|
||||
warrior_combatant.tick_cooldowns()
|
||||
warrior_combatant.tick_cooldowns()
|
||||
|
||||
# Should be available again
|
||||
assert "power_strike" not in warrior_combatant.cooldowns
|
||||
assert warrior_combatant.can_use_ability("power_strike", ability) == True
|
||||
|
||||
|
||||
def test_combat_encounter_initialization(warrior_combatant, goblin_combatant):
|
||||
"""Test initializing a combat encounter."""
|
||||
encounter = CombatEncounter(
|
||||
encounter_id="combat_001",
|
||||
combatants=[warrior_combatant, goblin_combatant],
|
||||
)
|
||||
|
||||
encounter.initialize_combat()
|
||||
|
||||
# Should have turn order
|
||||
assert len(encounter.turn_order) == 2
|
||||
assert encounter.round_number == 1
|
||||
assert encounter.status == CombatStatus.ACTIVE
|
||||
|
||||
# Both combatants should have initiative
|
||||
assert warrior_combatant.initiative > 0
|
||||
assert goblin_combatant.initiative > 0
|
||||
|
||||
|
||||
def test_combat_turn_advancement(warrior_combatant, goblin_combatant):
|
||||
"""Test advancing turns in combat."""
|
||||
encounter = CombatEncounter(
|
||||
encounter_id="combat_001",
|
||||
combatants=[warrior_combatant, goblin_combatant],
|
||||
)
|
||||
|
||||
encounter.initialize_combat()
|
||||
|
||||
# Get first combatant
|
||||
first = encounter.get_current_combatant()
|
||||
assert first is not None
|
||||
|
||||
# Advance turn
|
||||
encounter.advance_turn()
|
||||
|
||||
# Should be second combatant now
|
||||
second = encounter.get_current_combatant()
|
||||
assert second is not None
|
||||
assert second.combatant_id != first.combatant_id
|
||||
|
||||
# Advance again - should cycle back to first and increment round
|
||||
encounter.advance_turn()
|
||||
|
||||
assert encounter.round_number == 2
|
||||
third = encounter.get_current_combatant()
|
||||
assert third.combatant_id == first.combatant_id
|
||||
|
||||
|
||||
def test_combat_victory_condition(warrior_combatant, goblin_combatant):
|
||||
"""Test victory condition detection."""
|
||||
encounter = CombatEncounter(
|
||||
encounter_id="combat_001",
|
||||
combatants=[warrior_combatant, goblin_combatant],
|
||||
)
|
||||
|
||||
encounter.initialize_combat()
|
||||
|
||||
# Kill the goblin
|
||||
goblin_combatant.current_hp = 0
|
||||
|
||||
# Check end condition
|
||||
status = encounter.check_end_condition()
|
||||
|
||||
assert status == CombatStatus.VICTORY
|
||||
assert encounter.status == CombatStatus.VICTORY
|
||||
|
||||
|
||||
def test_combat_defeat_condition(warrior_combatant, goblin_combatant):
|
||||
"""Test defeat condition detection."""
|
||||
encounter = CombatEncounter(
|
||||
encounter_id="combat_001",
|
||||
combatants=[warrior_combatant, goblin_combatant],
|
||||
)
|
||||
|
||||
encounter.initialize_combat()
|
||||
|
||||
# Kill the warrior
|
||||
warrior_combatant.current_hp = 0
|
||||
|
||||
# Check end condition
|
||||
status = encounter.check_end_condition()
|
||||
|
||||
assert status == CombatStatus.DEFEAT
|
||||
assert encounter.status == CombatStatus.DEFEAT
|
||||
|
||||
|
||||
def test_combat_start_turn_processing(warrior_combatant):
|
||||
"""Test start_turn() processes effects and cooldowns."""
|
||||
encounter = CombatEncounter(
|
||||
encounter_id="combat_001",
|
||||
combatants=[warrior_combatant],
|
||||
)
|
||||
|
||||
# Initialize combat to set turn order
|
||||
encounter.initialize_combat()
|
||||
|
||||
# Add a DOT effect
|
||||
poison = Effect(
|
||||
effect_id="poison_1",
|
||||
name="Poison",
|
||||
effect_type=EffectType.DOT,
|
||||
duration=3,
|
||||
power=5,
|
||||
)
|
||||
warrior_combatant.add_effect(poison)
|
||||
|
||||
# Add a cooldown
|
||||
warrior_combatant.cooldowns["power_strike"] = 2
|
||||
|
||||
initial_hp = warrior_combatant.current_hp
|
||||
|
||||
# Start turn
|
||||
results = encounter.start_turn()
|
||||
|
||||
# Effects should have ticked
|
||||
assert len(results) == 1
|
||||
assert warrior_combatant.current_hp == initial_hp - 5
|
||||
|
||||
# Cooldown should have decreased
|
||||
assert warrior_combatant.cooldowns["power_strike"] == 1
|
||||
|
||||
|
||||
def test_combat_logging(warrior_combatant, goblin_combatant):
|
||||
"""Test combat log entries."""
|
||||
encounter = CombatEncounter(
|
||||
encounter_id="combat_001",
|
||||
combatants=[warrior_combatant, goblin_combatant],
|
||||
)
|
||||
|
||||
encounter.log_action("attack", "warrior_1", "Warrior attacks Goblin for 10 damage")
|
||||
|
||||
assert len(encounter.combat_log) == 1
|
||||
assert encounter.combat_log[0]["action_type"] == "attack"
|
||||
assert encounter.combat_log[0]["combatant_id"] == "warrior_1"
|
||||
assert "Warrior attacks Goblin" in encounter.combat_log[0]["message"]
|
||||
|
||||
|
||||
def test_ability_damage_calculation():
|
||||
"""Test ability power calculation with stat scaling."""
|
||||
stats = Stats(strength=20, intelligence=16)
|
||||
|
||||
# Physical ability scaling with strength
|
||||
physical = Ability(
|
||||
ability_id="cleave",
|
||||
name="Cleave",
|
||||
description="Powerful strike",
|
||||
ability_type=AbilityType.SKILL,
|
||||
base_power=15,
|
||||
scaling_stat=StatType.STRENGTH,
|
||||
scaling_factor=0.5,
|
||||
)
|
||||
|
||||
power = physical.calculate_power(stats)
|
||||
# 15 (base) + (20 strength × 0.5) = 15 + 10 = 25
|
||||
assert power == 25
|
||||
|
||||
# Magical ability scaling with intelligence
|
||||
magical = Ability(
|
||||
ability_id="fireball",
|
||||
name="Fireball",
|
||||
description="Fire spell",
|
||||
ability_type=AbilityType.SPELL,
|
||||
base_power=20,
|
||||
scaling_stat=StatType.INTELLIGENCE,
|
||||
scaling_factor=0.5,
|
||||
)
|
||||
|
||||
power = magical.calculate_power(stats)
|
||||
# 20 (base) + (16 intelligence × 0.5) = 20 + 8 = 28
|
||||
assert power == 28
|
||||
|
||||
|
||||
def test_full_combat_simulation():
|
||||
"""Integration test: Full combat simulation with all systems."""
|
||||
# Create warrior
|
||||
warrior_stats = Stats(strength=15, constitution=14)
|
||||
warrior = Combatant(
|
||||
combatant_id="hero",
|
||||
name="Hero",
|
||||
is_player=True,
|
||||
current_hp=warrior_stats.hit_points,
|
||||
max_hp=warrior_stats.hit_points,
|
||||
current_mp=warrior_stats.mana_points,
|
||||
max_mp=warrior_stats.mana_points,
|
||||
stats=warrior_stats,
|
||||
)
|
||||
|
||||
# Create goblin
|
||||
goblin_stats = Stats(strength=8, constitution=10)
|
||||
goblin = Combatant(
|
||||
combatant_id="goblin",
|
||||
name="Goblin",
|
||||
is_player=False,
|
||||
current_hp=goblin_stats.hit_points,
|
||||
max_hp=goblin_stats.hit_points,
|
||||
current_mp=goblin_stats.mana_points,
|
||||
max_mp=goblin_stats.mana_points,
|
||||
stats=goblin_stats,
|
||||
)
|
||||
|
||||
# Create encounter
|
||||
encounter = CombatEncounter(
|
||||
encounter_id="test_combat",
|
||||
combatants=[warrior, goblin],
|
||||
)
|
||||
|
||||
encounter.initialize_combat()
|
||||
|
||||
# Verify setup
|
||||
assert encounter.status == CombatStatus.ACTIVE
|
||||
assert len(encounter.turn_order) == 2
|
||||
assert warrior.is_alive() and goblin.is_alive()
|
||||
|
||||
# Simulate turns until combat ends
|
||||
max_turns = 50 # Increased to ensure combat completes
|
||||
turn_count = 0
|
||||
|
||||
while encounter.status == CombatStatus.ACTIVE and turn_count < max_turns:
|
||||
# Get current combatant
|
||||
current = encounter.get_current_combatant()
|
||||
|
||||
# Start turn (tick effects)
|
||||
encounter.start_turn()
|
||||
|
||||
if current and current.is_alive() and not current.is_stunned():
|
||||
# Simple AI: deal damage to opponent
|
||||
if current.combatant_id == "hero":
|
||||
target = goblin
|
||||
else:
|
||||
target = warrior
|
||||
|
||||
# Calculate simple attack damage: strength / 2 - target defense
|
||||
damage = max(1, (current.stats.strength // 2) - target.stats.defense)
|
||||
target.take_damage(damage)
|
||||
|
||||
encounter.log_action(
|
||||
"attack",
|
||||
current.combatant_id,
|
||||
f"{current.name} attacks {target.name} for {damage} damage",
|
||||
)
|
||||
|
||||
# Check for combat end
|
||||
encounter.check_end_condition()
|
||||
|
||||
# Advance turn
|
||||
encounter.advance_turn()
|
||||
turn_count += 1
|
||||
|
||||
# Combat should have ended
|
||||
assert encounter.status in [CombatStatus.VICTORY, CombatStatus.DEFEAT]
|
||||
assert len(encounter.combat_log) > 0
|
||||
361
api/tests/test_effects.py
Normal file
361
api/tests/test_effects.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""
|
||||
Unit tests for Effect dataclass.
|
||||
|
||||
Tests all effect types, tick() method, stacking, and serialization.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from app.models.effects import Effect
|
||||
from app.models.enums import EffectType, StatType
|
||||
|
||||
|
||||
def test_effect_creation():
|
||||
"""Test creating an Effect instance."""
|
||||
effect = Effect(
|
||||
effect_id="burn_1",
|
||||
name="Burning",
|
||||
effect_type=EffectType.DOT,
|
||||
duration=3,
|
||||
power=5,
|
||||
)
|
||||
|
||||
assert effect.effect_id == "burn_1"
|
||||
assert effect.name == "Burning"
|
||||
assert effect.effect_type == EffectType.DOT
|
||||
assert effect.duration == 3
|
||||
assert effect.power == 5
|
||||
assert effect.stacks == 1
|
||||
assert effect.max_stacks == 5
|
||||
|
||||
|
||||
def test_dot_effect_tick():
|
||||
"""Test DOT (damage over time) effect ticking."""
|
||||
effect = Effect(
|
||||
effect_id="poison_1",
|
||||
name="Poisoned",
|
||||
effect_type=EffectType.DOT,
|
||||
duration=3,
|
||||
power=10,
|
||||
stacks=2,
|
||||
)
|
||||
|
||||
result = effect.tick()
|
||||
|
||||
assert result["effect_type"] == "dot"
|
||||
assert result["value"] == 20 # 10 power × 2 stacks
|
||||
assert result["expired"] == False
|
||||
assert effect.duration == 2 # Reduced by 1
|
||||
|
||||
|
||||
def test_hot_effect_tick():
|
||||
"""Test HOT (heal over time) effect ticking."""
|
||||
effect = Effect(
|
||||
effect_id="regen_1",
|
||||
name="Regeneration",
|
||||
effect_type=EffectType.HOT,
|
||||
duration=5,
|
||||
power=8,
|
||||
stacks=1,
|
||||
)
|
||||
|
||||
result = effect.tick()
|
||||
|
||||
assert result["effect_type"] == "hot"
|
||||
assert result["value"] == 8 # 8 power × 1 stack
|
||||
assert result["expired"] == False
|
||||
assert effect.duration == 4
|
||||
|
||||
|
||||
def test_stun_effect_tick():
|
||||
"""Test STUN effect ticking."""
|
||||
effect = Effect(
|
||||
effect_id="stun_1",
|
||||
name="Stunned",
|
||||
effect_type=EffectType.STUN,
|
||||
duration=1,
|
||||
power=0,
|
||||
)
|
||||
|
||||
result = effect.tick()
|
||||
|
||||
assert result["effect_type"] == "stun"
|
||||
assert result.get("stunned") == True
|
||||
assert effect.duration == 0
|
||||
|
||||
|
||||
def test_shield_effect_tick():
|
||||
"""Test SHIELD effect ticking."""
|
||||
effect = Effect(
|
||||
effect_id="shield_1",
|
||||
name="Shield",
|
||||
effect_type=EffectType.SHIELD,
|
||||
duration=3,
|
||||
power=50,
|
||||
)
|
||||
|
||||
result = effect.tick()
|
||||
|
||||
assert result["effect_type"] == "shield"
|
||||
assert result["shield_remaining"] == 50
|
||||
assert effect.duration == 2
|
||||
|
||||
|
||||
def test_buff_effect_tick():
|
||||
"""Test BUFF effect ticking."""
|
||||
effect = Effect(
|
||||
effect_id="str_buff_1",
|
||||
name="Strength Boost",
|
||||
effect_type=EffectType.BUFF,
|
||||
duration=4,
|
||||
power=5,
|
||||
stat_affected=StatType.STRENGTH,
|
||||
stacks=2,
|
||||
)
|
||||
|
||||
result = effect.tick()
|
||||
|
||||
assert result["effect_type"] == "buff"
|
||||
assert result["stat_affected"] == "strength"
|
||||
assert result["stat_modifier"] == 10 # 5 power × 2 stacks
|
||||
assert effect.duration == 3
|
||||
|
||||
|
||||
def test_debuff_effect_tick():
|
||||
"""Test DEBUFF effect ticking."""
|
||||
effect = Effect(
|
||||
effect_id="weak_1",
|
||||
name="Weakened",
|
||||
effect_type=EffectType.DEBUFF,
|
||||
duration=2,
|
||||
power=3,
|
||||
stat_affected=StatType.STRENGTH,
|
||||
)
|
||||
|
||||
result = effect.tick()
|
||||
|
||||
assert result["effect_type"] == "debuff"
|
||||
assert result["stat_affected"] == "strength"
|
||||
assert result["stat_modifier"] == 3
|
||||
assert effect.duration == 1
|
||||
|
||||
|
||||
def test_effect_expiration():
|
||||
"""Test that effect expires when duration reaches 0."""
|
||||
effect = Effect(
|
||||
effect_id="burn_1",
|
||||
name="Burning",
|
||||
effect_type=EffectType.DOT,
|
||||
duration=1,
|
||||
power=5,
|
||||
)
|
||||
|
||||
result = effect.tick()
|
||||
|
||||
assert result["expired"] == True
|
||||
assert effect.duration == 0
|
||||
|
||||
|
||||
def test_effect_stacking():
|
||||
"""Test apply_stack() increases stacks up to max."""
|
||||
effect = Effect(
|
||||
effect_id="poison_1",
|
||||
name="Poison",
|
||||
effect_type=EffectType.DOT,
|
||||
duration=3,
|
||||
power=5,
|
||||
max_stacks=5,
|
||||
)
|
||||
|
||||
assert effect.stacks == 1
|
||||
|
||||
effect.apply_stack()
|
||||
assert effect.stacks == 2
|
||||
|
||||
effect.apply_stack()
|
||||
assert effect.stacks == 3
|
||||
|
||||
# Apply 3 more to reach max
|
||||
effect.apply_stack()
|
||||
effect.apply_stack()
|
||||
effect.apply_stack()
|
||||
assert effect.stacks == 5 # Capped at max_stacks
|
||||
|
||||
# Try to apply one more - should still be 5
|
||||
effect.apply_stack()
|
||||
assert effect.stacks == 5
|
||||
|
||||
|
||||
def test_shield_damage_absorption():
|
||||
"""Test reduce_shield() method."""
|
||||
effect = Effect(
|
||||
effect_id="shield_1",
|
||||
name="Shield",
|
||||
effect_type=EffectType.SHIELD,
|
||||
duration=3,
|
||||
power=50,
|
||||
stacks=1,
|
||||
)
|
||||
|
||||
# Shield absorbs 20 damage
|
||||
remaining = effect.reduce_shield(20)
|
||||
assert remaining == 0 # All damage absorbed
|
||||
assert effect.power == 30 # Shield reduced by 20
|
||||
|
||||
# Shield absorbs 20 more
|
||||
remaining = effect.reduce_shield(20)
|
||||
assert remaining == 0
|
||||
assert effect.power == 10
|
||||
|
||||
# Shield takes 15 damage, breaks completely
|
||||
remaining = effect.reduce_shield(15)
|
||||
assert remaining == 5 # 5 damage passes through
|
||||
assert effect.power == 0
|
||||
assert effect.duration == 0 # Effect expires
|
||||
|
||||
|
||||
def test_shield_with_stacks():
|
||||
"""Test shield absorption with multiple stacks."""
|
||||
effect = Effect(
|
||||
effect_id="shield_1",
|
||||
name="Shield",
|
||||
effect_type=EffectType.SHIELD,
|
||||
duration=3,
|
||||
power=20,
|
||||
stacks=3,
|
||||
)
|
||||
|
||||
# Total shield = 20 × 3 = 60
|
||||
# Apply 50 damage
|
||||
remaining = effect.reduce_shield(50)
|
||||
assert remaining == 0 # All absorbed
|
||||
# Shield reduced: 50 / 3 stacks = 16.67 per stack
|
||||
assert effect.power < 20 # Power reduced
|
||||
|
||||
|
||||
def test_effect_serialization():
|
||||
"""Test to_dict() serialization."""
|
||||
effect = Effect(
|
||||
effect_id="burn_1",
|
||||
name="Burning",
|
||||
effect_type=EffectType.DOT,
|
||||
duration=3,
|
||||
power=5,
|
||||
stacks=2,
|
||||
max_stacks=5,
|
||||
source="fireball",
|
||||
)
|
||||
|
||||
data = effect.to_dict()
|
||||
|
||||
assert data["effect_id"] == "burn_1"
|
||||
assert data["name"] == "Burning"
|
||||
assert data["effect_type"] == "dot"
|
||||
assert data["duration"] == 3
|
||||
assert data["power"] == 5
|
||||
assert data["stacks"] == 2
|
||||
assert data["max_stacks"] == 5
|
||||
assert data["source"] == "fireball"
|
||||
|
||||
|
||||
def test_effect_deserialization():
|
||||
"""Test from_dict() deserialization."""
|
||||
data = {
|
||||
"effect_id": "regen_1",
|
||||
"name": "Regeneration",
|
||||
"effect_type": "hot",
|
||||
"duration": 5,
|
||||
"power": 10,
|
||||
"stat_affected": None,
|
||||
"stacks": 1,
|
||||
"max_stacks": 5,
|
||||
"source": "potion",
|
||||
}
|
||||
|
||||
effect = Effect.from_dict(data)
|
||||
|
||||
assert effect.effect_id == "regen_1"
|
||||
assert effect.name == "Regeneration"
|
||||
assert effect.effect_type == EffectType.HOT
|
||||
assert effect.duration == 5
|
||||
assert effect.power == 10
|
||||
assert effect.stacks == 1
|
||||
|
||||
|
||||
def test_effect_with_stat_deserialization():
|
||||
"""Test deserializing effect with stat_affected."""
|
||||
data = {
|
||||
"effect_id": "buff_1",
|
||||
"name": "Strength Boost",
|
||||
"effect_type": "buff",
|
||||
"duration": 3,
|
||||
"power": 5,
|
||||
"stat_affected": "strength",
|
||||
"stacks": 1,
|
||||
"max_stacks": 5,
|
||||
"source": "spell",
|
||||
}
|
||||
|
||||
effect = Effect.from_dict(data)
|
||||
|
||||
assert effect.stat_affected == StatType.STRENGTH
|
||||
assert effect.effect_type == EffectType.BUFF
|
||||
|
||||
|
||||
def test_effect_round_trip_serialization():
|
||||
"""Test that serialization and deserialization preserve data."""
|
||||
original = Effect(
|
||||
effect_id="test_effect",
|
||||
name="Test Effect",
|
||||
effect_type=EffectType.DEBUFF,
|
||||
duration=10,
|
||||
power=15,
|
||||
stat_affected=StatType.CONSTITUTION,
|
||||
stacks=3,
|
||||
max_stacks=5,
|
||||
source="enemy_spell",
|
||||
)
|
||||
|
||||
# Serialize then deserialize
|
||||
data = original.to_dict()
|
||||
restored = Effect.from_dict(data)
|
||||
|
||||
assert restored.effect_id == original.effect_id
|
||||
assert restored.name == original.name
|
||||
assert restored.effect_type == original.effect_type
|
||||
assert restored.duration == original.duration
|
||||
assert restored.power == original.power
|
||||
assert restored.stat_affected == original.stat_affected
|
||||
assert restored.stacks == original.stacks
|
||||
assert restored.max_stacks == original.max_stacks
|
||||
assert restored.source == original.source
|
||||
|
||||
|
||||
def test_effect_repr():
|
||||
"""Test string representation."""
|
||||
effect = Effect(
|
||||
effect_id="poison_1",
|
||||
name="Poison",
|
||||
effect_type=EffectType.DOT,
|
||||
duration=3,
|
||||
power=5,
|
||||
stacks=2,
|
||||
)
|
||||
|
||||
repr_str = repr(effect)
|
||||
|
||||
assert "Poison" in repr_str
|
||||
assert "dot" in repr_str
|
||||
|
||||
|
||||
def test_non_shield_reduce_shield():
|
||||
"""Test that reduce_shield() on non-SHIELD effects returns full damage."""
|
||||
effect = Effect(
|
||||
effect_id="burn_1",
|
||||
name="Burning",
|
||||
effect_type=EffectType.DOT,
|
||||
duration=3,
|
||||
power=5,
|
||||
)
|
||||
|
||||
remaining = effect.reduce_shield(50)
|
||||
assert remaining == 50 # All damage passes through for non-shield effects
|
||||
294
api/tests/test_model_selector.py
Normal file
294
api/tests/test_model_selector.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""
|
||||
Unit tests for model selector module.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.ai import (
|
||||
ModelSelector,
|
||||
ModelConfig,
|
||||
UserTier,
|
||||
ContextType,
|
||||
ModelType,
|
||||
)
|
||||
|
||||
|
||||
class TestModelSelector:
|
||||
"""Tests for ModelSelector class."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.selector = ModelSelector()
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test ModelSelector initializes correctly."""
|
||||
assert self.selector is not None
|
||||
|
||||
# Test tier to model mapping
|
||||
def test_free_tier_gets_llama(self):
|
||||
"""Free tier should get Llama-3 8B."""
|
||||
config = self.selector.select_model(UserTier.FREE)
|
||||
assert config.model_type == ModelType.LLAMA_3_8B
|
||||
|
||||
def test_basic_tier_gets_haiku(self):
|
||||
"""Basic tier should get Claude Haiku."""
|
||||
config = self.selector.select_model(UserTier.BASIC)
|
||||
assert config.model_type == ModelType.CLAUDE_HAIKU
|
||||
|
||||
def test_premium_tier_gets_sonnet(self):
|
||||
"""Premium tier should get Claude Sonnet."""
|
||||
config = self.selector.select_model(UserTier.PREMIUM)
|
||||
assert config.model_type == ModelType.CLAUDE_SONNET
|
||||
|
||||
def test_elite_tier_gets_opus(self):
|
||||
"""Elite tier should get Claude Opus."""
|
||||
config = self.selector.select_model(UserTier.ELITE)
|
||||
assert config.model_type == ModelType.CLAUDE_SONNET_4
|
||||
|
||||
# Test token limits by tier (using STORY_PROGRESSION for full allocation)
|
||||
def test_free_tier_token_limit(self):
|
||||
"""Free tier should have 256 base tokens."""
|
||||
config = self.selector.select_model(UserTier.FREE, ContextType.STORY_PROGRESSION)
|
||||
assert config.max_tokens == 256
|
||||
|
||||
def test_basic_tier_token_limit(self):
|
||||
"""Basic tier should have 512 base tokens."""
|
||||
config = self.selector.select_model(UserTier.BASIC, ContextType.STORY_PROGRESSION)
|
||||
assert config.max_tokens == 512
|
||||
|
||||
def test_premium_tier_token_limit(self):
|
||||
"""Premium tier should have 1024 base tokens."""
|
||||
config = self.selector.select_model(UserTier.PREMIUM, ContextType.STORY_PROGRESSION)
|
||||
assert config.max_tokens == 1024
|
||||
|
||||
def test_elite_tier_token_limit(self):
|
||||
"""Elite tier should have 2048 base tokens."""
|
||||
config = self.selector.select_model(UserTier.ELITE, ContextType.STORY_PROGRESSION)
|
||||
assert config.max_tokens == 2048
|
||||
|
||||
# Test context-based token adjustments
|
||||
def test_story_progression_full_tokens(self):
|
||||
"""Story progression should use full token allocation."""
|
||||
config = self.selector.select_model(
|
||||
UserTier.PREMIUM,
|
||||
ContextType.STORY_PROGRESSION
|
||||
)
|
||||
# Full allocation = 1024 tokens for premium
|
||||
assert config.max_tokens == 1024
|
||||
|
||||
def test_combat_narration_reduced_tokens(self):
|
||||
"""Combat narration should use 75% of tokens."""
|
||||
config = self.selector.select_model(
|
||||
UserTier.PREMIUM,
|
||||
ContextType.COMBAT_NARRATION
|
||||
)
|
||||
# 75% of 1024 = 768
|
||||
assert config.max_tokens == 768
|
||||
|
||||
def test_quest_selection_half_tokens(self):
|
||||
"""Quest selection should use 50% of tokens."""
|
||||
config = self.selector.select_model(
|
||||
UserTier.PREMIUM,
|
||||
ContextType.QUEST_SELECTION
|
||||
)
|
||||
# 50% of 1024 = 512
|
||||
assert config.max_tokens == 512
|
||||
|
||||
def test_npc_dialogue_reduced_tokens(self):
|
||||
"""NPC dialogue should use 75% of tokens."""
|
||||
config = self.selector.select_model(
|
||||
UserTier.PREMIUM,
|
||||
ContextType.NPC_DIALOGUE
|
||||
)
|
||||
# 75% of 1024 = 768
|
||||
assert config.max_tokens == 768
|
||||
|
||||
def test_simple_response_half_tokens(self):
|
||||
"""Simple response should use 50% of tokens."""
|
||||
config = self.selector.select_model(
|
||||
UserTier.PREMIUM,
|
||||
ContextType.SIMPLE_RESPONSE
|
||||
)
|
||||
# 50% of 1024 = 512
|
||||
assert config.max_tokens == 512
|
||||
|
||||
# Test context-based temperature settings
|
||||
def test_story_progression_high_temperature(self):
|
||||
"""Story progression should have high temperature (0.9)."""
|
||||
config = self.selector.select_model(
|
||||
UserTier.PREMIUM,
|
||||
ContextType.STORY_PROGRESSION
|
||||
)
|
||||
assert config.temperature == 0.9
|
||||
|
||||
def test_combat_narration_medium_high_temperature(self):
|
||||
"""Combat narration should have medium-high temperature (0.8)."""
|
||||
config = self.selector.select_model(
|
||||
UserTier.PREMIUM,
|
||||
ContextType.COMBAT_NARRATION
|
||||
)
|
||||
assert config.temperature == 0.8
|
||||
|
||||
def test_quest_selection_low_temperature(self):
|
||||
"""Quest selection should have low temperature (0.5)."""
|
||||
config = self.selector.select_model(
|
||||
UserTier.PREMIUM,
|
||||
ContextType.QUEST_SELECTION
|
||||
)
|
||||
assert config.temperature == 0.5
|
||||
|
||||
def test_npc_dialogue_medium_temperature(self):
|
||||
"""NPC dialogue should have medium temperature (0.85)."""
|
||||
config = self.selector.select_model(
|
||||
UserTier.PREMIUM,
|
||||
ContextType.NPC_DIALOGUE
|
||||
)
|
||||
assert config.temperature == 0.85
|
||||
|
||||
def test_simple_response_balanced_temperature(self):
|
||||
"""Simple response should have balanced temperature (0.7)."""
|
||||
config = self.selector.select_model(
|
||||
UserTier.PREMIUM,
|
||||
ContextType.SIMPLE_RESPONSE
|
||||
)
|
||||
assert config.temperature == 0.7
|
||||
|
||||
# Test ModelConfig properties
|
||||
def test_model_config_model_property(self):
|
||||
"""ModelConfig.model should return model identifier string."""
|
||||
config = self.selector.select_model(UserTier.PREMIUM)
|
||||
assert config.model == "anthropic/claude-3.5-sonnet"
|
||||
|
||||
# Test get_model_for_tier method
|
||||
def test_get_model_for_tier_free(self):
|
||||
"""get_model_for_tier should return correct model for free tier."""
|
||||
model = self.selector.get_model_for_tier(UserTier.FREE)
|
||||
assert model == ModelType.LLAMA_3_8B
|
||||
|
||||
def test_get_model_for_tier_elite(self):
|
||||
"""get_model_for_tier should return correct model for elite tier."""
|
||||
model = self.selector.get_model_for_tier(UserTier.ELITE)
|
||||
assert model == ModelType.CLAUDE_SONNET_4
|
||||
|
||||
# Test get_tier_info method
|
||||
def test_get_tier_info_structure(self):
|
||||
"""get_tier_info should return complete tier information."""
|
||||
info = self.selector.get_tier_info(UserTier.PREMIUM)
|
||||
|
||||
assert "tier" in info
|
||||
assert "model" in info
|
||||
assert "model_name" in info
|
||||
assert "base_tokens" in info
|
||||
assert "quality" in info
|
||||
|
||||
def test_get_tier_info_premium_values(self):
|
||||
"""get_tier_info should return correct values for premium tier."""
|
||||
info = self.selector.get_tier_info(UserTier.PREMIUM)
|
||||
|
||||
assert info["tier"] == "premium"
|
||||
assert info["model"] == "anthropic/claude-3.5-sonnet"
|
||||
assert info["model_name"] == "Claude 3.5 Sonnet"
|
||||
assert info["base_tokens"] == 1024
|
||||
|
||||
def test_get_tier_info_free_values(self):
|
||||
"""get_tier_info should return correct values for free tier."""
|
||||
info = self.selector.get_tier_info(UserTier.FREE)
|
||||
|
||||
assert info["tier"] == "free"
|
||||
assert info["model_name"] == "Llama 3 8B"
|
||||
assert info["base_tokens"] == 256
|
||||
|
||||
# Test estimate_cost_per_request method
|
||||
def test_free_tier_zero_cost(self):
|
||||
"""Free tier should have zero cost."""
|
||||
cost = self.selector.estimate_cost_per_request(UserTier.FREE)
|
||||
assert cost == 0.0
|
||||
|
||||
def test_basic_tier_has_cost(self):
|
||||
"""Basic tier should have non-zero cost."""
|
||||
cost = self.selector.estimate_cost_per_request(UserTier.BASIC)
|
||||
assert cost > 0
|
||||
|
||||
def test_premium_tier_higher_cost(self):
|
||||
"""Premium tier should have higher cost than basic."""
|
||||
basic_cost = self.selector.estimate_cost_per_request(UserTier.BASIC)
|
||||
premium_cost = self.selector.estimate_cost_per_request(UserTier.PREMIUM)
|
||||
assert premium_cost > basic_cost
|
||||
|
||||
def test_elite_tier_highest_cost(self):
|
||||
"""Elite tier should have highest cost."""
|
||||
premium_cost = self.selector.estimate_cost_per_request(UserTier.PREMIUM)
|
||||
elite_cost = self.selector.estimate_cost_per_request(UserTier.ELITE)
|
||||
assert elite_cost > premium_cost
|
||||
|
||||
# Test all tier combinations
|
||||
def test_all_tiers_return_valid_config(self):
|
||||
"""All tiers should return valid ModelConfig objects."""
|
||||
for tier in UserTier:
|
||||
config = self.selector.select_model(tier)
|
||||
assert isinstance(config, ModelConfig)
|
||||
assert config.model_type in ModelType
|
||||
assert config.max_tokens > 0
|
||||
assert 0 <= config.temperature <= 1
|
||||
|
||||
# Test all context combinations
|
||||
def test_all_contexts_return_valid_config(self):
|
||||
"""All context types should return valid ModelConfig objects."""
|
||||
for context in ContextType:
|
||||
config = self.selector.select_model(UserTier.PREMIUM, context)
|
||||
assert isinstance(config, ModelConfig)
|
||||
assert config.max_tokens > 0
|
||||
assert 0 <= config.temperature <= 1
|
||||
|
||||
|
||||
class TestUserTierEnum:
|
||||
"""Tests for UserTier enum."""
|
||||
|
||||
def test_tier_values(self):
|
||||
"""Test UserTier enum values are correct strings."""
|
||||
assert UserTier.FREE.value == "free"
|
||||
assert UserTier.BASIC.value == "basic"
|
||||
assert UserTier.PREMIUM.value == "premium"
|
||||
assert UserTier.ELITE.value == "elite"
|
||||
|
||||
def test_tier_string_conversion(self):
|
||||
"""Test UserTier can be converted to string."""
|
||||
assert str(UserTier.FREE) == "UserTier.FREE"
|
||||
|
||||
|
||||
class TestContextTypeEnum:
|
||||
"""Tests for ContextType enum."""
|
||||
|
||||
def test_context_values(self):
|
||||
"""Test ContextType enum values are correct strings."""
|
||||
assert ContextType.STORY_PROGRESSION.value == "story_progression"
|
||||
assert ContextType.COMBAT_NARRATION.value == "combat_narration"
|
||||
assert ContextType.QUEST_SELECTION.value == "quest_selection"
|
||||
assert ContextType.NPC_DIALOGUE.value == "npc_dialogue"
|
||||
assert ContextType.SIMPLE_RESPONSE.value == "simple_response"
|
||||
|
||||
|
||||
class TestModelConfig:
|
||||
"""Tests for ModelConfig dataclass."""
|
||||
|
||||
def test_model_config_creation(self):
|
||||
"""Test ModelConfig can be created with valid data."""
|
||||
config = ModelConfig(
|
||||
model_type=ModelType.CLAUDE_SONNET,
|
||||
max_tokens=1024,
|
||||
temperature=0.9
|
||||
)
|
||||
|
||||
assert config.model_type == ModelType.CLAUDE_SONNET
|
||||
assert config.max_tokens == 1024
|
||||
assert config.temperature == 0.9
|
||||
|
||||
def test_model_property(self):
|
||||
"""Test model property returns model identifier."""
|
||||
config = ModelConfig(
|
||||
model_type=ModelType.LLAMA_3_8B,
|
||||
max_tokens=256,
|
||||
temperature=0.7
|
||||
)
|
||||
|
||||
assert config.model == "meta/meta-llama-3-8b-instruct"
|
||||
583
api/tests/test_narrative_generator.py
Normal file
583
api/tests/test_narrative_generator.py
Normal file
@@ -0,0 +1,583 @@
|
||||
"""
|
||||
Tests for the NarrativeGenerator wrapper.
|
||||
|
||||
These tests use mocked AI clients to verify the generator's
|
||||
behavior without making actual API calls.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from app.ai.narrative_generator import (
|
||||
NarrativeGenerator,
|
||||
NarrativeResponse,
|
||||
NarrativeGeneratorError,
|
||||
)
|
||||
from app.ai.replicate_client import ReplicateResponse, ReplicateClientError
|
||||
from app.ai.model_selector import UserTier, ContextType, ModelSelector, ModelConfig, ModelType
|
||||
from app.ai.prompt_templates import PromptTemplates
|
||||
|
||||
|
||||
# Test fixtures
|
||||
@pytest.fixture
|
||||
def mock_replicate_client():
|
||||
"""Create a mock Replicate client."""
|
||||
client = MagicMock()
|
||||
client.generate.return_value = ReplicateResponse(
|
||||
text="The tavern falls silent as you step through the doorway...",
|
||||
tokens_used=150,
|
||||
model="meta/meta-llama-3-8b-instruct",
|
||||
generation_time=2.5
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def generator(mock_replicate_client):
|
||||
"""Create a NarrativeGenerator with mocked dependencies."""
|
||||
return NarrativeGenerator(replicate_client=mock_replicate_client)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_character():
|
||||
"""Sample character data for tests."""
|
||||
return {
|
||||
"name": "Aldric",
|
||||
"level": 3,
|
||||
"player_class": "Fighter",
|
||||
"current_hp": 25,
|
||||
"max_hp": 30,
|
||||
"stats": {
|
||||
"strength": 16,
|
||||
"dexterity": 12,
|
||||
"constitution": 14,
|
||||
"intelligence": 10,
|
||||
"wisdom": 11,
|
||||
"charisma": 8
|
||||
},
|
||||
"skills": [
|
||||
{"name": "Sword Mastery", "level": 2},
|
||||
{"name": "Shield Block", "level": 1}
|
||||
],
|
||||
"effects": [],
|
||||
"completed_quests": []
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_game_state():
|
||||
"""Sample game state for tests."""
|
||||
return {
|
||||
"current_location": "The Rusty Anchor Tavern",
|
||||
"location_type": "TAVERN",
|
||||
"discovered_locations": ["Crossroads Village", "Dark Forest"],
|
||||
"active_quests": [],
|
||||
"time_of_day": "Evening"
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_combat_state():
|
||||
"""Sample combat state for tests."""
|
||||
return {
|
||||
"round_number": 2,
|
||||
"current_turn": "player",
|
||||
"enemies": [
|
||||
{
|
||||
"name": "Goblin Scout",
|
||||
"current_hp": 8,
|
||||
"max_hp": 12,
|
||||
"effects": []
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_npc():
|
||||
"""Sample NPC data for tests."""
|
||||
return {
|
||||
"name": "Barkeep Magnus",
|
||||
"role": "Tavern Owner",
|
||||
"personality": "Gruff but kind-hearted",
|
||||
"speaking_style": "Short sentences, occasional jokes",
|
||||
"goals": "Keep the tavern running, help adventurers"
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_quests():
|
||||
"""Sample eligible quests for tests."""
|
||||
return [
|
||||
{
|
||||
"quest_id": "goblin_cave",
|
||||
"name": "Clear the Goblin Cave",
|
||||
"difficulty": "EASY",
|
||||
"quest_giver": "Village Elder",
|
||||
"description": "A nearby cave has been overrun by goblins.",
|
||||
"narrative_hooks": [
|
||||
"Farmers complain about stolen livestock",
|
||||
"Goblin tracks spotted on the road"
|
||||
]
|
||||
},
|
||||
{
|
||||
"quest_id": "missing_merchant",
|
||||
"name": "Find the Missing Merchant",
|
||||
"difficulty": "MEDIUM",
|
||||
"quest_giver": "Guild Master",
|
||||
"description": "A merchant caravan has gone missing.",
|
||||
"narrative_hooks": [
|
||||
"The merchant was carrying valuable goods",
|
||||
"His family is worried sick"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
class TestNarrativeGeneratorInit:
|
||||
"""Tests for NarrativeGenerator initialization."""
|
||||
|
||||
def test_init_with_defaults(self):
|
||||
"""Test initialization with default dependencies."""
|
||||
with patch('app.ai.narrative_generator.ReplicateClient'):
|
||||
generator = NarrativeGenerator()
|
||||
assert generator.model_selector is not None
|
||||
assert generator.prompt_templates is not None
|
||||
|
||||
def test_init_with_custom_selector(self):
|
||||
"""Test initialization with custom model selector."""
|
||||
custom_selector = ModelSelector()
|
||||
generator = NarrativeGenerator(model_selector=custom_selector)
|
||||
assert generator.model_selector is custom_selector
|
||||
|
||||
def test_init_with_custom_client(self, mock_replicate_client):
|
||||
"""Test initialization with custom Replicate client."""
|
||||
generator = NarrativeGenerator(replicate_client=mock_replicate_client)
|
||||
assert generator.replicate_client is mock_replicate_client
|
||||
|
||||
|
||||
class TestGenerateStoryResponse:
|
||||
"""Tests for generate_story_response method."""
|
||||
|
||||
def test_basic_story_generation(
|
||||
self, generator, mock_replicate_client, sample_character, sample_game_state
|
||||
):
|
||||
"""Test basic story response generation."""
|
||||
response = generator.generate_story_response(
|
||||
character=sample_character,
|
||||
action="I search the room for hidden doors",
|
||||
game_state=sample_game_state,
|
||||
user_tier=UserTier.FREE
|
||||
)
|
||||
|
||||
assert isinstance(response, NarrativeResponse)
|
||||
assert response.narrative is not None
|
||||
assert response.tokens_used > 0
|
||||
assert response.context_type == "story_progression"
|
||||
assert response.generation_time > 0
|
||||
|
||||
# Verify client was called
|
||||
mock_replicate_client.generate.assert_called_once()
|
||||
|
||||
def test_story_with_conversation_history(
|
||||
self, generator, mock_replicate_client, sample_character, sample_game_state
|
||||
):
|
||||
"""Test story generation with conversation history."""
|
||||
history = [
|
||||
{
|
||||
"turn": 1,
|
||||
"action": "I enter the tavern",
|
||||
"dm_response": "The tavern is warm and inviting..."
|
||||
},
|
||||
{
|
||||
"turn": 2,
|
||||
"action": "I approach the bar",
|
||||
"dm_response": "The barkeep nods in greeting..."
|
||||
}
|
||||
]
|
||||
|
||||
response = generator.generate_story_response(
|
||||
character=sample_character,
|
||||
action="I ask about local rumors",
|
||||
game_state=sample_game_state,
|
||||
user_tier=UserTier.PREMIUM,
|
||||
conversation_history=history
|
||||
)
|
||||
|
||||
assert isinstance(response, NarrativeResponse)
|
||||
mock_replicate_client.generate.assert_called_once()
|
||||
|
||||
def test_story_with_world_context(
|
||||
self, generator, mock_replicate_client, sample_character, sample_game_state
|
||||
):
|
||||
"""Test story generation with additional world context."""
|
||||
response = generator.generate_story_response(
|
||||
character=sample_character,
|
||||
action="I look around",
|
||||
game_state=sample_game_state,
|
||||
user_tier=UserTier.ELITE,
|
||||
world_context="A festival is being celebrated in the village"
|
||||
)
|
||||
|
||||
assert isinstance(response, NarrativeResponse)
|
||||
|
||||
def test_story_uses_correct_model_for_tier(
|
||||
self, generator, mock_replicate_client, sample_character, sample_game_state
|
||||
):
|
||||
"""Test that correct model is selected based on tier."""
|
||||
for tier in UserTier:
|
||||
mock_replicate_client.generate.reset_mock()
|
||||
|
||||
generator.generate_story_response(
|
||||
character=sample_character,
|
||||
action="Test action",
|
||||
game_state=sample_game_state,
|
||||
user_tier=tier
|
||||
)
|
||||
|
||||
# Verify generate was called with appropriate model
|
||||
call_kwargs = mock_replicate_client.generate.call_args
|
||||
assert call_kwargs is not None
|
||||
|
||||
def test_story_handles_client_error(
|
||||
self, generator, mock_replicate_client, sample_character, sample_game_state
|
||||
):
|
||||
"""Test error handling when client fails."""
|
||||
mock_replicate_client.generate.side_effect = ReplicateClientError("API error")
|
||||
|
||||
with pytest.raises(NarrativeGeneratorError) as exc_info:
|
||||
generator.generate_story_response(
|
||||
character=sample_character,
|
||||
action="Test action",
|
||||
game_state=sample_game_state,
|
||||
user_tier=UserTier.FREE
|
||||
)
|
||||
|
||||
assert "AI generation failed" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestGenerateCombatNarration:
|
||||
"""Tests for generate_combat_narration method."""
|
||||
|
||||
def test_basic_combat_narration(
|
||||
self, generator, mock_replicate_client, sample_character, sample_combat_state
|
||||
):
|
||||
"""Test basic combat narration generation."""
|
||||
action_result = {
|
||||
"hit": True,
|
||||
"damage": 8,
|
||||
"target": "Goblin Scout"
|
||||
}
|
||||
|
||||
response = generator.generate_combat_narration(
|
||||
character=sample_character,
|
||||
combat_state=sample_combat_state,
|
||||
action="swings their sword",
|
||||
action_result=action_result,
|
||||
user_tier=UserTier.BASIC
|
||||
)
|
||||
|
||||
assert isinstance(response, NarrativeResponse)
|
||||
assert response.context_type == "combat_narration"
|
||||
|
||||
def test_critical_hit_narration(
|
||||
self, generator, mock_replicate_client, sample_character, sample_combat_state
|
||||
):
|
||||
"""Test combat narration for critical hit."""
|
||||
action_result = {
|
||||
"hit": True,
|
||||
"damage": 16,
|
||||
"target": "Goblin Scout"
|
||||
}
|
||||
|
||||
response = generator.generate_combat_narration(
|
||||
character=sample_character,
|
||||
combat_state=sample_combat_state,
|
||||
action="strikes with precision",
|
||||
action_result=action_result,
|
||||
user_tier=UserTier.PREMIUM,
|
||||
is_critical=True
|
||||
)
|
||||
|
||||
assert isinstance(response, NarrativeResponse)
|
||||
|
||||
def test_finishing_blow_narration(
|
||||
self, generator, mock_replicate_client, sample_character, sample_combat_state
|
||||
):
|
||||
"""Test combat narration for finishing blow."""
|
||||
action_result = {
|
||||
"hit": True,
|
||||
"damage": 10,
|
||||
"target": "Goblin Scout"
|
||||
}
|
||||
|
||||
response = generator.generate_combat_narration(
|
||||
character=sample_character,
|
||||
combat_state=sample_combat_state,
|
||||
action="delivers the final blow",
|
||||
action_result=action_result,
|
||||
user_tier=UserTier.ELITE,
|
||||
is_finishing_blow=True
|
||||
)
|
||||
|
||||
assert isinstance(response, NarrativeResponse)
|
||||
|
||||
def test_miss_narration(
|
||||
self, generator, mock_replicate_client, sample_character, sample_combat_state
|
||||
):
|
||||
"""Test combat narration for a miss."""
|
||||
action_result = {
|
||||
"hit": False,
|
||||
"damage": 0,
|
||||
"target": "Goblin Scout"
|
||||
}
|
||||
|
||||
response = generator.generate_combat_narration(
|
||||
character=sample_character,
|
||||
combat_state=sample_combat_state,
|
||||
action="swings wildly",
|
||||
action_result=action_result,
|
||||
user_tier=UserTier.FREE
|
||||
)
|
||||
|
||||
assert isinstance(response, NarrativeResponse)
|
||||
|
||||
|
||||
class TestGenerateQuestSelection:
|
||||
"""Tests for generate_quest_selection method."""
|
||||
|
||||
def test_basic_quest_selection(
|
||||
self, generator, mock_replicate_client, sample_character, sample_game_state, sample_quests
|
||||
):
|
||||
"""Test basic quest selection."""
|
||||
# Mock response to return a valid quest_id
|
||||
mock_replicate_client.generate.return_value = ReplicateResponse(
|
||||
text="goblin_cave",
|
||||
tokens_used=50,
|
||||
model="meta/meta-llama-3-8b-instruct",
|
||||
generation_time=1.0
|
||||
)
|
||||
|
||||
quest_id = generator.generate_quest_selection(
|
||||
character=sample_character,
|
||||
eligible_quests=sample_quests,
|
||||
game_context=sample_game_state,
|
||||
user_tier=UserTier.FREE
|
||||
)
|
||||
|
||||
assert quest_id == "goblin_cave"
|
||||
|
||||
def test_quest_selection_with_recent_actions(
|
||||
self, generator, mock_replicate_client, sample_character, sample_game_state, sample_quests
|
||||
):
|
||||
"""Test quest selection with recent actions."""
|
||||
mock_replicate_client.generate.return_value = ReplicateResponse(
|
||||
text="missing_merchant",
|
||||
tokens_used=50,
|
||||
model="meta/meta-llama-3-8b-instruct",
|
||||
generation_time=1.0
|
||||
)
|
||||
|
||||
recent_actions = [
|
||||
"Asked about missing traders",
|
||||
"Investigated the market square"
|
||||
]
|
||||
|
||||
quest_id = generator.generate_quest_selection(
|
||||
character=sample_character,
|
||||
eligible_quests=sample_quests,
|
||||
game_context=sample_game_state,
|
||||
user_tier=UserTier.PREMIUM,
|
||||
recent_actions=recent_actions
|
||||
)
|
||||
|
||||
assert quest_id == "missing_merchant"
|
||||
|
||||
def test_quest_selection_invalid_response_fallback(
|
||||
self, generator, mock_replicate_client, sample_character, sample_game_state, sample_quests
|
||||
):
|
||||
"""Test fallback when AI returns invalid quest_id."""
|
||||
# Mock response with invalid quest_id
|
||||
mock_replicate_client.generate.return_value = ReplicateResponse(
|
||||
text="invalid_quest_id",
|
||||
tokens_used=50,
|
||||
model="meta/meta-llama-3-8b-instruct",
|
||||
generation_time=1.0
|
||||
)
|
||||
|
||||
quest_id = generator.generate_quest_selection(
|
||||
character=sample_character,
|
||||
eligible_quests=sample_quests,
|
||||
game_context=sample_game_state,
|
||||
user_tier=UserTier.FREE
|
||||
)
|
||||
|
||||
# Should fall back to first eligible quest
|
||||
assert quest_id == "goblin_cave"
|
||||
|
||||
def test_quest_selection_no_quests_error(
|
||||
self, generator, sample_character, sample_game_state
|
||||
):
|
||||
"""Test error when no eligible quests provided."""
|
||||
with pytest.raises(NarrativeGeneratorError) as exc_info:
|
||||
generator.generate_quest_selection(
|
||||
character=sample_character,
|
||||
eligible_quests=[],
|
||||
game_context=sample_game_state,
|
||||
user_tier=UserTier.FREE
|
||||
)
|
||||
|
||||
assert "No eligible quests" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestGenerateNPCDialogue:
|
||||
"""Tests for generate_npc_dialogue method."""
|
||||
|
||||
def test_basic_npc_dialogue(
|
||||
self, generator, mock_replicate_client, sample_character, sample_npc, sample_game_state
|
||||
):
|
||||
"""Test basic NPC dialogue generation."""
|
||||
mock_replicate_client.generate.return_value = ReplicateResponse(
|
||||
text='*wipes down the bar* "Aye, rumors aplenty around here..."',
|
||||
tokens_used=100,
|
||||
model="anthropic/claude-3.5-haiku",
|
||||
generation_time=1.5
|
||||
)
|
||||
|
||||
response = generator.generate_npc_dialogue(
|
||||
character=sample_character,
|
||||
npc=sample_npc,
|
||||
conversation_topic="What rumors have you heard?",
|
||||
game_state=sample_game_state,
|
||||
user_tier=UserTier.BASIC
|
||||
)
|
||||
|
||||
assert isinstance(response, NarrativeResponse)
|
||||
assert response.context_type == "npc_dialogue"
|
||||
|
||||
def test_npc_dialogue_with_relationship(
|
||||
self, generator, mock_replicate_client, sample_character, sample_npc, sample_game_state
|
||||
):
|
||||
"""Test NPC dialogue with established relationship."""
|
||||
response = generator.generate_npc_dialogue(
|
||||
character=sample_character,
|
||||
npc=sample_npc,
|
||||
conversation_topic="Hello old friend",
|
||||
game_state=sample_game_state,
|
||||
user_tier=UserTier.PREMIUM,
|
||||
npc_relationship="Friendly - helped defend the tavern last month"
|
||||
)
|
||||
|
||||
assert isinstance(response, NarrativeResponse)
|
||||
|
||||
def test_npc_dialogue_with_previous_conversation(
|
||||
self, generator, mock_replicate_client, sample_character, sample_npc, sample_game_state
|
||||
):
|
||||
"""Test NPC dialogue with previous conversation history."""
|
||||
previous = [
|
||||
{
|
||||
"player_line": "What's on tap tonight?",
|
||||
"npc_response": "Got some fine ale from the southern vineyards."
|
||||
}
|
||||
]
|
||||
|
||||
response = generator.generate_npc_dialogue(
|
||||
character=sample_character,
|
||||
npc=sample_npc,
|
||||
conversation_topic="I'll take one",
|
||||
game_state=sample_game_state,
|
||||
user_tier=UserTier.FREE,
|
||||
previous_dialogue=previous
|
||||
)
|
||||
|
||||
assert isinstance(response, NarrativeResponse)
|
||||
|
||||
def test_npc_dialogue_with_special_knowledge(
|
||||
self, generator, mock_replicate_client, sample_character, sample_npc, sample_game_state
|
||||
):
|
||||
"""Test NPC dialogue with special knowledge."""
|
||||
response = generator.generate_npc_dialogue(
|
||||
character=sample_character,
|
||||
npc=sample_npc,
|
||||
conversation_topic="Have you seen anything strange?",
|
||||
game_state=sample_game_state,
|
||||
user_tier=UserTier.ELITE,
|
||||
npc_knowledge=["Secret passage in the cellar", "Hidden treasure map"]
|
||||
)
|
||||
|
||||
assert isinstance(response, NarrativeResponse)
|
||||
|
||||
|
||||
class TestModelSelection:
|
||||
"""Tests for model selection behavior."""
|
||||
|
||||
def test_free_tier_uses_llama(
|
||||
self, generator, mock_replicate_client, sample_character, sample_game_state
|
||||
):
|
||||
"""Test that free tier uses Llama model."""
|
||||
generator.generate_story_response(
|
||||
character=sample_character,
|
||||
action="Test",
|
||||
game_state=sample_game_state,
|
||||
user_tier=UserTier.FREE
|
||||
)
|
||||
|
||||
call_kwargs = mock_replicate_client.generate.call_args
|
||||
assert call_kwargs is not None
|
||||
# Model type should be Llama for free tier
|
||||
model_arg = call_kwargs.kwargs.get('model')
|
||||
if model_arg:
|
||||
assert "llama" in str(model_arg).lower() or model_arg == ModelType.LLAMA_3_8B
|
||||
|
||||
def test_different_contexts_use_different_temperatures(self, generator):
|
||||
"""Test that different contexts have different temperature settings."""
|
||||
# Get model configs for different contexts
|
||||
config_story = generator.model_selector.select_model(
|
||||
UserTier.FREE, ContextType.STORY_PROGRESSION
|
||||
)
|
||||
config_quest = generator.model_selector.select_model(
|
||||
UserTier.FREE, ContextType.QUEST_SELECTION
|
||||
)
|
||||
|
||||
# Story should have higher temperature (more creative)
|
||||
assert config_story.temperature > config_quest.temperature
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Tests for error handling behavior."""
|
||||
|
||||
def test_template_error_handling(self, mock_replicate_client):
|
||||
"""Test handling of template errors."""
|
||||
from app.ai.prompt_templates import PromptTemplateError
|
||||
|
||||
# Create generator with bad template path
|
||||
with patch.object(PromptTemplates, 'render') as mock_render:
|
||||
mock_render.side_effect = PromptTemplateError("Template not found")
|
||||
|
||||
generator = NarrativeGenerator(replicate_client=mock_replicate_client)
|
||||
|
||||
with pytest.raises(NarrativeGeneratorError) as exc_info:
|
||||
generator.generate_story_response(
|
||||
character={"name": "Test"},
|
||||
action="Test",
|
||||
game_state={"current_location": "Test"},
|
||||
user_tier=UserTier.FREE
|
||||
)
|
||||
|
||||
assert "Prompt template error" in str(exc_info.value)
|
||||
|
||||
def test_api_error_handling(
|
||||
self, generator, mock_replicate_client, sample_character, sample_game_state
|
||||
):
|
||||
"""Test handling of API errors."""
|
||||
mock_replicate_client.generate.side_effect = ReplicateClientError("Connection failed")
|
||||
|
||||
with pytest.raises(NarrativeGeneratorError) as exc_info:
|
||||
generator.generate_story_response(
|
||||
character=sample_character,
|
||||
action="Test",
|
||||
game_state=sample_game_state,
|
||||
user_tier=UserTier.FREE
|
||||
)
|
||||
|
||||
assert "AI generation failed" in str(exc_info.value)
|
||||
200
api/tests/test_origin_service.py
Normal file
200
api/tests/test_origin_service.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
Unit tests for OriginService - character origin loading and validation.
|
||||
|
||||
Tests verify that origins load correctly from YAML, contain all required
|
||||
data, and provide proper narrative hooks for the AI DM.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from app.services.origin_service import OriginService, get_origin_service
|
||||
from app.models.origins import Origin, StartingLocation, StartingBonus
|
||||
|
||||
|
||||
class TestOriginService:
|
||||
"""Test suite for OriginService functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def origin_service(self):
|
||||
"""Create a fresh OriginService instance for each test."""
|
||||
service = OriginService()
|
||||
service.clear_cache()
|
||||
return service
|
||||
|
||||
def test_service_initializes(self, origin_service):
|
||||
"""Test that OriginService initializes correctly."""
|
||||
assert origin_service is not None
|
||||
assert origin_service.data_file.exists()
|
||||
|
||||
def test_singleton_pattern(self):
|
||||
"""Test that get_origin_service returns a singleton."""
|
||||
service1 = get_origin_service()
|
||||
service2 = get_origin_service()
|
||||
assert service1 is service2
|
||||
|
||||
def test_load_all_origins(self, origin_service):
|
||||
"""Test loading all origins from YAML file."""
|
||||
origins = origin_service.load_all_origins()
|
||||
|
||||
assert len(origins) == 4
|
||||
origin_ids = [origin.id for origin in origins]
|
||||
assert "soul_revenant" in origin_ids
|
||||
assert "memory_thief" in origin_ids
|
||||
assert "shadow_apprentice" in origin_ids
|
||||
assert "escaped_captive" in origin_ids
|
||||
|
||||
def test_load_origin_by_id(self, origin_service):
|
||||
"""Test loading a specific origin by ID."""
|
||||
origin = origin_service.load_origin("soul_revenant")
|
||||
|
||||
assert origin is not None
|
||||
assert origin.id == "soul_revenant"
|
||||
assert origin.name == "Soul Revenant"
|
||||
assert isinstance(origin.description, str)
|
||||
assert len(origin.description) > 0
|
||||
|
||||
def test_origin_not_found(self, origin_service):
|
||||
"""Test that loading non-existent origin returns None."""
|
||||
origin = origin_service.load_origin("nonexistent_origin")
|
||||
assert origin is None
|
||||
|
||||
def test_origin_has_starting_location(self, origin_service):
|
||||
"""Test that all origins have valid starting locations."""
|
||||
origins = origin_service.load_all_origins()
|
||||
|
||||
for origin in origins:
|
||||
assert origin.starting_location is not None
|
||||
assert isinstance(origin.starting_location, StartingLocation)
|
||||
assert origin.starting_location.id
|
||||
assert origin.starting_location.name
|
||||
assert origin.starting_location.region
|
||||
assert origin.starting_location.description
|
||||
|
||||
def test_origin_has_narrative_hooks(self, origin_service):
|
||||
"""Test that all origins have narrative hooks for AI DM."""
|
||||
origins = origin_service.load_all_origins()
|
||||
|
||||
for origin in origins:
|
||||
assert origin.narrative_hooks is not None
|
||||
assert isinstance(origin.narrative_hooks, list)
|
||||
assert len(origin.narrative_hooks) > 0, f"{origin.id} has no narrative hooks"
|
||||
|
||||
def test_origin_has_starting_bonus(self, origin_service):
|
||||
"""Test that all origins have starting bonuses."""
|
||||
origins = origin_service.load_all_origins()
|
||||
|
||||
for origin in origins:
|
||||
assert origin.starting_bonus is not None
|
||||
assert isinstance(origin.starting_bonus, StartingBonus)
|
||||
assert origin.starting_bonus.trait
|
||||
assert origin.starting_bonus.description
|
||||
assert origin.starting_bonus.effect
|
||||
|
||||
def test_soul_revenant_details(self, origin_service):
|
||||
"""Test Soul Revenant origin has correct details."""
|
||||
origin = origin_service.load_origin("soul_revenant")
|
||||
|
||||
assert origin.name == "Soul Revenant"
|
||||
assert "centuries" in origin.description.lower()
|
||||
assert origin.starting_location.id == "forgotten_crypt"
|
||||
assert "Deathless Resolve" in origin.starting_bonus.trait
|
||||
assert any("past" in hook.lower() for hook in origin.narrative_hooks)
|
||||
|
||||
def test_memory_thief_details(self, origin_service):
|
||||
"""Test Memory Thief origin has correct details."""
|
||||
origin = origin_service.load_origin("memory_thief")
|
||||
|
||||
assert origin.name == "Memory Thief"
|
||||
assert "memory" in origin.description.lower()
|
||||
assert origin.starting_location.id == "thornfield_plains"
|
||||
assert "Blank Slate" in origin.starting_bonus.trait
|
||||
assert any("memory" in hook.lower() for hook in origin.narrative_hooks)
|
||||
|
||||
def test_shadow_apprentice_details(self, origin_service):
|
||||
"""Test Shadow Apprentice origin has correct details."""
|
||||
origin = origin_service.load_origin("shadow_apprentice")
|
||||
|
||||
assert origin.name == "Shadow Apprentice"
|
||||
assert "master" in origin.description.lower()
|
||||
assert origin.starting_location.id == "shadowfen"
|
||||
assert "Trained in Shadows" in origin.starting_bonus.trait
|
||||
assert any("master" in hook.lower() for hook in origin.narrative_hooks)
|
||||
|
||||
def test_escaped_captive_details(self, origin_service):
|
||||
"""Test Escaped Captive origin has correct details."""
|
||||
origin = origin_service.load_origin("escaped_captive")
|
||||
|
||||
assert origin.name == "The Escaped Captive"
|
||||
assert "prison" in origin.description.lower() or "ironpeak" in origin.description.lower()
|
||||
assert origin.starting_location.id == "ironpeak_pass"
|
||||
assert "Hardened Survivor" in origin.starting_bonus.trait
|
||||
assert any("prison" in hook.lower() or "past" in hook.lower() for hook in origin.narrative_hooks)
|
||||
|
||||
def test_origin_serialization(self, origin_service):
|
||||
"""Test that origins can be serialized to dict and back."""
|
||||
original = origin_service.load_origin("soul_revenant")
|
||||
|
||||
# Serialize to dict
|
||||
origin_dict = original.to_dict()
|
||||
assert isinstance(origin_dict, dict)
|
||||
assert origin_dict["id"] == "soul_revenant"
|
||||
assert origin_dict["name"] == "Soul Revenant"
|
||||
|
||||
# Deserialize from dict
|
||||
restored = Origin.from_dict(origin_dict)
|
||||
assert restored.id == original.id
|
||||
assert restored.name == original.name
|
||||
assert restored.description == original.description
|
||||
assert restored.starting_location.id == original.starting_location.id
|
||||
assert restored.starting_bonus.trait == original.starting_bonus.trait
|
||||
|
||||
def test_caching_works(self, origin_service):
|
||||
"""Test that caching improves performance on repeated loads."""
|
||||
# First load
|
||||
origin1 = origin_service.load_origin("soul_revenant")
|
||||
|
||||
# Second load should come from cache
|
||||
origin2 = origin_service.load_origin("soul_revenant")
|
||||
|
||||
# Should be the exact same instance from cache
|
||||
assert origin1 is origin2
|
||||
|
||||
def test_cache_clear(self, origin_service):
|
||||
"""Test that cache clearing works correctly."""
|
||||
# Load origins to populate cache
|
||||
origin_service.load_all_origins()
|
||||
assert len(origin_service._origins_cache) > 0
|
||||
|
||||
# Clear cache
|
||||
origin_service.clear_cache()
|
||||
assert len(origin_service._origins_cache) == 0
|
||||
assert origin_service._all_origins_loaded is False
|
||||
|
||||
def test_reload_origins(self, origin_service):
|
||||
"""Test that reloading clears cache and reloads from file."""
|
||||
# Load origins
|
||||
origins1 = origin_service.load_all_origins()
|
||||
|
||||
# Reload
|
||||
origins2 = origin_service.reload_origins()
|
||||
|
||||
# Should have same content but be fresh instances
|
||||
assert len(origins1) == len(origins2)
|
||||
assert all(o1.id == o2.id for o1, o2 in zip(origins1, origins2))
|
||||
|
||||
def test_get_all_origin_ids(self, origin_service):
|
||||
"""Test getting list of all origin IDs."""
|
||||
origin_ids = origin_service.get_all_origin_ids()
|
||||
|
||||
assert isinstance(origin_ids, list)
|
||||
assert len(origin_ids) == 4
|
||||
assert "soul_revenant" in origin_ids
|
||||
assert "memory_thief" in origin_ids
|
||||
assert "shadow_apprentice" in origin_ids
|
||||
assert "escaped_captive" in origin_ids
|
||||
|
||||
def test_get_origin_by_id_alias(self, origin_service):
|
||||
"""Test that get_origin_by_id is an alias for load_origin."""
|
||||
origin1 = origin_service.load_origin("soul_revenant")
|
||||
origin2 = origin_service.get_origin_by_id("soul_revenant")
|
||||
|
||||
assert origin1 is origin2
|
||||
321
api/tests/test_prompt_templates.py
Normal file
321
api/tests/test_prompt_templates.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""
|
||||
Unit tests for prompt templates module.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from app.ai import (
|
||||
PromptTemplates,
|
||||
PromptTemplateError,
|
||||
get_prompt_templates,
|
||||
render_prompt,
|
||||
)
|
||||
|
||||
|
||||
class TestPromptTemplates:
|
||||
"""Tests for PromptTemplates class."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures."""
|
||||
self.templates = PromptTemplates()
|
||||
|
||||
def test_initialization(self):
|
||||
"""Test PromptTemplates initializes correctly."""
|
||||
assert self.templates is not None
|
||||
assert self.templates.env is not None
|
||||
|
||||
def test_template_directory_exists(self):
|
||||
"""Test that template directory is created."""
|
||||
assert self.templates.template_dir.exists()
|
||||
|
||||
def test_get_template_names(self):
|
||||
"""Test listing available templates."""
|
||||
names = self.templates.get_template_names()
|
||||
assert isinstance(names, list)
|
||||
# Should have our 4 core templates
|
||||
assert 'story_action.j2' in names
|
||||
assert 'combat_action.j2' in names
|
||||
assert 'quest_offering.j2' in names
|
||||
assert 'npc_dialogue.j2' in names
|
||||
|
||||
def test_render_string_simple(self):
|
||||
"""Test rendering a simple template string."""
|
||||
result = self.templates.render_string(
|
||||
"Hello, {{ name }}!",
|
||||
name="Player"
|
||||
)
|
||||
assert result == "Hello, Player!"
|
||||
|
||||
def test_render_string_with_filter(self):
|
||||
"""Test rendering with custom filter."""
|
||||
result = self.templates.render_string(
|
||||
"{{ items | format_inventory }}",
|
||||
items=[
|
||||
{"name": "Sword", "quantity": 1},
|
||||
{"name": "Potion", "quantity": 3}
|
||||
]
|
||||
)
|
||||
assert "Sword" in result
|
||||
assert "Potion (x3)" in result
|
||||
|
||||
def test_render_story_action_template(self):
|
||||
"""Test rendering story_action template."""
|
||||
result = self.templates.render(
|
||||
"story_action.j2",
|
||||
character={
|
||||
"name": "Aldric",
|
||||
"level": 5,
|
||||
"player_class": "Warrior",
|
||||
"current_hp": 45,
|
||||
"max_hp": 50,
|
||||
"stats": {"strength": 16, "dexterity": 12},
|
||||
"skills": [],
|
||||
"effects": []
|
||||
},
|
||||
game_state={
|
||||
"current_location": "The Rusty Anchor",
|
||||
"location_type": "TAVERN",
|
||||
"active_quests": [],
|
||||
"discovered_locations": []
|
||||
},
|
||||
action="I look around the tavern for anyone suspicious"
|
||||
)
|
||||
|
||||
assert "Aldric" in result
|
||||
assert "Warrior" in result
|
||||
assert "Rusty Anchor" in result
|
||||
assert "suspicious" in result
|
||||
|
||||
def test_render_combat_action_template(self):
|
||||
"""Test rendering combat_action template."""
|
||||
result = self.templates.render(
|
||||
"combat_action.j2",
|
||||
character={
|
||||
"name": "Aldric",
|
||||
"level": 5,
|
||||
"player_class": "Warrior",
|
||||
"current_hp": 45,
|
||||
"max_hp": 50,
|
||||
"effects": []
|
||||
},
|
||||
combat_state={
|
||||
"round_number": 2,
|
||||
"current_turn": "Player",
|
||||
"enemies": [
|
||||
{
|
||||
"name": "Goblin",
|
||||
"current_hp": 8,
|
||||
"max_hp": 15,
|
||||
"effects": []
|
||||
}
|
||||
]
|
||||
},
|
||||
action="swings their sword at the Goblin",
|
||||
action_result={
|
||||
"hit": True,
|
||||
"damage": 7,
|
||||
"effects_applied": []
|
||||
},
|
||||
is_critical=False,
|
||||
is_finishing_blow=False
|
||||
)
|
||||
|
||||
assert "Aldric" in result
|
||||
assert "Goblin" in result
|
||||
assert "sword" in result
|
||||
|
||||
def test_render_quest_offering_template(self):
|
||||
"""Test rendering quest_offering template."""
|
||||
result = self.templates.render(
|
||||
"quest_offering.j2",
|
||||
character={
|
||||
"name": "Aldric",
|
||||
"level": 3,
|
||||
"player_class": "Warrior",
|
||||
"completed_quests": []
|
||||
},
|
||||
game_context={
|
||||
"current_location": "Village Square",
|
||||
"location_type": "TOWN",
|
||||
"active_quests": [],
|
||||
"world_events": []
|
||||
},
|
||||
eligible_quests=[
|
||||
{
|
||||
"quest_id": "quest_goblin_cave",
|
||||
"name": "Clear the Goblin Cave",
|
||||
"difficulty": "EASY",
|
||||
"quest_giver": "Village Elder",
|
||||
"description": "Goblins have been raiding farms",
|
||||
"narrative_hooks": [
|
||||
"Farmers complaining about lost livestock"
|
||||
]
|
||||
}
|
||||
],
|
||||
recent_actions=["Talked to locals"]
|
||||
)
|
||||
|
||||
assert "quest_goblin_cave" in result
|
||||
assert "Clear the Goblin Cave" in result
|
||||
assert "Village Elder" in result
|
||||
|
||||
def test_render_npc_dialogue_template(self):
|
||||
"""Test rendering npc_dialogue template."""
|
||||
result = self.templates.render(
|
||||
"npc_dialogue.j2",
|
||||
character={
|
||||
"name": "Aldric",
|
||||
"level": 5,
|
||||
"player_class": "Warrior"
|
||||
},
|
||||
npc={
|
||||
"name": "Grizzled Bartender",
|
||||
"role": "Tavern Owner",
|
||||
"personality": "Gruff but kind",
|
||||
"speaking_style": "Short sentences, common slang"
|
||||
},
|
||||
conversation_topic="What's the latest news around here?",
|
||||
game_state={
|
||||
"current_location": "The Rusty Anchor",
|
||||
"time_of_day": "Evening",
|
||||
"active_quests": []
|
||||
}
|
||||
)
|
||||
|
||||
assert "Grizzled Bartender" in result
|
||||
assert "Aldric" in result
|
||||
assert "news" in result
|
||||
|
||||
def test_format_inventory_filter_empty(self):
|
||||
"""Test format_inventory filter with empty list."""
|
||||
result = PromptTemplates._format_inventory([])
|
||||
assert result == "Empty inventory"
|
||||
|
||||
def test_format_inventory_filter_single(self):
|
||||
"""Test format_inventory filter with single item."""
|
||||
result = PromptTemplates._format_inventory([
|
||||
{"name": "Sword", "quantity": 1}
|
||||
])
|
||||
assert result == "Sword"
|
||||
|
||||
def test_format_inventory_filter_multiple(self):
|
||||
"""Test format_inventory filter with multiple items."""
|
||||
result = PromptTemplates._format_inventory([
|
||||
{"name": "Sword", "quantity": 1},
|
||||
{"name": "Shield", "quantity": 1},
|
||||
{"name": "Potion", "quantity": 5}
|
||||
])
|
||||
assert "Sword" in result
|
||||
assert "Shield" in result
|
||||
assert "Potion (x5)" in result
|
||||
|
||||
def test_format_inventory_filter_truncation(self):
|
||||
"""Test format_inventory filter truncates long lists."""
|
||||
items = [{"name": f"Item{i}", "quantity": 1} for i in range(15)]
|
||||
result = PromptTemplates._format_inventory(items, max_items=10)
|
||||
assert "and 5 more items" in result
|
||||
|
||||
def test_format_stats_filter(self):
|
||||
"""Test format_stats filter."""
|
||||
result = PromptTemplates._format_stats({
|
||||
"strength": 16,
|
||||
"dexterity": 14
|
||||
})
|
||||
assert "Strength: 16" in result
|
||||
assert "Dexterity: 14" in result
|
||||
|
||||
def test_format_stats_filter_empty(self):
|
||||
"""Test format_stats filter with empty dict."""
|
||||
result = PromptTemplates._format_stats({})
|
||||
assert result == "No stats available"
|
||||
|
||||
def test_format_skills_filter(self):
|
||||
"""Test format_skills filter."""
|
||||
result = PromptTemplates._format_skills([
|
||||
{"name": "Sword Mastery", "level": 3},
|
||||
{"name": "Shield Block", "level": 2}
|
||||
])
|
||||
assert "Sword Mastery (Lv.3)" in result
|
||||
assert "Shield Block (Lv.2)" in result
|
||||
|
||||
def test_format_skills_filter_empty(self):
|
||||
"""Test format_skills filter with empty list."""
|
||||
result = PromptTemplates._format_skills([])
|
||||
assert result == "No skills"
|
||||
|
||||
def test_format_effects_filter(self):
|
||||
"""Test format_effects filter."""
|
||||
result = PromptTemplates._format_effects([
|
||||
{"name": "Blessed", "remaining_turns": 3},
|
||||
{"name": "Strength Buff"}
|
||||
])
|
||||
assert "Blessed (3 turns)" in result
|
||||
assert "Strength Buff" in result
|
||||
|
||||
def test_format_effects_filter_empty(self):
|
||||
"""Test format_effects filter with empty list."""
|
||||
result = PromptTemplates._format_effects([])
|
||||
assert result == "No active effects"
|
||||
|
||||
def test_truncate_text_filter_short(self):
|
||||
"""Test truncate_text filter with short text."""
|
||||
result = PromptTemplates._truncate_text("Hello", 100)
|
||||
assert result == "Hello"
|
||||
|
||||
def test_truncate_text_filter_long(self):
|
||||
"""Test truncate_text filter with long text."""
|
||||
long_text = "A" * 150
|
||||
result = PromptTemplates._truncate_text(long_text, 100)
|
||||
assert len(result) == 100
|
||||
assert result.endswith("...")
|
||||
|
||||
def test_format_gold_filter(self):
|
||||
"""Test format_gold filter."""
|
||||
assert PromptTemplates._format_gold(1000) == "1,000 gold"
|
||||
assert PromptTemplates._format_gold(1000000) == "1,000,000 gold"
|
||||
assert PromptTemplates._format_gold(50) == "50 gold"
|
||||
|
||||
def test_invalid_template_raises_error(self):
|
||||
"""Test that invalid template raises PromptTemplateError."""
|
||||
with pytest.raises(PromptTemplateError):
|
||||
self.templates.render("nonexistent_template.j2")
|
||||
|
||||
def test_invalid_template_string_raises_error(self):
|
||||
"""Test that invalid template string raises PromptTemplateError."""
|
||||
with pytest.raises(PromptTemplateError):
|
||||
self.templates.render_string("{{ invalid syntax")
|
||||
|
||||
|
||||
class TestPromptTemplateConvenienceFunctions:
|
||||
"""Tests for module-level convenience functions."""
|
||||
|
||||
def test_get_prompt_templates_singleton(self):
|
||||
"""Test get_prompt_templates returns singleton."""
|
||||
templates1 = get_prompt_templates()
|
||||
templates2 = get_prompt_templates()
|
||||
assert templates1 is templates2
|
||||
|
||||
def test_render_prompt_function(self):
|
||||
"""Test render_prompt convenience function."""
|
||||
result = render_prompt(
|
||||
"story_action.j2",
|
||||
character={
|
||||
"name": "Test",
|
||||
"level": 1,
|
||||
"player_class": "Warrior",
|
||||
"current_hp": 10,
|
||||
"max_hp": 10,
|
||||
"stats": {},
|
||||
"skills": [],
|
||||
"effects": []
|
||||
},
|
||||
game_state={
|
||||
"current_location": "Test Location",
|
||||
"location_type": "TOWN",
|
||||
"active_quests": []
|
||||
},
|
||||
action="test action"
|
||||
)
|
||||
assert "Test" in result
|
||||
assert "test action" in result
|
||||
342
api/tests/test_rate_limiter_service.py
Normal file
342
api/tests/test_rate_limiter_service.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
Tests for RateLimiterService
|
||||
|
||||
Tests the tier-based rate limiting functionality including:
|
||||
- Daily limits per tier
|
||||
- Usage tracking and incrementing
|
||||
- Rate limit checks and exceptions
|
||||
- Reset functionality
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
from app.services.rate_limiter_service import (
|
||||
RateLimiterService,
|
||||
RateLimitExceeded,
|
||||
)
|
||||
from app.ai.model_selector import UserTier
|
||||
|
||||
|
||||
class TestRateLimiterService:
|
||||
"""Tests for RateLimiterService."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis(self):
|
||||
"""Create a mock Redis service."""
|
||||
mock = MagicMock()
|
||||
mock.get.return_value = None
|
||||
mock.incr.return_value = 1
|
||||
mock.expire.return_value = True
|
||||
mock.delete.return_value = 1
|
||||
return mock
|
||||
|
||||
@pytest.fixture
|
||||
def rate_limiter(self, mock_redis):
|
||||
"""Create a RateLimiterService with mock Redis."""
|
||||
return RateLimiterService(redis_service=mock_redis)
|
||||
|
||||
def test_tier_limits(self, rate_limiter):
|
||||
"""Test that tier limits are correctly defined."""
|
||||
assert rate_limiter.get_limit_for_tier(UserTier.FREE) == 20
|
||||
assert rate_limiter.get_limit_for_tier(UserTier.BASIC) == 50
|
||||
assert rate_limiter.get_limit_for_tier(UserTier.PREMIUM) == 100
|
||||
assert rate_limiter.get_limit_for_tier(UserTier.ELITE) == 200
|
||||
|
||||
def test_get_daily_key(self, rate_limiter):
|
||||
"""Test Redis key generation."""
|
||||
from datetime import date
|
||||
|
||||
key = rate_limiter._get_daily_key("user_123", date(2025, 1, 15))
|
||||
assert key == "rate_limit:daily:user_123:2025-01-15"
|
||||
|
||||
def test_get_current_usage_no_usage(self, rate_limiter, mock_redis):
|
||||
"""Test getting current usage when no usage exists."""
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
usage = rate_limiter.get_current_usage("user_123")
|
||||
|
||||
assert usage == 0
|
||||
mock_redis.get.assert_called_once()
|
||||
|
||||
def test_get_current_usage_with_usage(self, rate_limiter, mock_redis):
|
||||
"""Test getting current usage when usage exists."""
|
||||
mock_redis.get.return_value = "15"
|
||||
|
||||
usage = rate_limiter.get_current_usage("user_123")
|
||||
|
||||
assert usage == 15
|
||||
|
||||
def test_check_rate_limit_under_limit(self, rate_limiter, mock_redis):
|
||||
"""Test that check passes when under limit."""
|
||||
mock_redis.get.return_value = "10"
|
||||
|
||||
# Should not raise
|
||||
rate_limiter.check_rate_limit("user_123", UserTier.FREE)
|
||||
|
||||
def test_check_rate_limit_at_limit(self, rate_limiter, mock_redis):
|
||||
"""Test that check fails when at limit."""
|
||||
mock_redis.get.return_value = "20" # Free tier limit
|
||||
|
||||
with pytest.raises(RateLimitExceeded) as exc_info:
|
||||
rate_limiter.check_rate_limit("user_123", UserTier.FREE)
|
||||
|
||||
exc = exc_info.value
|
||||
assert exc.user_id == "user_123"
|
||||
assert exc.user_tier == UserTier.FREE
|
||||
assert exc.limit == 20
|
||||
assert exc.current_usage == 20
|
||||
|
||||
def test_check_rate_limit_over_limit(self, rate_limiter, mock_redis):
|
||||
"""Test that check fails when over limit."""
|
||||
mock_redis.get.return_value = "25" # Over free tier limit
|
||||
|
||||
with pytest.raises(RateLimitExceeded):
|
||||
rate_limiter.check_rate_limit("user_123", UserTier.FREE)
|
||||
|
||||
def test_check_rate_limit_premium_tier(self, rate_limiter, mock_redis):
|
||||
"""Test that premium tier has higher limit."""
|
||||
mock_redis.get.return_value = "50" # Over free limit, under premium
|
||||
|
||||
# Should not raise for premium
|
||||
rate_limiter.check_rate_limit("user_123", UserTier.PREMIUM)
|
||||
|
||||
# Should raise for free
|
||||
with pytest.raises(RateLimitExceeded):
|
||||
rate_limiter.check_rate_limit("user_123", UserTier.FREE)
|
||||
|
||||
def test_increment_usage_first_time(self, rate_limiter, mock_redis):
|
||||
"""Test incrementing usage for the first time (sets expiration)."""
|
||||
mock_redis.incr.return_value = 1
|
||||
|
||||
new_count = rate_limiter.increment_usage("user_123")
|
||||
|
||||
assert new_count == 1
|
||||
mock_redis.incr.assert_called_once()
|
||||
mock_redis.expire.assert_called_once() # Should set expiration
|
||||
|
||||
def test_increment_usage_subsequent(self, rate_limiter, mock_redis):
|
||||
"""Test incrementing usage after first time (no expiration set)."""
|
||||
mock_redis.incr.return_value = 5
|
||||
|
||||
new_count = rate_limiter.increment_usage("user_123")
|
||||
|
||||
assert new_count == 5
|
||||
mock_redis.incr.assert_called_once()
|
||||
mock_redis.expire.assert_not_called() # Should NOT set expiration
|
||||
|
||||
def test_get_remaining_turns_full(self, rate_limiter, mock_redis):
|
||||
"""Test remaining turns when no usage."""
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE)
|
||||
|
||||
assert remaining == 20
|
||||
|
||||
def test_get_remaining_turns_partial(self, rate_limiter, mock_redis):
|
||||
"""Test remaining turns with partial usage."""
|
||||
mock_redis.get.return_value = "12"
|
||||
|
||||
remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE)
|
||||
|
||||
assert remaining == 8
|
||||
|
||||
def test_get_remaining_turns_exhausted(self, rate_limiter, mock_redis):
|
||||
"""Test remaining turns when limit reached."""
|
||||
mock_redis.get.return_value = "20"
|
||||
|
||||
remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE)
|
||||
|
||||
assert remaining == 0
|
||||
|
||||
def test_get_remaining_turns_over_limit(self, rate_limiter, mock_redis):
|
||||
"""Test remaining turns when over limit (should be 0, not negative)."""
|
||||
mock_redis.get.return_value = "25"
|
||||
|
||||
remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE)
|
||||
|
||||
assert remaining == 0
|
||||
|
||||
def test_get_usage_info(self, rate_limiter, mock_redis):
|
||||
"""Test getting comprehensive usage info."""
|
||||
mock_redis.get.return_value = "15"
|
||||
|
||||
info = rate_limiter.get_usage_info("user_123", UserTier.FREE)
|
||||
|
||||
assert info["user_id"] == "user_123"
|
||||
assert info["user_tier"] == "free"
|
||||
assert info["current_usage"] == 15
|
||||
assert info["daily_limit"] == 20
|
||||
assert info["remaining"] == 5
|
||||
assert info["is_limited"] is False
|
||||
assert "reset_time" in info
|
||||
|
||||
def test_get_usage_info_limited(self, rate_limiter, mock_redis):
|
||||
"""Test usage info when limited."""
|
||||
mock_redis.get.return_value = "20"
|
||||
|
||||
info = rate_limiter.get_usage_info("user_123", UserTier.FREE)
|
||||
|
||||
assert info["is_limited"] is True
|
||||
assert info["remaining"] == 0
|
||||
|
||||
def test_reset_usage(self, rate_limiter, mock_redis):
|
||||
"""Test resetting usage counter."""
|
||||
mock_redis.delete.return_value = 1
|
||||
|
||||
result = rate_limiter.reset_usage("user_123")
|
||||
|
||||
assert result is True
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
def test_reset_usage_no_key(self, rate_limiter, mock_redis):
|
||||
"""Test resetting when no usage exists."""
|
||||
mock_redis.delete.return_value = 0
|
||||
|
||||
result = rate_limiter.reset_usage("user_123")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestRateLimitExceeded:
|
||||
"""Tests for RateLimitExceeded exception."""
|
||||
|
||||
def test_exception_attributes(self):
|
||||
"""Test that exception has correct attributes."""
|
||||
reset_time = datetime(2025, 1, 16, 0, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
exc = RateLimitExceeded(
|
||||
user_id="user_123",
|
||||
user_tier=UserTier.FREE,
|
||||
limit=20,
|
||||
current_usage=20,
|
||||
reset_time=reset_time
|
||||
)
|
||||
|
||||
assert exc.user_id == "user_123"
|
||||
assert exc.user_tier == UserTier.FREE
|
||||
assert exc.limit == 20
|
||||
assert exc.current_usage == 20
|
||||
assert exc.reset_time == reset_time
|
||||
|
||||
def test_exception_message(self):
|
||||
"""Test that exception message is formatted correctly."""
|
||||
reset_time = datetime(2025, 1, 16, 0, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
exc = RateLimitExceeded(
|
||||
user_id="user_123",
|
||||
user_tier=UserTier.FREE,
|
||||
limit=20,
|
||||
current_usage=20,
|
||||
reset_time=reset_time
|
||||
)
|
||||
|
||||
message = str(exc)
|
||||
assert "user_123" in message
|
||||
assert "free tier" in message
|
||||
assert "20/20" in message
|
||||
|
||||
|
||||
class TestRateLimiterIntegration:
|
||||
"""Integration tests for rate limiter workflow."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis(self):
|
||||
"""Create a mock Redis that simulates real behavior."""
|
||||
storage = {}
|
||||
|
||||
mock = MagicMock()
|
||||
|
||||
def mock_get(key):
|
||||
return storage.get(key)
|
||||
|
||||
def mock_incr(key):
|
||||
if key not in storage:
|
||||
storage[key] = 0
|
||||
storage[key] = int(storage[key]) + 1
|
||||
return storage[key]
|
||||
|
||||
def mock_delete(key):
|
||||
if key in storage:
|
||||
del storage[key]
|
||||
return 1
|
||||
return 0
|
||||
|
||||
mock.get.side_effect = mock_get
|
||||
mock.incr.side_effect = mock_incr
|
||||
mock.delete.side_effect = mock_delete
|
||||
mock.expire.return_value = True
|
||||
|
||||
return mock
|
||||
|
||||
@pytest.fixture
|
||||
def rate_limiter(self, mock_redis):
|
||||
"""Create rate limiter with simulated Redis."""
|
||||
return RateLimiterService(redis_service=mock_redis)
|
||||
|
||||
def test_full_workflow(self, rate_limiter):
|
||||
"""Test complete rate limiting workflow."""
|
||||
user_id = "user_123"
|
||||
tier = UserTier.FREE # 20 turns/day
|
||||
|
||||
# Initial state - should pass
|
||||
rate_limiter.check_rate_limit(user_id, tier)
|
||||
assert rate_limiter.get_remaining_turns(user_id, tier) == 20
|
||||
|
||||
# Use some turns
|
||||
for i in range(15):
|
||||
rate_limiter.check_rate_limit(user_id, tier)
|
||||
rate_limiter.increment_usage(user_id)
|
||||
|
||||
# Check remaining
|
||||
assert rate_limiter.get_remaining_turns(user_id, tier) == 5
|
||||
|
||||
# Use remaining turns
|
||||
for i in range(5):
|
||||
rate_limiter.check_rate_limit(user_id, tier)
|
||||
rate_limiter.increment_usage(user_id)
|
||||
|
||||
# Now should be limited
|
||||
assert rate_limiter.get_remaining_turns(user_id, tier) == 0
|
||||
|
||||
with pytest.raises(RateLimitExceeded):
|
||||
rate_limiter.check_rate_limit(user_id, tier)
|
||||
|
||||
def test_different_tiers_same_usage(self, rate_limiter):
|
||||
"""Test that same usage affects different tiers differently."""
|
||||
user_id = "user_123"
|
||||
|
||||
# Use 30 turns
|
||||
for _ in range(30):
|
||||
rate_limiter.increment_usage(user_id)
|
||||
|
||||
# Free tier (20) should be limited
|
||||
with pytest.raises(RateLimitExceeded):
|
||||
rate_limiter.check_rate_limit(user_id, UserTier.FREE)
|
||||
|
||||
# Basic tier (50) should not be limited
|
||||
rate_limiter.check_rate_limit(user_id, UserTier.BASIC)
|
||||
|
||||
# Premium tier (100) should not be limited
|
||||
rate_limiter.check_rate_limit(user_id, UserTier.PREMIUM)
|
||||
|
||||
def test_reset_clears_usage(self, rate_limiter):
|
||||
"""Test that reset allows new usage."""
|
||||
user_id = "user_123"
|
||||
tier = UserTier.FREE
|
||||
|
||||
# Use all turns
|
||||
for _ in range(20):
|
||||
rate_limiter.increment_usage(user_id)
|
||||
|
||||
# Should be limited
|
||||
with pytest.raises(RateLimitExceeded):
|
||||
rate_limiter.check_rate_limit(user_id, tier)
|
||||
|
||||
# Reset usage
|
||||
rate_limiter.reset_usage(user_id)
|
||||
|
||||
# Should be able to use again
|
||||
rate_limiter.check_rate_limit(user_id, tier)
|
||||
assert rate_limiter.get_remaining_turns(user_id, tier) == 20
|
||||
573
api/tests/test_redis_service.py
Normal file
573
api/tests/test_redis_service.py
Normal file
@@ -0,0 +1,573 @@
|
||||
"""
|
||||
Unit tests for Redis Service.
|
||||
|
||||
These tests use mocking to test the RedisService without requiring
|
||||
a real Redis connection.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
import json
|
||||
|
||||
from redis.exceptions import RedisError, ConnectionError as RedisConnectionError
|
||||
|
||||
from app.services.redis_service import (
|
||||
RedisService,
|
||||
RedisServiceError,
|
||||
RedisConnectionFailed
|
||||
)
|
||||
|
||||
|
||||
class TestRedisServiceInit:
|
||||
"""Test RedisService initialization."""
|
||||
|
||||
@patch('app.services.redis_service.redis.ConnectionPool.from_url')
|
||||
@patch('app.services.redis_service.redis.Redis')
|
||||
def test_init_success(self, mock_redis_class, mock_pool_from_url):
|
||||
"""Test successful initialization."""
|
||||
# Setup mocks
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_from_url.return_value = mock_pool
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.ping.return_value = True
|
||||
mock_redis_class.return_value = mock_client
|
||||
|
||||
# Create service
|
||||
service = RedisService(redis_url='redis://localhost:6379/0')
|
||||
|
||||
# Verify
|
||||
mock_pool_from_url.assert_called_once()
|
||||
mock_redis_class.assert_called_once_with(connection_pool=mock_pool)
|
||||
mock_client.ping.assert_called_once()
|
||||
assert service.client == mock_client
|
||||
|
||||
@patch('app.services.redis_service.redis.ConnectionPool.from_url')
|
||||
def test_init_connection_failed(self, mock_pool_from_url):
|
||||
"""Test initialization fails when Redis is unavailable."""
|
||||
mock_pool_from_url.side_effect = RedisConnectionError("Connection refused")
|
||||
|
||||
with pytest.raises(RedisConnectionFailed) as exc_info:
|
||||
RedisService(redis_url='redis://localhost:6379/0')
|
||||
|
||||
assert "Could not connect to Redis" in str(exc_info.value)
|
||||
|
||||
def test_init_missing_url(self):
|
||||
"""Test initialization fails with missing URL."""
|
||||
# Clear environment variable and pass empty string
|
||||
with patch.dict('os.environ', {'REDIS_URL': ''}, clear=True):
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
RedisService(redis_url='')
|
||||
|
||||
assert "Redis URL not configured" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestRedisServiceOperations:
|
||||
"""Test Redis operations (get, set, delete, exists)."""
|
||||
|
||||
@pytest.fixture
|
||||
def redis_service(self):
|
||||
"""Create a RedisService with mocked client."""
|
||||
with patch('app.services.redis_service.redis.ConnectionPool.from_url') as mock_pool_from_url:
|
||||
with patch('app.services.redis_service.redis.Redis') as mock_redis_class:
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_from_url.return_value = mock_pool
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.ping.return_value = True
|
||||
mock_redis_class.return_value = mock_client
|
||||
|
||||
service = RedisService(redis_url='redis://localhost:6379/0')
|
||||
yield service
|
||||
|
||||
def test_get_existing_key(self, redis_service):
|
||||
"""Test getting an existing key."""
|
||||
redis_service.client.get.return_value = "test_value"
|
||||
|
||||
result = redis_service.get("test_key")
|
||||
|
||||
redis_service.client.get.assert_called_once_with("test_key")
|
||||
assert result == "test_value"
|
||||
|
||||
def test_get_nonexistent_key(self, redis_service):
|
||||
"""Test getting a non-existent key returns None."""
|
||||
redis_service.client.get.return_value = None
|
||||
|
||||
result = redis_service.get("nonexistent_key")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_get_error(self, redis_service):
|
||||
"""Test get raises RedisServiceError on failure."""
|
||||
redis_service.client.get.side_effect = RedisError("Connection lost")
|
||||
|
||||
with pytest.raises(RedisServiceError) as exc_info:
|
||||
redis_service.get("test_key")
|
||||
|
||||
assert "Failed to get key" in str(exc_info.value)
|
||||
|
||||
def test_set_basic(self, redis_service):
|
||||
"""Test basic set operation."""
|
||||
redis_service.client.set.return_value = True
|
||||
|
||||
result = redis_service.set("test_key", "test_value")
|
||||
|
||||
redis_service.client.set.assert_called_once_with(
|
||||
"test_key", "test_value", ex=None, nx=False, xx=False
|
||||
)
|
||||
assert result is True
|
||||
|
||||
def test_set_with_ttl(self, redis_service):
|
||||
"""Test set with TTL."""
|
||||
redis_service.client.set.return_value = True
|
||||
|
||||
result = redis_service.set("test_key", "test_value", ttl=3600)
|
||||
|
||||
redis_service.client.set.assert_called_once_with(
|
||||
"test_key", "test_value", ex=3600, nx=False, xx=False
|
||||
)
|
||||
assert result is True
|
||||
|
||||
def test_set_nx_success(self, redis_service):
|
||||
"""Test set with NX (only if not exists) - success."""
|
||||
redis_service.client.set.return_value = True
|
||||
|
||||
result = redis_service.set("test_key", "test_value", nx=True)
|
||||
|
||||
redis_service.client.set.assert_called_once_with(
|
||||
"test_key", "test_value", ex=None, nx=True, xx=False
|
||||
)
|
||||
assert result is True
|
||||
|
||||
def test_set_nx_failure(self, redis_service):
|
||||
"""Test set with NX fails when key exists."""
|
||||
redis_service.client.set.return_value = None # NX returns None if key exists
|
||||
|
||||
result = redis_service.set("test_key", "test_value", nx=True)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_set_error(self, redis_service):
|
||||
"""Test set raises RedisServiceError on failure."""
|
||||
redis_service.client.set.side_effect = RedisError("Connection lost")
|
||||
|
||||
with pytest.raises(RedisServiceError) as exc_info:
|
||||
redis_service.set("test_key", "test_value")
|
||||
|
||||
assert "Failed to set key" in str(exc_info.value)
|
||||
|
||||
def test_delete_single_key(self, redis_service):
|
||||
"""Test deleting a single key."""
|
||||
redis_service.client.delete.return_value = 1
|
||||
|
||||
result = redis_service.delete("test_key")
|
||||
|
||||
redis_service.client.delete.assert_called_once_with("test_key")
|
||||
assert result == 1
|
||||
|
||||
def test_delete_multiple_keys(self, redis_service):
|
||||
"""Test deleting multiple keys."""
|
||||
redis_service.client.delete.return_value = 3
|
||||
|
||||
result = redis_service.delete("key1", "key2", "key3")
|
||||
|
||||
redis_service.client.delete.assert_called_once_with("key1", "key2", "key3")
|
||||
assert result == 3
|
||||
|
||||
def test_delete_no_keys(self, redis_service):
|
||||
"""Test delete with no keys returns 0."""
|
||||
result = redis_service.delete()
|
||||
|
||||
redis_service.client.delete.assert_not_called()
|
||||
assert result == 0
|
||||
|
||||
def test_delete_error(self, redis_service):
|
||||
"""Test delete raises RedisServiceError on failure."""
|
||||
redis_service.client.delete.side_effect = RedisError("Connection lost")
|
||||
|
||||
with pytest.raises(RedisServiceError) as exc_info:
|
||||
redis_service.delete("test_key")
|
||||
|
||||
assert "Failed to delete keys" in str(exc_info.value)
|
||||
|
||||
def test_exists_single_key(self, redis_service):
|
||||
"""Test checking existence of a single key."""
|
||||
redis_service.client.exists.return_value = 1
|
||||
|
||||
result = redis_service.exists("test_key")
|
||||
|
||||
redis_service.client.exists.assert_called_once_with("test_key")
|
||||
assert result == 1
|
||||
|
||||
def test_exists_multiple_keys(self, redis_service):
|
||||
"""Test checking existence of multiple keys."""
|
||||
redis_service.client.exists.return_value = 2
|
||||
|
||||
result = redis_service.exists("key1", "key2", "key3")
|
||||
|
||||
redis_service.client.exists.assert_called_once_with("key1", "key2", "key3")
|
||||
assert result == 2
|
||||
|
||||
def test_exists_no_keys(self, redis_service):
|
||||
"""Test exists with no keys returns 0."""
|
||||
result = redis_service.exists()
|
||||
|
||||
redis_service.client.exists.assert_not_called()
|
||||
assert result == 0
|
||||
|
||||
def test_exists_error(self, redis_service):
|
||||
"""Test exists raises RedisServiceError on failure."""
|
||||
redis_service.client.exists.side_effect = RedisError("Connection lost")
|
||||
|
||||
with pytest.raises(RedisServiceError) as exc_info:
|
||||
redis_service.exists("test_key")
|
||||
|
||||
assert "Failed to check existence" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestRedisServiceJSON:
|
||||
"""Test JSON serialization methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def redis_service(self):
|
||||
"""Create a RedisService with mocked client."""
|
||||
with patch('app.services.redis_service.redis.ConnectionPool.from_url') as mock_pool_from_url:
|
||||
with patch('app.services.redis_service.redis.Redis') as mock_redis_class:
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_from_url.return_value = mock_pool
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.ping.return_value = True
|
||||
mock_redis_class.return_value = mock_client
|
||||
|
||||
service = RedisService(redis_url='redis://localhost:6379/0')
|
||||
yield service
|
||||
|
||||
def test_get_json_success(self, redis_service):
|
||||
"""Test getting and deserializing JSON."""
|
||||
test_data = {"name": "test", "value": 123, "nested": {"key": "value"}}
|
||||
redis_service.client.get.return_value = json.dumps(test_data)
|
||||
|
||||
result = redis_service.get_json("test_key")
|
||||
|
||||
assert result == test_data
|
||||
|
||||
def test_get_json_none(self, redis_service):
|
||||
"""Test get_json returns None for non-existent key."""
|
||||
redis_service.client.get.return_value = None
|
||||
|
||||
result = redis_service.get_json("nonexistent_key")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_get_json_invalid(self, redis_service):
|
||||
"""Test get_json raises error for invalid JSON."""
|
||||
redis_service.client.get.return_value = "not valid json {"
|
||||
|
||||
with pytest.raises(RedisServiceError) as exc_info:
|
||||
redis_service.get_json("test_key")
|
||||
|
||||
assert "Failed to decode JSON" in str(exc_info.value)
|
||||
|
||||
def test_set_json_success(self, redis_service):
|
||||
"""Test serializing and setting JSON."""
|
||||
redis_service.client.set.return_value = True
|
||||
test_data = {"name": "test", "value": 123}
|
||||
|
||||
result = redis_service.set_json("test_key", test_data, ttl=3600)
|
||||
|
||||
# Verify the value was serialized
|
||||
call_args = redis_service.client.set.call_args
|
||||
stored_value = call_args[0][1]
|
||||
assert json.loads(stored_value) == test_data
|
||||
assert result is True
|
||||
|
||||
def test_set_json_non_serializable(self, redis_service):
|
||||
"""Test set_json raises error for non-serializable data."""
|
||||
non_serializable = {"func": lambda x: x}
|
||||
|
||||
with pytest.raises(RedisServiceError) as exc_info:
|
||||
redis_service.set_json("test_key", non_serializable)
|
||||
|
||||
assert "Failed to serialize value" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestRedisServiceTTL:
|
||||
"""Test TTL-related operations."""
|
||||
|
||||
@pytest.fixture
|
||||
def redis_service(self):
|
||||
"""Create a RedisService with mocked client."""
|
||||
with patch('app.services.redis_service.redis.ConnectionPool.from_url') as mock_pool_from_url:
|
||||
with patch('app.services.redis_service.redis.Redis') as mock_redis_class:
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_from_url.return_value = mock_pool
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.ping.return_value = True
|
||||
mock_redis_class.return_value = mock_client
|
||||
|
||||
service = RedisService(redis_url='redis://localhost:6379/0')
|
||||
yield service
|
||||
|
||||
def test_expire_success(self, redis_service):
|
||||
"""Test setting expiration on existing key."""
|
||||
redis_service.client.expire.return_value = True
|
||||
|
||||
result = redis_service.expire("test_key", 3600)
|
||||
|
||||
redis_service.client.expire.assert_called_once_with("test_key", 3600)
|
||||
assert result is True
|
||||
|
||||
def test_expire_nonexistent_key(self, redis_service):
|
||||
"""Test expire returns False for non-existent key."""
|
||||
redis_service.client.expire.return_value = False
|
||||
|
||||
result = redis_service.expire("nonexistent_key", 3600)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_ttl_existing_key(self, redis_service):
|
||||
"""Test getting TTL of existing key."""
|
||||
redis_service.client.ttl.return_value = 3500
|
||||
|
||||
result = redis_service.ttl("test_key")
|
||||
|
||||
redis_service.client.ttl.assert_called_once_with("test_key")
|
||||
assert result == 3500
|
||||
|
||||
def test_ttl_no_expiry(self, redis_service):
|
||||
"""Test TTL returns -1 for key without expiry."""
|
||||
redis_service.client.ttl.return_value = -1
|
||||
|
||||
result = redis_service.ttl("test_key")
|
||||
|
||||
assert result == -1
|
||||
|
||||
def test_ttl_nonexistent_key(self, redis_service):
|
||||
"""Test TTL returns -2 for non-existent key."""
|
||||
redis_service.client.ttl.return_value = -2
|
||||
|
||||
result = redis_service.ttl("test_key")
|
||||
|
||||
assert result == -2
|
||||
|
||||
|
||||
class TestRedisServiceIncrement:
|
||||
"""Test increment/decrement operations."""
|
||||
|
||||
@pytest.fixture
|
||||
def redis_service(self):
|
||||
"""Create a RedisService with mocked client."""
|
||||
with patch('app.services.redis_service.redis.ConnectionPool.from_url') as mock_pool_from_url:
|
||||
with patch('app.services.redis_service.redis.Redis') as mock_redis_class:
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_from_url.return_value = mock_pool
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.ping.return_value = True
|
||||
mock_redis_class.return_value = mock_client
|
||||
|
||||
service = RedisService(redis_url='redis://localhost:6379/0')
|
||||
yield service
|
||||
|
||||
def test_incr_default(self, redis_service):
|
||||
"""Test incrementing by default amount (1)."""
|
||||
redis_service.client.incrby.return_value = 5
|
||||
|
||||
result = redis_service.incr("counter")
|
||||
|
||||
redis_service.client.incrby.assert_called_once_with("counter", 1)
|
||||
assert result == 5
|
||||
|
||||
def test_incr_custom_amount(self, redis_service):
|
||||
"""Test incrementing by custom amount."""
|
||||
redis_service.client.incrby.return_value = 15
|
||||
|
||||
result = redis_service.incr("counter", 10)
|
||||
|
||||
redis_service.client.incrby.assert_called_once_with("counter", 10)
|
||||
assert result == 15
|
||||
|
||||
def test_decr_default(self, redis_service):
|
||||
"""Test decrementing by default amount (1)."""
|
||||
redis_service.client.decrby.return_value = 4
|
||||
|
||||
result = redis_service.decr("counter")
|
||||
|
||||
redis_service.client.decrby.assert_called_once_with("counter", 1)
|
||||
assert result == 4
|
||||
|
||||
def test_decr_custom_amount(self, redis_service):
|
||||
"""Test decrementing by custom amount."""
|
||||
redis_service.client.decrby.return_value = 0
|
||||
|
||||
result = redis_service.decr("counter", 5)
|
||||
|
||||
redis_service.client.decrby.assert_called_once_with("counter", 5)
|
||||
assert result == 0
|
||||
|
||||
|
||||
class TestRedisServiceHealth:
|
||||
"""Test health check and info operations."""
|
||||
|
||||
@pytest.fixture
|
||||
def redis_service(self):
|
||||
"""Create a RedisService with mocked client."""
|
||||
with patch('app.services.redis_service.redis.ConnectionPool.from_url') as mock_pool_from_url:
|
||||
with patch('app.services.redis_service.redis.Redis') as mock_redis_class:
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_from_url.return_value = mock_pool
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.ping.return_value = True
|
||||
mock_redis_class.return_value = mock_client
|
||||
|
||||
service = RedisService(redis_url='redis://localhost:6379/0')
|
||||
yield service
|
||||
|
||||
def test_health_check_success(self, redis_service):
|
||||
"""Test health check when Redis is healthy."""
|
||||
redis_service.client.ping.return_value = True
|
||||
|
||||
result = redis_service.health_check()
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_health_check_failure(self, redis_service):
|
||||
"""Test health check when Redis is unhealthy."""
|
||||
redis_service.client.ping.side_effect = RedisError("Connection lost")
|
||||
|
||||
result = redis_service.health_check()
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_info_success(self, redis_service):
|
||||
"""Test getting Redis info."""
|
||||
mock_info = {
|
||||
'redis_version': '7.0.0',
|
||||
'used_memory': 1000000,
|
||||
'connected_clients': 5
|
||||
}
|
||||
redis_service.client.info.return_value = mock_info
|
||||
|
||||
result = redis_service.info()
|
||||
|
||||
assert result == mock_info
|
||||
|
||||
def test_info_error(self, redis_service):
|
||||
"""Test info raises error on failure."""
|
||||
redis_service.client.info.side_effect = RedisError("Connection lost")
|
||||
|
||||
with pytest.raises(RedisServiceError) as exc_info:
|
||||
redis_service.info()
|
||||
|
||||
assert "Failed to get Redis info" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestRedisServiceUtility:
|
||||
"""Test utility methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def redis_service(self):
|
||||
"""Create a RedisService with mocked client."""
|
||||
with patch('app.services.redis_service.redis.ConnectionPool.from_url') as mock_pool_from_url:
|
||||
with patch('app.services.redis_service.redis.Redis') as mock_redis_class:
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_from_url.return_value = mock_pool
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.ping.return_value = True
|
||||
mock_redis_class.return_value = mock_client
|
||||
|
||||
service = RedisService(redis_url='redis://localhost:6379/0')
|
||||
yield service
|
||||
|
||||
def test_flush_db(self, redis_service):
|
||||
"""Test flushing database."""
|
||||
redis_service.client.flushdb.return_value = True
|
||||
|
||||
result = redis_service.flush_db()
|
||||
|
||||
redis_service.client.flushdb.assert_called_once()
|
||||
assert result is True
|
||||
|
||||
def test_close(self, redis_service):
|
||||
"""Test closing connection pool."""
|
||||
redis_service.close()
|
||||
|
||||
redis_service.pool.disconnect.assert_called_once()
|
||||
|
||||
def test_context_manager(self, redis_service):
|
||||
"""Test using service as context manager."""
|
||||
with redis_service as service:
|
||||
assert service is not None
|
||||
|
||||
redis_service.pool.disconnect.assert_called_once()
|
||||
|
||||
def test_sanitize_url_with_password(self, redis_service):
|
||||
"""Test URL sanitization masks password."""
|
||||
url = "redis://user:secretpassword@localhost:6379/0"
|
||||
|
||||
result = redis_service._sanitize_url(url)
|
||||
|
||||
assert "secretpassword" not in result
|
||||
assert "***" in result
|
||||
assert "localhost:6379/0" in result
|
||||
|
||||
def test_sanitize_url_without_password(self, redis_service):
|
||||
"""Test URL sanitization with no password."""
|
||||
url = "redis://localhost:6379/0"
|
||||
|
||||
result = redis_service._sanitize_url(url)
|
||||
|
||||
assert result == url
|
||||
|
||||
|
||||
class TestRedisServiceIntegration:
|
||||
"""Integration-style tests that verify the flow of operations."""
|
||||
|
||||
@pytest.fixture
|
||||
def redis_service(self):
|
||||
"""Create a RedisService with mocked client."""
|
||||
with patch('app.services.redis_service.redis.ConnectionPool.from_url') as mock_pool_from_url:
|
||||
with patch('app.services.redis_service.redis.Redis') as mock_redis_class:
|
||||
mock_pool = MagicMock()
|
||||
mock_pool_from_url.return_value = mock_pool
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.ping.return_value = True
|
||||
mock_redis_class.return_value = mock_client
|
||||
|
||||
service = RedisService(redis_url='redis://localhost:6379/0')
|
||||
yield service
|
||||
|
||||
def test_set_then_get(self, redis_service):
|
||||
"""Test setting and then getting a value."""
|
||||
# Set
|
||||
redis_service.client.set.return_value = True
|
||||
redis_service.set("test_key", "test_value")
|
||||
|
||||
# Get
|
||||
redis_service.client.get.return_value = "test_value"
|
||||
result = redis_service.get("test_key")
|
||||
|
||||
assert result == "test_value"
|
||||
|
||||
def test_json_roundtrip(self, redis_service):
|
||||
"""Test JSON serialization roundtrip."""
|
||||
test_data = {
|
||||
"user_id": "user_123",
|
||||
"tokens_used": 450,
|
||||
"model": "claude-3-5-haiku"
|
||||
}
|
||||
|
||||
# Set JSON
|
||||
redis_service.client.set.return_value = True
|
||||
redis_service.set_json("job_result", test_data, ttl=3600)
|
||||
|
||||
# Get JSON
|
||||
redis_service.client.get.return_value = json.dumps(test_data)
|
||||
result = redis_service.get_json("job_result")
|
||||
|
||||
assert result == test_data
|
||||
462
api/tests/test_replicate_client.py
Normal file
462
api/tests/test_replicate_client.py
Normal file
@@ -0,0 +1,462 @@
|
||||
"""
|
||||
Tests for Replicate API client.
|
||||
|
||||
Tests cover initialization, prompt formatting, generation,
|
||||
retry logic, and error handling.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from app.ai.replicate_client import (
|
||||
ReplicateClient,
|
||||
ReplicateResponse,
|
||||
ReplicateClientError,
|
||||
ReplicateAPIError,
|
||||
ReplicateRateLimitError,
|
||||
ReplicateTimeoutError,
|
||||
ModelType,
|
||||
)
|
||||
|
||||
|
||||
class TestReplicateClientInit:
|
||||
"""Tests for ReplicateClient initialization."""
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
def test_init_with_token(self, mock_config):
|
||||
"""Test initialization with explicit API token."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token=None,
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
client = ReplicateClient(api_token="test_token_123")
|
||||
|
||||
assert client.api_token == "test_token_123"
|
||||
assert client.model == ReplicateClient.DEFAULT_MODEL.value
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
def test_init_from_config(self, mock_config):
|
||||
"""Test initialization from config."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="config_token",
|
||||
REPLICATE_MODEL="custom/model"
|
||||
)
|
||||
|
||||
client = ReplicateClient()
|
||||
|
||||
assert client.api_token == "config_token"
|
||||
assert client.model == "custom/model"
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
def test_init_missing_token(self, mock_config):
|
||||
"""Test initialization fails without API token."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token=None,
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
with pytest.raises(ReplicateClientError) as exc_info:
|
||||
ReplicateClient()
|
||||
|
||||
assert "API token not configured" in str(exc_info.value)
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
def test_init_custom_model(self, mock_config):
|
||||
"""Test initialization with custom model."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
client = ReplicateClient(model="meta/llama-2-70b")
|
||||
|
||||
assert client.model == "meta/llama-2-70b"
|
||||
|
||||
|
||||
class TestPromptFormatting:
|
||||
"""Tests for Llama-3 prompt formatting."""
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
def test_format_prompt_user_only(self, mock_config):
|
||||
"""Test formatting with only user prompt."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
client = ReplicateClient()
|
||||
formatted = client._format_llama_prompt("Hello world")
|
||||
|
||||
assert "<|begin_of_text|>" in formatted
|
||||
assert "<|start_header_id|>user<|end_header_id|>" in formatted
|
||||
assert "Hello world" in formatted
|
||||
assert "<|start_header_id|>assistant<|end_header_id|>" in formatted
|
||||
# No system header without system prompt
|
||||
assert "system<|end_header_id|>" not in formatted
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
def test_format_prompt_with_system(self, mock_config):
|
||||
"""Test formatting with system and user prompts."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
client = ReplicateClient()
|
||||
formatted = client._format_llama_prompt(
|
||||
"What is 2+2?",
|
||||
system_prompt="You are a helpful assistant."
|
||||
)
|
||||
|
||||
assert "<|start_header_id|>system<|end_header_id|>" in formatted
|
||||
assert "You are a helpful assistant." in formatted
|
||||
assert "<|start_header_id|>user<|end_header_id|>" in formatted
|
||||
assert "What is 2+2?" in formatted
|
||||
|
||||
|
||||
class TestGenerate:
|
||||
"""Tests for text generation."""
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
@patch('app.ai.replicate_client.replicate')
|
||||
def test_generate_success(self, mock_replicate, mock_config):
|
||||
"""Test successful text generation."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
# Mock streaming response
|
||||
mock_replicate.run.return_value = iter(["Hello ", "world ", "!"])
|
||||
|
||||
client = ReplicateClient()
|
||||
response = client.generate("Say hello")
|
||||
|
||||
assert isinstance(response, ReplicateResponse)
|
||||
assert response.text == "Hello world !"
|
||||
assert response.tokens_used > 0
|
||||
assert response.model == ReplicateClient.DEFAULT_MODEL.value
|
||||
assert response.generation_time > 0
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
@patch('app.ai.replicate_client.replicate')
|
||||
def test_generate_with_parameters(self, mock_replicate, mock_config):
|
||||
"""Test generation with custom parameters."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
mock_replicate.run.return_value = iter(["Response"])
|
||||
|
||||
client = ReplicateClient()
|
||||
response = client.generate(
|
||||
prompt="Test",
|
||||
system_prompt="Be concise",
|
||||
max_tokens=100,
|
||||
temperature=0.5,
|
||||
top_p=0.8,
|
||||
timeout=60
|
||||
)
|
||||
|
||||
# Verify parameters were passed
|
||||
call_args = mock_replicate.run.call_args
|
||||
assert call_args[1]['input']['max_tokens'] == 100
|
||||
assert call_args[1]['input']['temperature'] == 0.5
|
||||
assert call_args[1]['input']['top_p'] == 0.8
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
@patch('app.ai.replicate_client.replicate')
|
||||
def test_generate_string_response(self, mock_replicate, mock_config):
|
||||
"""Test handling string response (non-streaming)."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
mock_replicate.run.return_value = "Direct string response"
|
||||
|
||||
client = ReplicateClient()
|
||||
response = client.generate("Test")
|
||||
|
||||
assert response.text == "Direct string response"
|
||||
|
||||
|
||||
class TestRetryLogic:
|
||||
"""Tests for retry and error handling."""
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
@patch('app.ai.replicate_client.replicate')
|
||||
@patch('app.ai.replicate_client.time.sleep')
|
||||
def test_retry_on_rate_limit(self, mock_sleep, mock_replicate, mock_config):
|
||||
"""Test retry logic on rate limit errors."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
# First call raises rate limit, second succeeds
|
||||
mock_replicate.exceptions.ReplicateError = Exception
|
||||
mock_replicate.run.side_effect = [
|
||||
Exception("Rate limit exceeded 429"),
|
||||
iter(["Success"])
|
||||
]
|
||||
|
||||
client = ReplicateClient()
|
||||
response = client.generate("Test")
|
||||
|
||||
assert response.text == "Success"
|
||||
assert mock_sleep.called # Verify backoff happened
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
@patch('app.ai.replicate_client.replicate')
|
||||
@patch('app.ai.replicate_client.time.sleep')
|
||||
def test_max_retries_exceeded(self, mock_sleep, mock_replicate, mock_config):
|
||||
"""Test that max retries raises error."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
# All calls fail with rate limit
|
||||
mock_replicate.exceptions.ReplicateError = Exception
|
||||
mock_replicate.run.side_effect = Exception("Rate limit exceeded 429")
|
||||
|
||||
client = ReplicateClient()
|
||||
|
||||
with pytest.raises(ReplicateRateLimitError):
|
||||
client.generate("Test")
|
||||
|
||||
# Should have retried MAX_RETRIES times
|
||||
assert mock_replicate.run.call_count == ReplicateClient.MAX_RETRIES
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
@patch('app.ai.replicate_client.replicate')
|
||||
def test_timeout_error(self, mock_replicate, mock_config):
|
||||
"""Test timeout error handling."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
mock_replicate.exceptions.ReplicateError = Exception
|
||||
mock_replicate.run.side_effect = Exception("Request timeout")
|
||||
|
||||
client = ReplicateClient()
|
||||
|
||||
with pytest.raises(ReplicateTimeoutError):
|
||||
client.generate("Test")
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
@patch('app.ai.replicate_client.replicate')
|
||||
def test_api_error(self, mock_replicate, mock_config):
|
||||
"""Test generic API error handling."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
mock_replicate.exceptions.ReplicateError = Exception
|
||||
mock_replicate.run.side_effect = Exception("Invalid model")
|
||||
|
||||
client = ReplicateClient()
|
||||
|
||||
with pytest.raises(ReplicateAPIError):
|
||||
client.generate("Test")
|
||||
|
||||
|
||||
class TestValidation:
|
||||
"""Tests for API key validation."""
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
@patch('app.ai.replicate_client.replicate')
|
||||
def test_validate_api_key_success(self, mock_replicate, mock_config):
|
||||
"""Test successful API key validation."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="valid_token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
mock_replicate.models.get.return_value = MagicMock()
|
||||
|
||||
client = ReplicateClient()
|
||||
assert client.validate_api_key() is True
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
@patch('app.ai.replicate_client.replicate')
|
||||
def test_validate_api_key_failure(self, mock_replicate, mock_config):
|
||||
"""Test failed API key validation."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="invalid_token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
mock_replicate.models.get.side_effect = Exception("Invalid API token")
|
||||
|
||||
client = ReplicateClient()
|
||||
assert client.validate_api_key() is False
|
||||
|
||||
|
||||
class TestResponseDataclass:
|
||||
"""Tests for ReplicateResponse dataclass."""
|
||||
|
||||
def test_response_creation(self):
|
||||
"""Test creating ReplicateResponse."""
|
||||
response = ReplicateResponse(
|
||||
text="Hello world",
|
||||
tokens_used=50,
|
||||
model="meta/llama-3-8b",
|
||||
generation_time=1.5
|
||||
)
|
||||
|
||||
assert response.text == "Hello world"
|
||||
assert response.tokens_used == 50
|
||||
assert response.model == "meta/llama-3-8b"
|
||||
assert response.generation_time == 1.5
|
||||
|
||||
def test_response_immutability(self):
|
||||
"""Test that response fields are accessible."""
|
||||
response = ReplicateResponse(
|
||||
text="Test",
|
||||
tokens_used=10,
|
||||
model="test",
|
||||
generation_time=0.5
|
||||
)
|
||||
|
||||
# Dataclass should allow attribute access
|
||||
assert hasattr(response, 'text')
|
||||
assert hasattr(response, 'tokens_used')
|
||||
|
||||
|
||||
class TestModelType:
|
||||
"""Tests for ModelType enum and multi-model support."""
|
||||
|
||||
def test_model_type_values(self):
|
||||
"""Test ModelType enum has expected values."""
|
||||
assert ModelType.LLAMA_3_8B.value == "meta/meta-llama-3-8b-instruct"
|
||||
assert ModelType.CLAUDE_HAIKU.value == "anthropic/claude-3.5-haiku"
|
||||
assert ModelType.CLAUDE_SONNET.value == "anthropic/claude-3.5-sonnet"
|
||||
assert ModelType.CLAUDE_SONNET_4.value == "anthropic/claude-sonnet-4"
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
def test_init_with_model_type_enum(self, mock_config):
|
||||
"""Test initialization with ModelType enum."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
client = ReplicateClient(model=ModelType.CLAUDE_HAIKU)
|
||||
|
||||
assert client.model == "anthropic/claude-3.5-haiku"
|
||||
assert client.model_type == ModelType.CLAUDE_HAIKU
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
def test_is_claude_model(self, mock_config):
|
||||
"""Test _is_claude_model helper."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
# Llama model
|
||||
client = ReplicateClient(model=ModelType.LLAMA_3_8B)
|
||||
assert client._is_claude_model() is False
|
||||
|
||||
# Claude models
|
||||
client = ReplicateClient(model=ModelType.CLAUDE_HAIKU)
|
||||
assert client._is_claude_model() is True
|
||||
|
||||
client = ReplicateClient(model=ModelType.CLAUDE_SONNET)
|
||||
assert client._is_claude_model() is True
|
||||
|
||||
|
||||
class TestClaudeModels:
|
||||
"""Tests for Claude model generation via Replicate."""
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
@patch('app.ai.replicate_client.replicate')
|
||||
def test_generate_with_claude_haiku(self, mock_replicate, mock_config):
|
||||
"""Test generation with Claude Haiku model."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
mock_replicate.run.return_value = iter(["Claude ", "response"])
|
||||
|
||||
client = ReplicateClient(model=ModelType.CLAUDE_HAIKU)
|
||||
response = client.generate("Test prompt")
|
||||
|
||||
assert response.text == "Claude response"
|
||||
assert response.model == "anthropic/claude-3.5-haiku"
|
||||
|
||||
# Verify Claude-style params (not Llama formatted prompt)
|
||||
call_args = mock_replicate.run.call_args
|
||||
assert "prompt" in call_args[1]['input']
|
||||
# Claude params don't include Llama special tokens
|
||||
assert "<|begin_of_text|>" not in call_args[1]['input']['prompt']
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
@patch('app.ai.replicate_client.replicate')
|
||||
def test_generate_with_claude_system_prompt(self, mock_replicate, mock_config):
|
||||
"""Test Claude generation includes system_prompt parameter."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
mock_replicate.run.return_value = iter(["Response"])
|
||||
|
||||
client = ReplicateClient(model=ModelType.CLAUDE_SONNET)
|
||||
client.generate(
|
||||
prompt="User message",
|
||||
system_prompt="You are a DM"
|
||||
)
|
||||
|
||||
call_args = mock_replicate.run.call_args
|
||||
assert call_args[1]['input']['system_prompt'] == "You are a DM"
|
||||
assert call_args[1]['input']['prompt'] == "User message"
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
@patch('app.ai.replicate_client.replicate')
|
||||
def test_model_specific_defaults(self, mock_replicate, mock_config):
|
||||
"""Test that model-specific defaults are applied."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
mock_replicate.run.return_value = iter(["Response"])
|
||||
|
||||
# Claude Sonnet should use higher max_tokens by default
|
||||
client = ReplicateClient(model=ModelType.CLAUDE_SONNET)
|
||||
client.generate("Test")
|
||||
|
||||
call_args = mock_replicate.run.call_args
|
||||
# Sonnet default is 1024 tokens
|
||||
assert call_args[1]['input']['max_tokens'] == 1024
|
||||
assert call_args[1]['input']['temperature'] == 0.9
|
||||
|
||||
@patch('app.ai.replicate_client.get_config')
|
||||
@patch('app.ai.replicate_client.replicate')
|
||||
def test_model_override_in_generate(self, mock_replicate, mock_config):
|
||||
"""Test overriding model in generate() call."""
|
||||
mock_config.return_value = MagicMock(
|
||||
replicate_api_token="token",
|
||||
REPLICATE_MODEL=None
|
||||
)
|
||||
|
||||
mock_replicate.run.return_value = iter(["Response"])
|
||||
|
||||
# Init with Llama, but call with Claude
|
||||
client = ReplicateClient(model=ModelType.LLAMA_3_8B)
|
||||
response = client.generate("Test", model=ModelType.CLAUDE_HAIKU)
|
||||
|
||||
# Response should reflect the overridden model
|
||||
assert response.model == "anthropic/claude-3.5-haiku"
|
||||
|
||||
# Verify correct model was called
|
||||
call_args = mock_replicate.run.call_args
|
||||
assert call_args[0][0] == "anthropic/claude-3.5-haiku"
|
||||
390
api/tests/test_session_model.py
Normal file
390
api/tests/test_session_model.py
Normal file
@@ -0,0 +1,390 @@
|
||||
"""
|
||||
Tests for the GameSession model with solo play support.
|
||||
|
||||
Tests cover:
|
||||
- Solo session creation and serialization
|
||||
- Multiplayer session creation and serialization
|
||||
- ConversationEntry with timestamps
|
||||
- GameState with location_type
|
||||
- Session type detection
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
|
||||
from app.models.session import (
|
||||
GameSession,
|
||||
GameState,
|
||||
ConversationEntry,
|
||||
SessionConfig,
|
||||
)
|
||||
from app.models.enums import (
|
||||
SessionStatus,
|
||||
SessionType,
|
||||
LocationType,
|
||||
)
|
||||
|
||||
|
||||
class TestGameState:
|
||||
"""Tests for GameState dataclass."""
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test GameState has correct defaults."""
|
||||
state = GameState()
|
||||
assert state.current_location == "Crossroads Village"
|
||||
assert state.location_type == LocationType.TOWN
|
||||
assert state.discovered_locations == []
|
||||
assert state.active_quests == []
|
||||
assert state.world_events == []
|
||||
|
||||
def test_to_dict_serializes_location_type(self):
|
||||
"""Test location_type is serialized as string."""
|
||||
state = GameState(
|
||||
current_location="The Rusty Anchor",
|
||||
location_type=LocationType.TAVERN,
|
||||
)
|
||||
data = state.to_dict()
|
||||
assert data["current_location"] == "The Rusty Anchor"
|
||||
assert data["location_type"] == "tavern"
|
||||
|
||||
def test_from_dict_deserializes_location_type(self):
|
||||
"""Test location_type is deserialized from string."""
|
||||
data = {
|
||||
"current_location": "Dark Forest",
|
||||
"location_type": "wilderness",
|
||||
"discovered_locations": ["Town A"],
|
||||
"active_quests": ["quest_1"],
|
||||
"world_events": [],
|
||||
}
|
||||
state = GameState.from_dict(data)
|
||||
assert state.current_location == "Dark Forest"
|
||||
assert state.location_type == LocationType.WILDERNESS
|
||||
assert state.discovered_locations == ["Town A"]
|
||||
|
||||
def test_roundtrip_serialization(self):
|
||||
"""Test GameState serializes and deserializes correctly."""
|
||||
state = GameState(
|
||||
current_location="Ancient Library",
|
||||
location_type=LocationType.LIBRARY,
|
||||
discovered_locations=["Town", "Forest"],
|
||||
active_quests=["quest_1", "quest_2"],
|
||||
world_events=[{"type": "festival"}],
|
||||
)
|
||||
data = state.to_dict()
|
||||
restored = GameState.from_dict(data)
|
||||
|
||||
assert restored.current_location == state.current_location
|
||||
assert restored.location_type == state.location_type
|
||||
assert restored.discovered_locations == state.discovered_locations
|
||||
assert restored.active_quests == state.active_quests
|
||||
|
||||
|
||||
class TestConversationEntry:
|
||||
"""Tests for ConversationEntry dataclass."""
|
||||
|
||||
def test_auto_timestamp(self):
|
||||
"""Test timestamp is auto-generated."""
|
||||
entry = ConversationEntry(
|
||||
turn=1,
|
||||
character_id="char_123",
|
||||
character_name="Hero",
|
||||
action="I explore",
|
||||
dm_response="You find a chest",
|
||||
)
|
||||
assert entry.timestamp
|
||||
assert entry.timestamp.endswith("Z")
|
||||
|
||||
def test_provided_timestamp_preserved(self):
|
||||
"""Test provided timestamp is not overwritten."""
|
||||
ts = "2025-11-21T10:30:00Z"
|
||||
entry = ConversationEntry(
|
||||
turn=1,
|
||||
character_id="char_123",
|
||||
character_name="Hero",
|
||||
action="I explore",
|
||||
dm_response="You find a chest",
|
||||
timestamp=ts,
|
||||
)
|
||||
assert entry.timestamp == ts
|
||||
|
||||
def test_to_dict_with_quest_offered(self):
|
||||
"""Test serialization includes quest_offered when present."""
|
||||
entry = ConversationEntry(
|
||||
turn=5,
|
||||
character_id="char_123",
|
||||
character_name="Hero",
|
||||
action="Talk to elder",
|
||||
dm_response="The elder offers you a quest",
|
||||
quest_offered={
|
||||
"quest_id": "quest_goblin_cave",
|
||||
"quest_name": "Clear the Goblin Cave",
|
||||
},
|
||||
)
|
||||
data = entry.to_dict()
|
||||
assert "quest_offered" in data
|
||||
assert data["quest_offered"]["quest_id"] == "quest_goblin_cave"
|
||||
|
||||
def test_to_dict_without_quest_offered(self):
|
||||
"""Test serialization omits quest_offered when None."""
|
||||
entry = ConversationEntry(
|
||||
turn=1,
|
||||
character_id="char_123",
|
||||
character_name="Hero",
|
||||
action="I explore",
|
||||
dm_response="You find nothing",
|
||||
)
|
||||
data = entry.to_dict()
|
||||
assert "quest_offered" not in data
|
||||
|
||||
def test_from_dict_roundtrip(self):
|
||||
"""Test ConversationEntry roundtrip serialization."""
|
||||
entry = ConversationEntry(
|
||||
turn=3,
|
||||
character_id="char_456",
|
||||
character_name="Wizard",
|
||||
action="Cast fireball",
|
||||
dm_response="The spell illuminates the cave",
|
||||
combat_log=[{"action": "attack", "damage": 15}],
|
||||
quest_offered={"quest_id": "q1"},
|
||||
)
|
||||
data = entry.to_dict()
|
||||
restored = ConversationEntry.from_dict(data)
|
||||
|
||||
assert restored.turn == entry.turn
|
||||
assert restored.character_id == entry.character_id
|
||||
assert restored.action == entry.action
|
||||
assert restored.dm_response == entry.dm_response
|
||||
assert restored.timestamp == entry.timestamp
|
||||
assert restored.combat_log == entry.combat_log
|
||||
assert restored.quest_offered == entry.quest_offered
|
||||
|
||||
|
||||
class TestGameSessionSolo:
|
||||
"""Tests for solo GameSession functionality."""
|
||||
|
||||
def test_create_solo_session(self):
|
||||
"""Test creating a solo session."""
|
||||
session = GameSession(
|
||||
session_id="sess_123",
|
||||
session_type=SessionType.SOLO,
|
||||
solo_character_id="char_456",
|
||||
user_id="user_789",
|
||||
)
|
||||
assert session.session_type == SessionType.SOLO
|
||||
assert session.solo_character_id == "char_456"
|
||||
assert session.user_id == "user_789"
|
||||
assert session.is_solo() is True
|
||||
|
||||
def test_is_solo_method(self):
|
||||
"""Test is_solo returns correct values."""
|
||||
solo = GameSession(
|
||||
session_id="s1",
|
||||
session_type=SessionType.SOLO,
|
||||
solo_character_id="c1",
|
||||
)
|
||||
multi = GameSession(
|
||||
session_id="s2",
|
||||
session_type=SessionType.MULTIPLAYER,
|
||||
party_member_ids=["c1", "c2"],
|
||||
)
|
||||
assert solo.is_solo() is True
|
||||
assert multi.is_solo() is False
|
||||
|
||||
def test_get_character_id_solo(self):
|
||||
"""Test get_character_id returns solo_character_id for solo sessions."""
|
||||
session = GameSession(
|
||||
session_id="sess_123",
|
||||
session_type=SessionType.SOLO,
|
||||
solo_character_id="char_456",
|
||||
)
|
||||
assert session.get_character_id() == "char_456"
|
||||
|
||||
def test_get_character_id_multiplayer(self):
|
||||
"""Test get_character_id returns current turn character for multiplayer."""
|
||||
session = GameSession(
|
||||
session_id="sess_123",
|
||||
session_type=SessionType.MULTIPLAYER,
|
||||
party_member_ids=["c1", "c2", "c3"],
|
||||
turn_order=["c2", "c1", "c3"],
|
||||
current_turn=1,
|
||||
)
|
||||
assert session.get_character_id() == "c1"
|
||||
|
||||
def test_to_dict_includes_new_fields(self):
|
||||
"""Test to_dict includes session_type, solo_character_id, user_id."""
|
||||
session = GameSession(
|
||||
session_id="sess_123",
|
||||
session_type=SessionType.SOLO,
|
||||
solo_character_id="char_456",
|
||||
user_id="user_789",
|
||||
)
|
||||
data = session.to_dict()
|
||||
|
||||
assert data["session_id"] == "sess_123"
|
||||
assert data["session_type"] == "solo"
|
||||
assert data["solo_character_id"] == "char_456"
|
||||
assert data["user_id"] == "user_789"
|
||||
|
||||
def test_from_dict_solo_session(self):
|
||||
"""Test from_dict correctly deserializes solo session."""
|
||||
data = {
|
||||
"session_id": "sess_123",
|
||||
"session_type": "solo",
|
||||
"solo_character_id": "char_456",
|
||||
"user_id": "user_789",
|
||||
"party_member_ids": [],
|
||||
"turn_number": 5,
|
||||
"game_state": {
|
||||
"current_location": "Town",
|
||||
"location_type": "town",
|
||||
},
|
||||
}
|
||||
session = GameSession.from_dict(data)
|
||||
|
||||
assert session.session_id == "sess_123"
|
||||
assert session.session_type == SessionType.SOLO
|
||||
assert session.solo_character_id == "char_456"
|
||||
assert session.user_id == "user_789"
|
||||
assert session.turn_number == 5
|
||||
assert session.game_state.location_type == LocationType.TOWN
|
||||
|
||||
def test_roundtrip_serialization(self):
|
||||
"""Test complete roundtrip of solo session."""
|
||||
session = GameSession(
|
||||
session_id="sess_test",
|
||||
session_type=SessionType.SOLO,
|
||||
solo_character_id="char_hero",
|
||||
user_id="user_player",
|
||||
turn_number=10,
|
||||
game_state=GameState(
|
||||
current_location="Dark Dungeon",
|
||||
location_type=LocationType.DUNGEON,
|
||||
active_quests=["quest_1"],
|
||||
),
|
||||
conversation_history=[
|
||||
ConversationEntry(
|
||||
turn=1,
|
||||
character_id="char_hero",
|
||||
character_name="Hero",
|
||||
action="Enter dungeon",
|
||||
dm_response="The darkness swallows you...",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
data = session.to_dict()
|
||||
restored = GameSession.from_dict(data)
|
||||
|
||||
assert restored.session_id == session.session_id
|
||||
assert restored.session_type == session.session_type
|
||||
assert restored.solo_character_id == session.solo_character_id
|
||||
assert restored.user_id == session.user_id
|
||||
assert restored.turn_number == session.turn_number
|
||||
assert restored.game_state.current_location == session.game_state.current_location
|
||||
assert restored.game_state.location_type == session.game_state.location_type
|
||||
assert len(restored.conversation_history) == 1
|
||||
assert restored.conversation_history[0].action == "Enter dungeon"
|
||||
|
||||
def test_repr_solo(self):
|
||||
"""Test __repr__ for solo session."""
|
||||
session = GameSession(
|
||||
session_id="sess_123",
|
||||
session_type=SessionType.SOLO,
|
||||
solo_character_id="char_456",
|
||||
turn_number=5,
|
||||
)
|
||||
repr_str = repr(session)
|
||||
assert "type=solo" in repr_str
|
||||
assert "char=char_456" in repr_str
|
||||
assert "turn=5" in repr_str
|
||||
|
||||
def test_repr_multiplayer(self):
|
||||
"""Test __repr__ for multiplayer session."""
|
||||
session = GameSession(
|
||||
session_id="sess_123",
|
||||
session_type=SessionType.MULTIPLAYER,
|
||||
party_member_ids=["c1", "c2", "c3"],
|
||||
turn_number=10,
|
||||
)
|
||||
repr_str = repr(session)
|
||||
assert "type=multiplayer" in repr_str
|
||||
assert "party=3" in repr_str
|
||||
assert "turn=10" in repr_str
|
||||
|
||||
|
||||
class TestGameSessionBackwardsCompatibility:
|
||||
"""Tests for backwards compatibility with existing sessions."""
|
||||
|
||||
def test_default_session_type_is_solo(self):
|
||||
"""Test new sessions default to solo type."""
|
||||
session = GameSession(session_id="test")
|
||||
assert session.session_type == SessionType.SOLO
|
||||
|
||||
def test_from_dict_without_session_type(self):
|
||||
"""Test from_dict handles missing session_type (defaults to solo)."""
|
||||
data = {
|
||||
"session_id": "old_session",
|
||||
"party_member_ids": ["c1"],
|
||||
}
|
||||
session = GameSession.from_dict(data)
|
||||
assert session.session_type == SessionType.SOLO
|
||||
|
||||
def test_from_dict_without_location_type(self):
|
||||
"""Test from_dict handles missing location_type in game_state."""
|
||||
data = {
|
||||
"session_id": "old_session",
|
||||
"game_state": {
|
||||
"current_location": "Old Town",
|
||||
},
|
||||
}
|
||||
session = GameSession.from_dict(data)
|
||||
assert session.game_state.location_type == LocationType.TOWN
|
||||
|
||||
def test_existing_methods_still_work(self):
|
||||
"""Test existing session methods work with new fields."""
|
||||
session = GameSession(
|
||||
session_id="test",
|
||||
session_type=SessionType.SOLO,
|
||||
solo_character_id="char_1",
|
||||
)
|
||||
|
||||
# Test existing methods
|
||||
assert session.is_in_combat() is False
|
||||
session.update_activity()
|
||||
assert session.last_activity
|
||||
|
||||
# Add conversation entry
|
||||
entry = ConversationEntry(
|
||||
turn=1,
|
||||
character_id="char_1",
|
||||
character_name="Hero",
|
||||
action="test",
|
||||
dm_response="response",
|
||||
)
|
||||
session.add_conversation_entry(entry)
|
||||
assert len(session.conversation_history) == 1
|
||||
|
||||
|
||||
class TestLocationTypeEnum:
|
||||
"""Tests for LocationType enum."""
|
||||
|
||||
def test_all_location_types_defined(self):
|
||||
"""Test all expected location types exist."""
|
||||
expected = ["town", "tavern", "wilderness", "dungeon", "ruins", "library", "safe_area"]
|
||||
actual = [lt.value for lt in LocationType]
|
||||
assert sorted(actual) == sorted(expected)
|
||||
|
||||
def test_location_type_from_string(self):
|
||||
"""Test LocationType can be created from string."""
|
||||
assert LocationType("town") == LocationType.TOWN
|
||||
assert LocationType("wilderness") == LocationType.WILDERNESS
|
||||
assert LocationType("dungeon") == LocationType.DUNGEON
|
||||
|
||||
|
||||
class TestSessionTypeEnum:
|
||||
"""Tests for SessionType enum."""
|
||||
|
||||
def test_session_types_defined(self):
|
||||
"""Test session types are defined correctly."""
|
||||
assert SessionType.SOLO.value == "solo"
|
||||
assert SessionType.MULTIPLAYER.value == "multiplayer"
|
||||
566
api/tests/test_session_service.py
Normal file
566
api/tests/test_session_service.py
Normal file
@@ -0,0 +1,566 @@
|
||||
"""
|
||||
Tests for the SessionService.
|
||||
|
||||
Tests cover:
|
||||
- Solo session creation
|
||||
- Session retrieval and listing
|
||||
- Conversation history management
|
||||
- Game state tracking (location, quests, events)
|
||||
- Session validation and limits
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from datetime import datetime
|
||||
|
||||
from app.services.session_service import (
|
||||
SessionService,
|
||||
SessionNotFound,
|
||||
SessionLimitExceeded,
|
||||
SessionValidationError,
|
||||
MAX_ACTIVE_SESSIONS,
|
||||
)
|
||||
from app.models.session import GameSession, GameState, ConversationEntry
|
||||
from app.models.enums import SessionStatus, SessionType, LocationType
|
||||
from app.models.character import Character
|
||||
from app.models.skills import PlayerClass
|
||||
from app.models.origins import Origin
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db():
|
||||
"""Create mock database service."""
|
||||
with patch('app.services.session_service.get_database_service') as mock:
|
||||
db = Mock()
|
||||
mock.return_value = db
|
||||
yield db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_appwrite():
|
||||
"""Create mock Appwrite service."""
|
||||
with patch('app.services.session_service.AppwriteService') as mock:
|
||||
service = Mock()
|
||||
mock.return_value = service
|
||||
yield service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_character_service():
|
||||
"""Create mock character service."""
|
||||
with patch('app.services.session_service.get_character_service') as mock:
|
||||
service = Mock()
|
||||
mock.return_value = service
|
||||
yield service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_character():
|
||||
"""Create a sample character for testing."""
|
||||
return Character(
|
||||
character_id="char_123",
|
||||
user_id="user_456",
|
||||
name="Test Hero",
|
||||
player_class=Mock(spec=PlayerClass),
|
||||
origin=Mock(spec=Origin),
|
||||
level=5,
|
||||
experience=1000,
|
||||
base_stats={"strength": 10},
|
||||
unlocked_skills=[],
|
||||
inventory=[],
|
||||
equipped={},
|
||||
gold=100,
|
||||
active_quests=[],
|
||||
discovered_locations=[],
|
||||
current_location="Town"
|
||||
)
|
||||
|
||||
|
||||
class TestSessionServiceCreation:
|
||||
"""Tests for session creation."""
|
||||
|
||||
def test_create_solo_session_success(self, mock_db, mock_appwrite, mock_character_service, sample_character):
|
||||
"""Test successful solo session creation."""
|
||||
mock_character_service.get_character.return_value = sample_character
|
||||
mock_db.count_documents.return_value = 0
|
||||
mock_db.create_document.return_value = None
|
||||
|
||||
service = SessionService()
|
||||
session = service.create_solo_session(
|
||||
user_id="user_456",
|
||||
character_id="char_123"
|
||||
)
|
||||
|
||||
assert session.session_type == SessionType.SOLO
|
||||
assert session.solo_character_id == "char_123"
|
||||
assert session.user_id == "user_456"
|
||||
assert session.turn_number == 0
|
||||
assert session.status == SessionStatus.ACTIVE
|
||||
assert session.game_state.current_location == "Crossroads Village"
|
||||
assert session.game_state.location_type == LocationType.TOWN
|
||||
|
||||
mock_db.create_document.assert_called_once()
|
||||
|
||||
def test_create_solo_session_character_not_found(self, mock_db, mock_appwrite, mock_character_service):
|
||||
"""Test session creation fails when character not found."""
|
||||
from app.services.character_service import CharacterNotFound
|
||||
mock_character_service.get_character.side_effect = CharacterNotFound("Not found")
|
||||
|
||||
service = SessionService()
|
||||
with pytest.raises(CharacterNotFound):
|
||||
service.create_solo_session(
|
||||
user_id="user_456",
|
||||
character_id="char_invalid"
|
||||
)
|
||||
|
||||
def test_create_solo_session_limit_exceeded(self, mock_db, mock_appwrite, mock_character_service, sample_character):
|
||||
"""Test session creation fails when limit exceeded."""
|
||||
mock_character_service.get_character.return_value = sample_character
|
||||
mock_db.count_documents.return_value = MAX_ACTIVE_SESSIONS
|
||||
|
||||
service = SessionService()
|
||||
with pytest.raises(SessionLimitExceeded):
|
||||
service.create_solo_session(
|
||||
user_id="user_456",
|
||||
character_id="char_123"
|
||||
)
|
||||
|
||||
def test_create_solo_session_custom_location(self, mock_db, mock_appwrite, mock_character_service, sample_character):
|
||||
"""Test session creation with custom starting location."""
|
||||
mock_character_service.get_character.return_value = sample_character
|
||||
mock_db.count_documents.return_value = 0
|
||||
|
||||
service = SessionService()
|
||||
session = service.create_solo_session(
|
||||
user_id="user_456",
|
||||
character_id="char_123",
|
||||
starting_location="Dark Forest",
|
||||
starting_location_type=LocationType.WILDERNESS
|
||||
)
|
||||
|
||||
assert session.game_state.current_location == "Dark Forest"
|
||||
assert session.game_state.location_type == LocationType.WILDERNESS
|
||||
assert "Dark Forest" in session.game_state.discovered_locations
|
||||
|
||||
|
||||
class TestSessionServiceRetrieval:
|
||||
"""Tests for session retrieval."""
|
||||
|
||||
def test_get_session_success(self, mock_db, mock_appwrite, mock_character_service):
|
||||
"""Test successful session retrieval."""
|
||||
session_data = {
|
||||
"session_id": "sess_123",
|
||||
"session_type": "solo",
|
||||
"solo_character_id": "char_456",
|
||||
"user_id": "user_789",
|
||||
"party_member_ids": [],
|
||||
"turn_number": 5,
|
||||
"status": "active",
|
||||
"game_state": {
|
||||
"current_location": "Town",
|
||||
"location_type": "town"
|
||||
},
|
||||
"conversation_history": [],
|
||||
"config": {},
|
||||
"turn_order": [],
|
||||
"current_turn": 0
|
||||
}
|
||||
|
||||
mock_document = Mock()
|
||||
mock_document.data = {
|
||||
'userId': 'user_789',
|
||||
'sessionData': __import__('json').dumps(session_data)
|
||||
}
|
||||
mock_db.get_document.return_value = mock_document
|
||||
|
||||
service = SessionService()
|
||||
session = service.get_session("sess_123", "user_789")
|
||||
|
||||
assert session.session_id == "sess_123"
|
||||
assert session.user_id == "user_789"
|
||||
assert session.turn_number == 5
|
||||
|
||||
def test_get_session_not_found(self, mock_db, mock_appwrite, mock_character_service):
|
||||
"""Test session retrieval when not found."""
|
||||
mock_db.get_document.return_value = None
|
||||
|
||||
service = SessionService()
|
||||
with pytest.raises(SessionNotFound):
|
||||
service.get_session("sess_invalid")
|
||||
|
||||
def test_get_session_wrong_user(self, mock_db, mock_appwrite, mock_character_service):
|
||||
"""Test session retrieval with wrong user ID."""
|
||||
mock_document = Mock()
|
||||
mock_document.data = {
|
||||
'userId': 'user_other',
|
||||
'sessionData': '{}'
|
||||
}
|
||||
mock_db.get_document.return_value = mock_document
|
||||
|
||||
service = SessionService()
|
||||
with pytest.raises(SessionNotFound):
|
||||
service.get_session("sess_123", "user_wrong")
|
||||
|
||||
def test_get_user_sessions(self, mock_db, mock_appwrite, mock_character_service):
|
||||
"""Test getting all sessions for a user."""
|
||||
session_data = {
|
||||
"session_id": "sess_1",
|
||||
"session_type": "solo",
|
||||
"user_id": "user_123",
|
||||
"status": "active",
|
||||
"turn_number": 0,
|
||||
"game_state": {},
|
||||
"conversation_history": [],
|
||||
"config": {},
|
||||
"turn_order": [],
|
||||
"current_turn": 0,
|
||||
"party_member_ids": []
|
||||
}
|
||||
|
||||
mock_doc = Mock()
|
||||
mock_doc.data = {'sessionData': __import__('json').dumps(session_data)}
|
||||
mock_doc.id = "sess_1"
|
||||
mock_db.list_rows.return_value = [mock_doc]
|
||||
|
||||
service = SessionService()
|
||||
sessions = service.get_user_sessions("user_123")
|
||||
|
||||
assert len(sessions) == 1
|
||||
assert sessions[0].session_id == "sess_1"
|
||||
|
||||
|
||||
class TestConversationHistory:
|
||||
"""Tests for conversation history management."""
|
||||
|
||||
def test_add_conversation_entry(self, mock_db, mock_appwrite, mock_character_service):
|
||||
"""Test adding a conversation entry."""
|
||||
session_data = {
|
||||
"session_id": "sess_123",
|
||||
"session_type": "solo",
|
||||
"user_id": "user_456",
|
||||
"turn_number": 0,
|
||||
"status": "active",
|
||||
"game_state": {"current_location": "Town", "location_type": "town"},
|
||||
"conversation_history": [],
|
||||
"config": {},
|
||||
"turn_order": [],
|
||||
"current_turn": 0,
|
||||
"party_member_ids": []
|
||||
}
|
||||
|
||||
mock_document = Mock()
|
||||
mock_document.data = {
|
||||
'userId': 'user_456',
|
||||
'sessionData': __import__('json').dumps(session_data)
|
||||
}
|
||||
mock_db.get_document.return_value = mock_document
|
||||
|
||||
service = SessionService()
|
||||
updated = service.add_conversation_entry(
|
||||
session_id="sess_123",
|
||||
character_id="char_789",
|
||||
character_name="Hero",
|
||||
action="I explore the area",
|
||||
dm_response="You find a hidden path..."
|
||||
)
|
||||
|
||||
assert updated.turn_number == 1
|
||||
assert len(updated.conversation_history) == 1
|
||||
assert updated.conversation_history[0].action == "I explore the area"
|
||||
assert updated.conversation_history[0].dm_response == "You find a hidden path..."
|
||||
|
||||
mock_db.update_document.assert_called_once()
|
||||
|
||||
def test_add_conversation_entry_with_quest(self, mock_db, mock_appwrite, mock_character_service):
|
||||
"""Test adding conversation entry with quest offering."""
|
||||
session_data = {
|
||||
"session_id": "sess_123",
|
||||
"session_type": "solo",
|
||||
"user_id": "user_456",
|
||||
"turn_number": 5,
|
||||
"status": "active",
|
||||
"game_state": {},
|
||||
"conversation_history": [],
|
||||
"config": {},
|
||||
"turn_order": [],
|
||||
"current_turn": 0,
|
||||
"party_member_ids": []
|
||||
}
|
||||
|
||||
mock_document = Mock()
|
||||
mock_document.data = {
|
||||
'userId': 'user_456',
|
||||
'sessionData': __import__('json').dumps(session_data)
|
||||
}
|
||||
mock_db.get_document.return_value = mock_document
|
||||
|
||||
service = SessionService()
|
||||
updated = service.add_conversation_entry(
|
||||
session_id="sess_123",
|
||||
character_id="char_789",
|
||||
character_name="Hero",
|
||||
action="Talk to elder",
|
||||
dm_response="The elder offers you a quest...",
|
||||
quest_offered={"quest_id": "quest_goblin", "name": "Clear Goblin Cave"}
|
||||
)
|
||||
|
||||
assert updated.conversation_history[0].quest_offered is not None
|
||||
assert updated.conversation_history[0].quest_offered["quest_id"] == "quest_goblin"
|
||||
|
||||
def test_get_recent_history(self, mock_db, mock_appwrite, mock_character_service):
|
||||
"""Test getting recent conversation history."""
|
||||
# Create session with 5 conversation entries
|
||||
entries = []
|
||||
for i in range(5):
|
||||
entries.append({
|
||||
"turn": i + 1,
|
||||
"character_id": "char_123",
|
||||
"character_name": "Hero",
|
||||
"action": f"Action {i+1}",
|
||||
"dm_response": f"Response {i+1}",
|
||||
"timestamp": "2025-11-21T10:00:00Z",
|
||||
"combat_log": []
|
||||
})
|
||||
|
||||
session_data = {
|
||||
"session_id": "sess_123",
|
||||
"session_type": "solo",
|
||||
"user_id": "user_456",
|
||||
"turn_number": 5,
|
||||
"status": "active",
|
||||
"game_state": {},
|
||||
"conversation_history": entries,
|
||||
"config": {},
|
||||
"turn_order": [],
|
||||
"current_turn": 0,
|
||||
"party_member_ids": []
|
||||
}
|
||||
|
||||
mock_document = Mock()
|
||||
mock_document.data = {
|
||||
'userId': 'user_456',
|
||||
'sessionData': __import__('json').dumps(session_data)
|
||||
}
|
||||
mock_db.get_document.return_value = mock_document
|
||||
|
||||
service = SessionService()
|
||||
recent = service.get_recent_history("sess_123", num_turns=3)
|
||||
|
||||
assert len(recent) == 3
|
||||
assert recent[0].turn == 3 # Last 3 entries
|
||||
assert recent[2].turn == 5
|
||||
|
||||
|
||||
class TestGameStateTracking:
|
||||
"""Tests for game state tracking methods."""
|
||||
|
||||
def test_update_location(self, mock_db, mock_appwrite, mock_character_service):
|
||||
"""Test updating session location."""
|
||||
session_data = {
|
||||
"session_id": "sess_123",
|
||||
"session_type": "solo",
|
||||
"user_id": "user_456",
|
||||
"turn_number": 0,
|
||||
"status": "active",
|
||||
"game_state": {
|
||||
"current_location": "Town",
|
||||
"location_type": "town",
|
||||
"discovered_locations": ["Town"]
|
||||
},
|
||||
"conversation_history": [],
|
||||
"config": {},
|
||||
"turn_order": [],
|
||||
"current_turn": 0,
|
||||
"party_member_ids": []
|
||||
}
|
||||
|
||||
mock_document = Mock()
|
||||
mock_document.data = {
|
||||
'userId': 'user_456',
|
||||
'sessionData': __import__('json').dumps(session_data)
|
||||
}
|
||||
mock_db.get_document.return_value = mock_document
|
||||
|
||||
service = SessionService()
|
||||
updated = service.update_location(
|
||||
session_id="sess_123",
|
||||
new_location="Dark Forest",
|
||||
location_type=LocationType.WILDERNESS
|
||||
)
|
||||
|
||||
assert updated.game_state.current_location == "Dark Forest"
|
||||
assert updated.game_state.location_type == LocationType.WILDERNESS
|
||||
assert "Dark Forest" in updated.game_state.discovered_locations
|
||||
|
||||
def test_add_active_quest(self, mock_db, mock_appwrite, mock_character_service):
|
||||
"""Test adding an active quest."""
|
||||
session_data = {
|
||||
"session_id": "sess_123",
|
||||
"session_type": "solo",
|
||||
"user_id": "user_456",
|
||||
"turn_number": 0,
|
||||
"status": "active",
|
||||
"game_state": {
|
||||
"active_quests": [],
|
||||
"current_location": "Town",
|
||||
"location_type": "town"
|
||||
},
|
||||
"conversation_history": [],
|
||||
"config": {},
|
||||
"turn_order": [],
|
||||
"current_turn": 0,
|
||||
"party_member_ids": []
|
||||
}
|
||||
|
||||
mock_document = Mock()
|
||||
mock_document.data = {
|
||||
'userId': 'user_456',
|
||||
'sessionData': __import__('json').dumps(session_data)
|
||||
}
|
||||
mock_db.get_document.return_value = mock_document
|
||||
|
||||
service = SessionService()
|
||||
updated = service.add_active_quest("sess_123", "quest_goblin")
|
||||
|
||||
assert "quest_goblin" in updated.game_state.active_quests
|
||||
|
||||
def test_add_active_quest_limit(self, mock_db, mock_appwrite, mock_character_service):
|
||||
"""Test adding quest fails when max reached."""
|
||||
session_data = {
|
||||
"session_id": "sess_123",
|
||||
"session_type": "solo",
|
||||
"user_id": "user_456",
|
||||
"turn_number": 0,
|
||||
"status": "active",
|
||||
"game_state": {
|
||||
"active_quests": ["quest_1", "quest_2"], # Already at max
|
||||
"current_location": "Town",
|
||||
"location_type": "town"
|
||||
},
|
||||
"conversation_history": [],
|
||||
"config": {},
|
||||
"turn_order": [],
|
||||
"current_turn": 0,
|
||||
"party_member_ids": []
|
||||
}
|
||||
|
||||
mock_document = Mock()
|
||||
mock_document.data = {
|
||||
'userId': 'user_456',
|
||||
'sessionData': __import__('json').dumps(session_data)
|
||||
}
|
||||
mock_db.get_document.return_value = mock_document
|
||||
|
||||
service = SessionService()
|
||||
with pytest.raises(SessionValidationError):
|
||||
service.add_active_quest("sess_123", "quest_3")
|
||||
|
||||
def test_remove_active_quest(self, mock_db, mock_appwrite, mock_character_service):
|
||||
"""Test removing an active quest."""
|
||||
session_data = {
|
||||
"session_id": "sess_123",
|
||||
"session_type": "solo",
|
||||
"user_id": "user_456",
|
||||
"turn_number": 0,
|
||||
"status": "active",
|
||||
"game_state": {
|
||||
"active_quests": ["quest_1", "quest_2"],
|
||||
"current_location": "Town",
|
||||
"location_type": "town"
|
||||
},
|
||||
"conversation_history": [],
|
||||
"config": {},
|
||||
"turn_order": [],
|
||||
"current_turn": 0,
|
||||
"party_member_ids": []
|
||||
}
|
||||
|
||||
mock_document = Mock()
|
||||
mock_document.data = {
|
||||
'userId': 'user_456',
|
||||
'sessionData': __import__('json').dumps(session_data)
|
||||
}
|
||||
mock_db.get_document.return_value = mock_document
|
||||
|
||||
service = SessionService()
|
||||
updated = service.remove_active_quest("sess_123", "quest_1")
|
||||
|
||||
assert "quest_1" not in updated.game_state.active_quests
|
||||
assert "quest_2" in updated.game_state.active_quests
|
||||
|
||||
def test_add_world_event(self, mock_db, mock_appwrite, mock_character_service):
|
||||
"""Test adding a world event."""
|
||||
session_data = {
|
||||
"session_id": "sess_123",
|
||||
"session_type": "solo",
|
||||
"user_id": "user_456",
|
||||
"turn_number": 0,
|
||||
"status": "active",
|
||||
"game_state": {
|
||||
"world_events": [],
|
||||
"current_location": "Town",
|
||||
"location_type": "town"
|
||||
},
|
||||
"conversation_history": [],
|
||||
"config": {},
|
||||
"turn_order": [],
|
||||
"current_turn": 0,
|
||||
"party_member_ids": []
|
||||
}
|
||||
|
||||
mock_document = Mock()
|
||||
mock_document.data = {
|
||||
'userId': 'user_456',
|
||||
'sessionData': __import__('json').dumps(session_data)
|
||||
}
|
||||
mock_db.get_document.return_value = mock_document
|
||||
|
||||
service = SessionService()
|
||||
updated = service.add_world_event("sess_123", {"type": "festival", "description": "A festival begins"})
|
||||
|
||||
assert len(updated.game_state.world_events) == 1
|
||||
assert updated.game_state.world_events[0]["type"] == "festival"
|
||||
assert "timestamp" in updated.game_state.world_events[0]
|
||||
|
||||
|
||||
class TestSessionLifecycle:
|
||||
"""Tests for session lifecycle management."""
|
||||
|
||||
def test_end_session(self, mock_db, mock_appwrite, mock_character_service):
|
||||
"""Test ending a session."""
|
||||
session_data = {
|
||||
"session_id": "sess_123",
|
||||
"session_type": "solo",
|
||||
"user_id": "user_456",
|
||||
"turn_number": 10,
|
||||
"status": "active",
|
||||
"game_state": {},
|
||||
"conversation_history": [],
|
||||
"config": {},
|
||||
"turn_order": [],
|
||||
"current_turn": 0,
|
||||
"party_member_ids": []
|
||||
}
|
||||
|
||||
mock_document = Mock()
|
||||
mock_document.data = {
|
||||
'userId': 'user_456',
|
||||
'sessionData': __import__('json').dumps(session_data)
|
||||
}
|
||||
mock_db.get_document.return_value = mock_document
|
||||
|
||||
service = SessionService()
|
||||
updated = service.end_session("sess_123", "user_456")
|
||||
|
||||
assert updated.status == SessionStatus.COMPLETED
|
||||
mock_db.update_document.assert_called_once()
|
||||
|
||||
def test_count_user_sessions(self, mock_db, mock_appwrite, mock_character_service):
|
||||
"""Test counting user sessions."""
|
||||
mock_db.count_documents.return_value = 3
|
||||
|
||||
service = SessionService()
|
||||
count = service.count_user_sessions("user_123", active_only=True)
|
||||
|
||||
assert count == 3
|
||||
mock_db.count_documents.assert_called_once()
|
||||
198
api/tests/test_stats.py
Normal file
198
api/tests/test_stats.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
Unit tests for Stats dataclass.
|
||||
|
||||
Tests computed properties, serialization, and basic operations.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from app.models.stats import Stats
|
||||
|
||||
|
||||
def test_stats_default_values():
|
||||
"""Test that Stats initializes with default values."""
|
||||
stats = Stats()
|
||||
|
||||
assert stats.strength == 10
|
||||
assert stats.dexterity == 10
|
||||
assert stats.constitution == 10
|
||||
assert stats.intelligence == 10
|
||||
assert stats.wisdom == 10
|
||||
assert stats.charisma == 10
|
||||
|
||||
|
||||
def test_stats_custom_values():
|
||||
"""Test creating Stats with custom values."""
|
||||
stats = Stats(
|
||||
strength=15,
|
||||
dexterity=12,
|
||||
constitution=14,
|
||||
intelligence=8,
|
||||
wisdom=10,
|
||||
charisma=11,
|
||||
)
|
||||
|
||||
assert stats.strength == 15
|
||||
assert stats.dexterity == 12
|
||||
assert stats.constitution == 14
|
||||
assert stats.intelligence == 8
|
||||
assert stats.wisdom == 10
|
||||
assert stats.charisma == 11
|
||||
|
||||
|
||||
def test_hit_points_calculation():
|
||||
"""Test HP calculation: 10 + (constitution × 2)."""
|
||||
stats = Stats(constitution=10)
|
||||
assert stats.hit_points == 30 # 10 + (10 × 2)
|
||||
|
||||
stats = Stats(constitution=15)
|
||||
assert stats.hit_points == 40 # 10 + (15 × 2)
|
||||
|
||||
stats = Stats(constitution=20)
|
||||
assert stats.hit_points == 50 # 10 + (20 × 2)
|
||||
|
||||
|
||||
def test_mana_points_calculation():
|
||||
"""Test MP calculation: 10 + (intelligence × 2)."""
|
||||
stats = Stats(intelligence=10)
|
||||
assert stats.mana_points == 30 # 10 + (10 × 2)
|
||||
|
||||
stats = Stats(intelligence=15)
|
||||
assert stats.mana_points == 40 # 10 + (15 × 2)
|
||||
|
||||
stats = Stats(intelligence=8)
|
||||
assert stats.mana_points == 26 # 10 + (8 × 2)
|
||||
|
||||
|
||||
def test_defense_calculation():
|
||||
"""Test defense calculation: constitution // 2."""
|
||||
stats = Stats(constitution=10)
|
||||
assert stats.defense == 5 # 10 // 2
|
||||
|
||||
stats = Stats(constitution=15)
|
||||
assert stats.defense == 7 # 15 // 2
|
||||
|
||||
stats = Stats(constitution=21)
|
||||
assert stats.defense == 10 # 21 // 2
|
||||
|
||||
|
||||
def test_resistance_calculation():
|
||||
"""Test resistance calculation: wisdom // 2."""
|
||||
stats = Stats(wisdom=10)
|
||||
assert stats.resistance == 5 # 10 // 2
|
||||
|
||||
stats = Stats(wisdom=14)
|
||||
assert stats.resistance == 7 # 14 // 2
|
||||
|
||||
stats = Stats(wisdom=9)
|
||||
assert stats.resistance == 4 # 9 // 2
|
||||
|
||||
|
||||
def test_stats_serialization():
|
||||
"""Test to_dict() serialization."""
|
||||
stats = Stats(
|
||||
strength=15,
|
||||
dexterity=12,
|
||||
constitution=14,
|
||||
intelligence=10,
|
||||
wisdom=11,
|
||||
charisma=8,
|
||||
)
|
||||
|
||||
data = stats.to_dict()
|
||||
|
||||
assert data["strength"] == 15
|
||||
assert data["dexterity"] == 12
|
||||
assert data["constitution"] == 14
|
||||
assert data["intelligence"] == 10
|
||||
assert data["wisdom"] == 11
|
||||
assert data["charisma"] == 8
|
||||
|
||||
|
||||
def test_stats_deserialization():
|
||||
"""Test from_dict() deserialization."""
|
||||
data = {
|
||||
"strength": 18,
|
||||
"dexterity": 14,
|
||||
"constitution": 16,
|
||||
"intelligence": 12,
|
||||
"wisdom": 10,
|
||||
"charisma": 9,
|
||||
}
|
||||
|
||||
stats = Stats.from_dict(data)
|
||||
|
||||
assert stats.strength == 18
|
||||
assert stats.dexterity == 14
|
||||
assert stats.constitution == 16
|
||||
assert stats.intelligence == 12
|
||||
assert stats.wisdom == 10
|
||||
assert stats.charisma == 9
|
||||
|
||||
|
||||
def test_stats_deserialization_with_missing_values():
|
||||
"""Test from_dict() with missing values (should use defaults)."""
|
||||
data = {
|
||||
"strength": 15,
|
||||
# Missing other stats
|
||||
}
|
||||
|
||||
stats = Stats.from_dict(data)
|
||||
|
||||
assert stats.strength == 15
|
||||
assert stats.dexterity == 10 # Default
|
||||
assert stats.constitution == 10 # Default
|
||||
assert stats.intelligence == 10 # Default
|
||||
assert stats.wisdom == 10 # Default
|
||||
assert stats.charisma == 10 # Default
|
||||
|
||||
|
||||
def test_stats_round_trip_serialization():
|
||||
"""Test that serialization and deserialization preserve data."""
|
||||
original = Stats(
|
||||
strength=20,
|
||||
dexterity=15,
|
||||
constitution=18,
|
||||
intelligence=10,
|
||||
wisdom=12,
|
||||
charisma=14,
|
||||
)
|
||||
|
||||
# Serialize then deserialize
|
||||
data = original.to_dict()
|
||||
restored = Stats.from_dict(data)
|
||||
|
||||
assert restored.strength == original.strength
|
||||
assert restored.dexterity == original.dexterity
|
||||
assert restored.constitution == original.constitution
|
||||
assert restored.intelligence == original.intelligence
|
||||
assert restored.wisdom == original.wisdom
|
||||
assert restored.charisma == original.charisma
|
||||
|
||||
|
||||
def test_stats_copy():
|
||||
"""Test that copy() creates an independent copy."""
|
||||
original = Stats(strength=15, dexterity=12, constitution=14)
|
||||
copy = original.copy()
|
||||
|
||||
assert copy.strength == original.strength
|
||||
assert copy.dexterity == original.dexterity
|
||||
assert copy.constitution == original.constitution
|
||||
|
||||
# Modify copy
|
||||
copy.strength = 20
|
||||
|
||||
# Original should be unchanged
|
||||
assert original.strength == 15
|
||||
assert copy.strength == 20
|
||||
|
||||
|
||||
def test_stats_repr():
|
||||
"""Test string representation."""
|
||||
stats = Stats(strength=15, constitution=12, intelligence=10)
|
||||
repr_str = repr(stats)
|
||||
|
||||
assert "STR=15" in repr_str
|
||||
assert "CON=12" in repr_str
|
||||
assert "INT=10" in repr_str
|
||||
assert "HP=" in repr_str
|
||||
assert "MP=" in repr_str
|
||||
460
api/tests/test_usage_tracking_service.py
Normal file
460
api/tests/test_usage_tracking_service.py
Normal file
@@ -0,0 +1,460 @@
|
||||
"""
|
||||
Tests for UsageTrackingService.
|
||||
|
||||
These tests verify:
|
||||
- Cost calculation for different models
|
||||
- Usage logging functionality
|
||||
- Daily and monthly usage aggregation
|
||||
- Static helper methods
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, date, timezone, timedelta
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
from app.services.usage_tracking_service import (
|
||||
UsageTrackingService,
|
||||
MODEL_COSTS,
|
||||
DEFAULT_COST
|
||||
)
|
||||
from app.models.ai_usage import (
|
||||
AIUsageLog,
|
||||
DailyUsageSummary,
|
||||
MonthlyUsageSummary,
|
||||
TaskType
|
||||
)
|
||||
|
||||
|
||||
class TestAIUsageLogModel:
|
||||
"""Tests for the AIUsageLog dataclass."""
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test conversion to dictionary."""
|
||||
log = AIUsageLog(
|
||||
log_id="log_123",
|
||||
user_id="user_456",
|
||||
timestamp=datetime(2025, 11, 21, 10, 30, 0, tzinfo=timezone.utc),
|
||||
model="anthropic/claude-3.5-sonnet",
|
||||
tokens_input=100,
|
||||
tokens_output=350,
|
||||
tokens_total=450,
|
||||
estimated_cost=0.00555,
|
||||
task_type=TaskType.STORY_PROGRESSION,
|
||||
session_id="sess_789",
|
||||
character_id="char_abc",
|
||||
request_duration_ms=1500,
|
||||
success=True,
|
||||
error_message=None
|
||||
)
|
||||
|
||||
result = log.to_dict()
|
||||
|
||||
assert result["log_id"] == "log_123"
|
||||
assert result["user_id"] == "user_456"
|
||||
assert result["model"] == "anthropic/claude-3.5-sonnet"
|
||||
assert result["tokens_total"] == 450
|
||||
assert result["task_type"] == "story_progression"
|
||||
assert result["success"] is True
|
||||
|
||||
def test_from_dict(self):
|
||||
"""Test creation from dictionary."""
|
||||
data = {
|
||||
"log_id": "log_123",
|
||||
"user_id": "user_456",
|
||||
"timestamp": "2025-11-21T10:30:00+00:00",
|
||||
"model": "anthropic/claude-3.5-sonnet",
|
||||
"tokens_input": 100,
|
||||
"tokens_output": 350,
|
||||
"tokens_total": 450,
|
||||
"estimated_cost": 0.00555,
|
||||
"task_type": "story_progression",
|
||||
"session_id": "sess_789",
|
||||
"character_id": "char_abc",
|
||||
"request_duration_ms": 1500,
|
||||
"success": True,
|
||||
"error_message": None
|
||||
}
|
||||
|
||||
log = AIUsageLog.from_dict(data)
|
||||
|
||||
assert log.log_id == "log_123"
|
||||
assert log.user_id == "user_456"
|
||||
assert log.task_type == TaskType.STORY_PROGRESSION
|
||||
assert log.tokens_total == 450
|
||||
|
||||
def test_from_dict_with_invalid_task_type(self):
|
||||
"""Test handling of invalid task type."""
|
||||
data = {
|
||||
"log_id": "log_123",
|
||||
"user_id": "user_456",
|
||||
"timestamp": "2025-11-21T10:30:00+00:00",
|
||||
"model": "test-model",
|
||||
"tokens_input": 100,
|
||||
"tokens_output": 200,
|
||||
"tokens_total": 300,
|
||||
"estimated_cost": 0.001,
|
||||
"task_type": "invalid_type"
|
||||
}
|
||||
|
||||
log = AIUsageLog.from_dict(data)
|
||||
|
||||
# Should default to GENERAL
|
||||
assert log.task_type == TaskType.GENERAL
|
||||
|
||||
|
||||
class TestDailyUsageSummary:
|
||||
"""Tests for the DailyUsageSummary dataclass."""
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test conversion to dictionary."""
|
||||
summary = DailyUsageSummary(
|
||||
date=date(2025, 11, 21),
|
||||
user_id="user_123",
|
||||
total_requests=15,
|
||||
total_tokens=6750,
|
||||
total_input_tokens=2000,
|
||||
total_output_tokens=4750,
|
||||
estimated_cost=0.45,
|
||||
requests_by_task={"story_progression": 10, "combat_narration": 5}
|
||||
)
|
||||
|
||||
result = summary.to_dict()
|
||||
|
||||
assert result["date"] == "2025-11-21"
|
||||
assert result["total_requests"] == 15
|
||||
assert result["estimated_cost"] == 0.45
|
||||
assert result["requests_by_task"]["story_progression"] == 10
|
||||
|
||||
|
||||
class TestCostCalculation:
|
||||
"""Tests for cost calculation functionality."""
|
||||
|
||||
def test_calculate_cost_llama(self):
|
||||
"""Test cost calculation for Llama model."""
|
||||
cost = UsageTrackingService.estimate_cost_for_model(
|
||||
model="meta/meta-llama-3-8b-instruct",
|
||||
tokens_input=1000,
|
||||
tokens_output=1000
|
||||
)
|
||||
|
||||
# Llama: $0.0001 per 1K input + $0.0001 per 1K output
|
||||
expected = 0.0001 + 0.0001
|
||||
assert abs(cost - expected) < 0.000001
|
||||
|
||||
def test_calculate_cost_haiku(self):
|
||||
"""Test cost calculation for Claude Haiku."""
|
||||
cost = UsageTrackingService.estimate_cost_for_model(
|
||||
model="anthropic/claude-3.5-haiku",
|
||||
tokens_input=1000,
|
||||
tokens_output=1000
|
||||
)
|
||||
|
||||
# Haiku: $0.001 per 1K input + $0.005 per 1K output
|
||||
expected = 0.001 + 0.005
|
||||
assert abs(cost - expected) < 0.000001
|
||||
|
||||
def test_calculate_cost_sonnet(self):
|
||||
"""Test cost calculation for Claude Sonnet."""
|
||||
cost = UsageTrackingService.estimate_cost_for_model(
|
||||
model="anthropic/claude-3.5-sonnet",
|
||||
tokens_input=1000,
|
||||
tokens_output=1000
|
||||
)
|
||||
|
||||
# Sonnet: $0.003 per 1K input + $0.015 per 1K output
|
||||
expected = 0.003 + 0.015
|
||||
assert abs(cost - expected) < 0.000001
|
||||
|
||||
def test_calculate_cost_opus(self):
|
||||
"""Test cost calculation for Claude Opus."""
|
||||
cost = UsageTrackingService.estimate_cost_for_model(
|
||||
model="anthropic/claude-3-opus",
|
||||
tokens_input=1000,
|
||||
tokens_output=1000
|
||||
)
|
||||
|
||||
# Opus: $0.015 per 1K input + $0.075 per 1K output
|
||||
expected = 0.015 + 0.075
|
||||
assert abs(cost - expected) < 0.000001
|
||||
|
||||
def test_calculate_cost_unknown_model(self):
|
||||
"""Test cost calculation for unknown model uses default."""
|
||||
cost = UsageTrackingService.estimate_cost_for_model(
|
||||
model="unknown/model",
|
||||
tokens_input=1000,
|
||||
tokens_output=1000
|
||||
)
|
||||
|
||||
# Default: $0.001 per 1K input + $0.005 per 1K output
|
||||
expected = DEFAULT_COST["input"] + DEFAULT_COST["output"]
|
||||
assert abs(cost - expected) < 0.000001
|
||||
|
||||
def test_calculate_cost_fractional_tokens(self):
|
||||
"""Test cost calculation with fractional token counts."""
|
||||
cost = UsageTrackingService.estimate_cost_for_model(
|
||||
model="anthropic/claude-3.5-sonnet",
|
||||
tokens_input=500,
|
||||
tokens_output=250
|
||||
)
|
||||
|
||||
# Sonnet: (500/1000 * 0.003) + (250/1000 * 0.015)
|
||||
expected = 0.0015 + 0.00375
|
||||
assert abs(cost - expected) < 0.000001
|
||||
|
||||
def test_get_model_cost_info(self):
|
||||
"""Test getting cost info for a model."""
|
||||
cost_info = UsageTrackingService.get_model_cost_info(
|
||||
"anthropic/claude-3.5-sonnet"
|
||||
)
|
||||
|
||||
assert cost_info["input"] == 0.003
|
||||
assert cost_info["output"] == 0.015
|
||||
|
||||
def test_get_model_cost_info_unknown(self):
|
||||
"""Test getting cost info for unknown model."""
|
||||
cost_info = UsageTrackingService.get_model_cost_info("unknown/model")
|
||||
|
||||
assert cost_info == DEFAULT_COST
|
||||
|
||||
|
||||
class TestUsageTrackingService:
|
||||
"""Tests for UsageTrackingService class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_env(self):
|
||||
"""Set up mock environment variables."""
|
||||
with patch.dict('os.environ', {
|
||||
'APPWRITE_ENDPOINT': 'https://cloud.appwrite.io/v1',
|
||||
'APPWRITE_PROJECT_ID': 'test_project',
|
||||
'APPWRITE_API_KEY': 'test_api_key',
|
||||
'APPWRITE_DATABASE_ID': 'test_db'
|
||||
}):
|
||||
yield
|
||||
|
||||
@pytest.fixture
|
||||
def mock_databases(self):
|
||||
"""Create mock Databases service."""
|
||||
with patch('app.services.usage_tracking_service.Databases') as mock:
|
||||
yield mock
|
||||
|
||||
@pytest.fixture
|
||||
def service(self, mock_env, mock_databases):
|
||||
"""Create UsageTrackingService instance with mocked dependencies."""
|
||||
service = UsageTrackingService()
|
||||
return service
|
||||
|
||||
def test_init_missing_env(self):
|
||||
"""Test initialization fails with missing env vars."""
|
||||
with patch.dict('os.environ', {}, clear=True):
|
||||
with pytest.raises(ValueError, match="Appwrite configuration incomplete"):
|
||||
UsageTrackingService()
|
||||
|
||||
def test_log_usage_success(self, service):
|
||||
"""Test logging usage successfully."""
|
||||
# Mock the create_document response
|
||||
service.databases.create_document = Mock(return_value={
|
||||
"$id": "doc_123"
|
||||
})
|
||||
|
||||
result = service.log_usage(
|
||||
user_id="user_123",
|
||||
model="anthropic/claude-3.5-sonnet",
|
||||
tokens_input=100,
|
||||
tokens_output=350,
|
||||
task_type=TaskType.STORY_PROGRESSION,
|
||||
session_id="sess_789"
|
||||
)
|
||||
|
||||
# Verify result
|
||||
assert result.user_id == "user_123"
|
||||
assert result.model == "anthropic/claude-3.5-sonnet"
|
||||
assert result.tokens_input == 100
|
||||
assert result.tokens_output == 350
|
||||
assert result.tokens_total == 450
|
||||
assert result.task_type == TaskType.STORY_PROGRESSION
|
||||
assert result.estimated_cost > 0
|
||||
|
||||
# Verify Appwrite was called
|
||||
service.databases.create_document.assert_called_once()
|
||||
call_args = service.databases.create_document.call_args
|
||||
assert call_args.kwargs["database_id"] == "test_db"
|
||||
assert call_args.kwargs["collection_id"] == "ai_usage_logs"
|
||||
|
||||
def test_log_usage_with_error(self, service):
|
||||
"""Test logging usage when request failed."""
|
||||
service.databases.create_document = Mock(return_value={
|
||||
"$id": "doc_123"
|
||||
})
|
||||
|
||||
result = service.log_usage(
|
||||
user_id="user_123",
|
||||
model="anthropic/claude-3.5-sonnet",
|
||||
tokens_input=100,
|
||||
tokens_output=0,
|
||||
task_type=TaskType.STORY_PROGRESSION,
|
||||
success=False,
|
||||
error_message="API timeout"
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert result.error_message == "API timeout"
|
||||
assert result.tokens_output == 0
|
||||
|
||||
def test_get_daily_usage(self, service):
|
||||
"""Test getting daily usage summary."""
|
||||
# Mock list_rows response
|
||||
service.tables_db.list_rows = Mock(return_value={
|
||||
"rows": [
|
||||
{
|
||||
"tokens_input": 100,
|
||||
"tokens_output": 300,
|
||||
"tokens_total": 400,
|
||||
"estimated_cost": 0.005,
|
||||
"task_type": "story_progression"
|
||||
},
|
||||
{
|
||||
"tokens_input": 150,
|
||||
"tokens_output": 350,
|
||||
"tokens_total": 500,
|
||||
"estimated_cost": 0.006,
|
||||
"task_type": "story_progression"
|
||||
},
|
||||
{
|
||||
"tokens_input": 50,
|
||||
"tokens_output": 200,
|
||||
"tokens_total": 250,
|
||||
"estimated_cost": 0.003,
|
||||
"task_type": "combat_narration"
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
result = service.get_daily_usage("user_123", date(2025, 11, 21))
|
||||
|
||||
assert result.user_id == "user_123"
|
||||
assert result.date == date(2025, 11, 21)
|
||||
assert result.total_requests == 3
|
||||
assert result.total_tokens == 1150
|
||||
assert result.total_input_tokens == 300
|
||||
assert result.total_output_tokens == 850
|
||||
assert abs(result.estimated_cost - 0.014) < 0.0001
|
||||
assert result.requests_by_task["story_progression"] == 2
|
||||
assert result.requests_by_task["combat_narration"] == 1
|
||||
|
||||
def test_get_daily_usage_empty(self, service):
|
||||
"""Test getting daily usage when no usage exists."""
|
||||
service.tables_db.list_rows = Mock(return_value={
|
||||
"rows": []
|
||||
})
|
||||
|
||||
result = service.get_daily_usage("user_123", date(2025, 11, 21))
|
||||
|
||||
assert result.total_requests == 0
|
||||
assert result.total_tokens == 0
|
||||
assert result.estimated_cost == 0.0
|
||||
assert result.requests_by_task == {}
|
||||
|
||||
def test_get_monthly_cost(self, service):
|
||||
"""Test getting monthly cost summary."""
|
||||
service.tables_db.list_rows = Mock(return_value={
|
||||
"rows": [
|
||||
{"tokens_total": 1000, "estimated_cost": 0.01},
|
||||
{"tokens_total": 2000, "estimated_cost": 0.02},
|
||||
{"tokens_total": 1500, "estimated_cost": 0.015}
|
||||
]
|
||||
})
|
||||
|
||||
result = service.get_monthly_cost("user_123", 2025, 11)
|
||||
|
||||
assert result.year == 2025
|
||||
assert result.month == 11
|
||||
assert result.user_id == "user_123"
|
||||
assert result.total_requests == 3
|
||||
assert result.total_tokens == 4500
|
||||
assert abs(result.estimated_cost - 0.045) < 0.0001
|
||||
|
||||
def test_get_monthly_cost_invalid_month(self, service):
|
||||
"""Test monthly cost with invalid month raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid month"):
|
||||
service.get_monthly_cost("user_123", 2025, 13)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid month"):
|
||||
service.get_monthly_cost("user_123", 2025, 0)
|
||||
|
||||
def test_get_total_daily_cost(self, service):
|
||||
"""Test getting total daily cost across all users."""
|
||||
service.tables_db.list_rows = Mock(return_value={
|
||||
"rows": [
|
||||
{"estimated_cost": 0.10},
|
||||
{"estimated_cost": 0.25},
|
||||
{"estimated_cost": 0.15}
|
||||
]
|
||||
})
|
||||
|
||||
result = service.get_total_daily_cost(date(2025, 11, 21))
|
||||
|
||||
assert abs(result - 0.50) < 0.0001
|
||||
|
||||
def test_get_user_request_count_today(self, service):
|
||||
"""Test getting user request count for today."""
|
||||
service.tables_db.list_rows = Mock(return_value={
|
||||
"rows": [
|
||||
{"tokens_total": 100, "tokens_input": 30, "tokens_output": 70, "estimated_cost": 0.001, "task_type": "story_progression"},
|
||||
{"tokens_total": 200, "tokens_input": 50, "tokens_output": 150, "estimated_cost": 0.002, "task_type": "story_progression"}
|
||||
]
|
||||
})
|
||||
|
||||
result = service.get_user_request_count_today("user_123")
|
||||
|
||||
assert result == 2
|
||||
|
||||
|
||||
class TestCostEstimations:
|
||||
"""Tests for realistic cost estimation scenarios."""
|
||||
|
||||
def test_free_tier_daily_cost(self):
|
||||
"""Test estimated daily cost for free tier user with Llama."""
|
||||
# 20 requests per day, average 500 total tokens each
|
||||
total_input = 20 * 200
|
||||
total_output = 20 * 300
|
||||
|
||||
cost = UsageTrackingService.estimate_cost_for_model(
|
||||
model="meta/meta-llama-3-8b-instruct",
|
||||
tokens_input=total_input,
|
||||
tokens_output=total_output
|
||||
)
|
||||
|
||||
# Should be very cheap (essentially free)
|
||||
assert cost < 0.01
|
||||
|
||||
def test_premium_tier_daily_cost(self):
|
||||
"""Test estimated daily cost for premium tier user with Sonnet."""
|
||||
# 100 requests per day, average 1000 total tokens each
|
||||
total_input = 100 * 300
|
||||
total_output = 100 * 700
|
||||
|
||||
cost = UsageTrackingService.estimate_cost_for_model(
|
||||
model="anthropic/claude-3.5-sonnet",
|
||||
tokens_input=total_input,
|
||||
tokens_output=total_output
|
||||
)
|
||||
|
||||
# Should be under $2/day for heavy usage
|
||||
assert cost < 2.0
|
||||
|
||||
def test_elite_tier_monthly_cost(self):
|
||||
"""Test estimated monthly cost for elite tier user."""
|
||||
# 200 requests per day * 30 days = 6000 requests
|
||||
# Average 1500 tokens per request
|
||||
total_input = 6000 * 500
|
||||
total_output = 6000 * 1000
|
||||
|
||||
cost = UsageTrackingService.estimate_cost_for_model(
|
||||
model="anthropic/claude-4.5-sonnet",
|
||||
tokens_input=total_input,
|
||||
tokens_output=total_output
|
||||
)
|
||||
|
||||
# Elite tier should be under $100/month even with heavy usage
|
||||
assert cost < 100.0
|
||||
Reference in New Issue
Block a user