first commit

This commit is contained in:
2025-11-24 23:10:55 -06:00
commit 8315fa51c9
279 changed files with 74600 additions and 0 deletions

3
api/tests/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
"""
Unit tests for Code of Conquest.
"""

View File

@@ -0,0 +1,311 @@
"""
Tests for ActionPrompt model
Tests the action prompt availability logic, tier filtering,
location filtering, and serialization.
"""
import pytest
from app.models.action_prompt import (
ActionPrompt,
ActionCategory,
LocationType,
)
from app.ai.model_selector import UserTier
class TestActionPrompt:
"""Tests for ActionPrompt dataclass."""
@pytest.fixture
def free_action(self):
"""Create a free tier action available in towns."""
return ActionPrompt(
prompt_id="ask_locals",
category=ActionCategory.ASK_QUESTION,
display_text="Ask locals for information",
description="Talk to NPCs to learn about quests and rumors",
tier_required=UserTier.FREE,
context_filter=[LocationType.TOWN, LocationType.TAVERN],
dm_prompt_template="The player asks locals about {{ topic }}.",
)
@pytest.fixture
def premium_action(self):
"""Create a premium tier action available anywhere."""
return ActionPrompt(
prompt_id="investigate",
category=ActionCategory.GATHER_INFO,
display_text="Investigate suspicious activity",
description="Look for clues and hidden details",
tier_required=UserTier.PREMIUM,
context_filter=[LocationType.ANY],
dm_prompt_template="The player investigates the area.",
icon="magnifying_glass",
)
@pytest.fixture
def elite_action(self):
"""Create an elite tier action for libraries."""
return ActionPrompt(
prompt_id="consult_texts",
category=ActionCategory.SPECIAL,
display_text="Consult ancient texts",
description="Study rare manuscripts for hidden knowledge",
tier_required=UserTier.ELITE,
context_filter=[LocationType.LIBRARY, LocationType.TOWN],
dm_prompt_template="The player studies ancient texts.",
cooldown_turns=3,
)
# Availability tests
def test_free_action_available_to_free_user(self, free_action):
"""Free action should be available to free tier users."""
assert free_action.is_available(UserTier.FREE, LocationType.TOWN) is True
def test_free_action_available_to_premium_user(self, free_action):
"""Free action should be available to higher tier users."""
assert free_action.is_available(UserTier.PREMIUM, LocationType.TOWN) is True
assert free_action.is_available(UserTier.ELITE, LocationType.TOWN) is True
def test_premium_action_not_available_to_free_user(self, premium_action):
"""Premium action should not be available to free tier users."""
assert premium_action.is_available(UserTier.FREE, LocationType.TOWN) is False
def test_premium_action_available_to_premium_user(self, premium_action):
"""Premium action should be available to premium tier users."""
assert premium_action.is_available(UserTier.PREMIUM, LocationType.TOWN) is True
def test_elite_action_not_available_to_premium_user(self, elite_action):
"""Elite action should not be available to premium tier users."""
assert elite_action.is_available(UserTier.PREMIUM, LocationType.LIBRARY) is False
def test_elite_action_available_to_elite_user(self, elite_action):
"""Elite action should be available to elite tier users."""
assert elite_action.is_available(UserTier.ELITE, LocationType.LIBRARY) is True
# Location filtering tests
def test_action_available_in_matching_location(self, free_action):
"""Action should be available in matching locations."""
assert free_action.is_available(UserTier.FREE, LocationType.TOWN) is True
assert free_action.is_available(UserTier.FREE, LocationType.TAVERN) is True
def test_action_not_available_in_non_matching_location(self, free_action):
"""Action should not be available in non-matching locations."""
assert free_action.is_available(UserTier.FREE, LocationType.WILDERNESS) is False
assert free_action.is_available(UserTier.FREE, LocationType.DUNGEON) is False
def test_any_location_matches_all(self, premium_action):
"""Action with ANY location should be available everywhere."""
assert premium_action.is_available(UserTier.PREMIUM, LocationType.TOWN) is True
assert premium_action.is_available(UserTier.PREMIUM, LocationType.WILDERNESS) is True
assert premium_action.is_available(UserTier.PREMIUM, LocationType.DUNGEON) is True
assert premium_action.is_available(UserTier.PREMIUM, LocationType.LIBRARY) is True
def test_both_tier_and_location_must_match(self, free_action):
"""Both tier and location requirements must be met."""
# Wrong location, right tier
assert free_action.is_available(UserTier.ELITE, LocationType.DUNGEON) is False
# Lock status tests
def test_is_locked_for_lower_tier(self, premium_action):
"""Action should be locked for lower tier users."""
assert premium_action.is_locked(UserTier.FREE) is True
assert premium_action.is_locked(UserTier.BASIC) is True
def test_is_not_locked_for_sufficient_tier(self, premium_action):
"""Action should not be locked for sufficient tier users."""
assert premium_action.is_locked(UserTier.PREMIUM) is False
assert premium_action.is_locked(UserTier.ELITE) is False
def test_get_lock_reason_returns_message(self, premium_action):
"""Lock reason should explain tier requirement."""
reason = premium_action.get_lock_reason(UserTier.FREE)
assert reason is not None
assert "Premium" in reason
def test_get_lock_reason_returns_none_when_unlocked(self, premium_action):
"""Lock reason should be None when action is unlocked."""
reason = premium_action.get_lock_reason(UserTier.PREMIUM)
assert reason is None
# Tier hierarchy tests
def test_tier_hierarchy_free_to_elite(self):
"""Test full tier hierarchy from FREE to ELITE."""
action = ActionPrompt(
prompt_id="test",
category=ActionCategory.EXPLORE,
display_text="Test",
description="Test action",
tier_required=UserTier.BASIC,
context_filter=[LocationType.ANY],
dm_prompt_template="Test",
)
# FREE < BASIC (should fail)
assert action.is_available(UserTier.FREE, LocationType.TOWN) is False
# BASIC >= BASIC (should pass)
assert action.is_available(UserTier.BASIC, LocationType.TOWN) is True
# PREMIUM > BASIC (should pass)
assert action.is_available(UserTier.PREMIUM, LocationType.TOWN) is True
# ELITE > BASIC (should pass)
assert action.is_available(UserTier.ELITE, LocationType.TOWN) is True
# Serialization tests
def test_to_dict(self, free_action):
"""Test serialization to dictionary."""
data = free_action.to_dict()
assert data["prompt_id"] == "ask_locals"
assert data["category"] == "ask_question"
assert data["display_text"] == "Ask locals for information"
assert data["tier_required"] == "free"
assert data["context_filter"] == ["town", "tavern"]
assert "dm_prompt_template" in data
def test_from_dict(self):
"""Test deserialization from dictionary."""
data = {
"prompt_id": "explore_area",
"category": "explore",
"display_text": "Explore the area",
"description": "Look around for points of interest",
"tier_required": "free",
"context_filter": ["wilderness", "dungeon"],
"dm_prompt_template": "The player explores {{ location }}.",
"icon": "compass",
"cooldown_turns": 2,
}
action = ActionPrompt.from_dict(data)
assert action.prompt_id == "explore_area"
assert action.category == ActionCategory.EXPLORE
assert action.tier_required == UserTier.FREE
assert LocationType.WILDERNESS in action.context_filter
assert action.icon == "compass"
assert action.cooldown_turns == 2
def test_round_trip_serialization(self, free_action):
"""Test that to_dict and from_dict are inverse operations."""
data = free_action.to_dict()
restored = ActionPrompt.from_dict(data)
assert restored.prompt_id == free_action.prompt_id
assert restored.category == free_action.category
assert restored.display_text == free_action.display_text
assert restored.tier_required == free_action.tier_required
assert restored.context_filter == free_action.context_filter
def test_from_dict_invalid_category(self):
"""Test error handling for invalid category."""
data = {
"prompt_id": "test",
"category": "invalid_category",
"display_text": "Test",
"description": "Test",
"tier_required": "free",
"context_filter": ["any"],
"dm_prompt_template": "Test",
}
with pytest.raises(ValueError) as exc_info:
ActionPrompt.from_dict(data)
assert "Invalid action category" in str(exc_info.value)
def test_from_dict_invalid_tier(self):
"""Test error handling for invalid tier."""
data = {
"prompt_id": "test",
"category": "explore",
"display_text": "Test",
"description": "Test",
"tier_required": "super_premium",
"context_filter": ["any"],
"dm_prompt_template": "Test",
}
with pytest.raises(ValueError) as exc_info:
ActionPrompt.from_dict(data)
assert "Invalid user tier" in str(exc_info.value)
def test_from_dict_invalid_location(self):
"""Test error handling for invalid location type."""
data = {
"prompt_id": "test",
"category": "explore",
"display_text": "Test",
"description": "Test",
"tier_required": "free",
"context_filter": ["invalid_location"],
"dm_prompt_template": "Test",
}
with pytest.raises(ValueError) as exc_info:
ActionPrompt.from_dict(data)
assert "Invalid location type" in str(exc_info.value)
# Optional fields tests
def test_optional_icon(self, free_action, premium_action):
"""Test that icon is optional."""
assert free_action.icon is None
assert premium_action.icon == "magnifying_glass"
def test_default_cooldown(self, free_action, elite_action):
"""Test default and custom cooldown values."""
assert free_action.cooldown_turns == 0
assert elite_action.cooldown_turns == 3
# Repr test
def test_repr(self, free_action):
"""Test string representation."""
repr_str = repr(free_action)
assert "ask_locals" in repr_str
assert "ask_question" in repr_str
assert "free" in repr_str
class TestActionCategory:
"""Tests for ActionCategory enum."""
def test_all_categories_defined(self):
"""Verify all expected categories exist."""
categories = [cat.value for cat in ActionCategory]
assert "ask_question" in categories
assert "travel" in categories
assert "gather_info" in categories
assert "rest" in categories
assert "interact" in categories
assert "explore" in categories
assert "special" in categories
class TestLocationType:
"""Tests for LocationType enum."""
def test_all_location_types_defined(self):
"""Verify all expected location types exist."""
locations = [loc.value for loc in LocationType]
assert "town" in locations
assert "tavern" in locations
assert "wilderness" in locations
assert "dungeon" in locations
assert "safe_area" in locations
assert "library" in locations
assert "any" in locations

View File

@@ -0,0 +1,314 @@
"""
Tests for ActionPromptLoader service
Tests loading from YAML, filtering by tier and location,
and error handling.
"""
import pytest
import tempfile
import os
from app.services.action_prompt_loader import (
ActionPromptLoader,
ActionPromptLoaderError,
ActionPromptNotFoundError,
)
from app.models.action_prompt import LocationType
from app.ai.model_selector import UserTier
class TestActionPromptLoader:
"""Tests for ActionPromptLoader service."""
@pytest.fixture(autouse=True)
def reset_singleton(self):
"""Reset singleton before each test."""
ActionPromptLoader.reset_instance()
yield
ActionPromptLoader.reset_instance()
@pytest.fixture
def sample_yaml(self):
"""Create a sample YAML file for testing."""
content = """
action_prompts:
- prompt_id: test_free
category: explore
display_text: Free Action
description: Available to all
tier_required: free
context_filter: [town, tavern]
dm_prompt_template: Test template
- prompt_id: test_premium
category: gather_info
display_text: Premium Action
description: Premium only
tier_required: premium
context_filter: [any]
dm_prompt_template: Premium template
- prompt_id: test_elite
category: special
display_text: Elite Action
description: Elite only
tier_required: elite
context_filter: [library]
dm_prompt_template: Elite template
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
f.write(content)
filepath = f.name
yield filepath
os.unlink(filepath)
@pytest.fixture
def loader(self, sample_yaml):
"""Create a loader with sample data."""
loader = ActionPromptLoader()
loader.load_from_yaml(sample_yaml)
return loader
# Loading tests
def test_load_from_yaml(self, sample_yaml):
"""Test loading prompts from YAML file."""
loader = ActionPromptLoader()
count = loader.load_from_yaml(sample_yaml)
assert count == 3
assert loader.is_loaded()
def test_load_file_not_found(self):
"""Test error when file doesn't exist."""
loader = ActionPromptLoader()
with pytest.raises(ActionPromptLoaderError) as exc_info:
loader.load_from_yaml("/nonexistent/path.yaml")
assert "not found" in str(exc_info.value)
def test_load_invalid_yaml(self):
"""Test error when YAML is malformed."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
f.write("invalid: yaml: content: [")
filepath = f.name
try:
loader = ActionPromptLoader()
with pytest.raises(ActionPromptLoaderError) as exc_info:
loader.load_from_yaml(filepath)
assert "Invalid YAML" in str(exc_info.value)
finally:
os.unlink(filepath)
def test_load_missing_key(self):
"""Test error when action_prompts key is missing."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f:
f.write("other_key: value")
filepath = f.name
try:
loader = ActionPromptLoader()
with pytest.raises(ActionPromptLoaderError) as exc_info:
loader.load_from_yaml(filepath)
assert "Missing 'action_prompts'" in str(exc_info.value)
finally:
os.unlink(filepath)
# Get methods tests
def test_get_all_actions(self, loader):
"""Test getting all loaded actions."""
actions = loader.get_all_actions()
assert len(actions) == 3
prompt_ids = [a.prompt_id for a in actions]
assert "test_free" in prompt_ids
assert "test_premium" in prompt_ids
assert "test_elite" in prompt_ids
def test_get_action_by_id(self, loader):
"""Test getting action by ID."""
action = loader.get_action_by_id("test_free")
assert action.prompt_id == "test_free"
assert action.display_text == "Free Action"
def test_get_action_by_id_not_found(self, loader):
"""Test error when action ID not found."""
with pytest.raises(ActionPromptNotFoundError):
loader.get_action_by_id("nonexistent")
# Filtering tests
def test_get_available_actions_free_tier(self, loader):
"""Test filtering for free tier user."""
actions = loader.get_available_actions(UserTier.FREE, LocationType.TOWN)
assert len(actions) == 1
assert actions[0].prompt_id == "test_free"
def test_get_available_actions_premium_tier(self, loader):
"""Test filtering for premium tier user."""
actions = loader.get_available_actions(UserTier.PREMIUM, LocationType.TOWN)
# Premium gets: test_free (town) + test_premium (any)
assert len(actions) == 2
prompt_ids = [a.prompt_id for a in actions]
assert "test_free" in prompt_ids
assert "test_premium" in prompt_ids
def test_get_available_actions_elite_tier(self, loader):
"""Test filtering for elite tier user."""
actions = loader.get_available_actions(UserTier.ELITE, LocationType.LIBRARY)
# Elite in library gets: test_premium (any) + test_elite (library)
assert len(actions) == 2
prompt_ids = [a.prompt_id for a in actions]
assert "test_premium" in prompt_ids
assert "test_elite" in prompt_ids
def test_get_available_actions_location_filter(self, loader):
"""Test that location filtering works correctly."""
# In town, no elite actions available
actions = loader.get_available_actions(UserTier.ELITE, LocationType.TOWN)
prompt_ids = [a.prompt_id for a in actions]
assert "test_elite" not in prompt_ids # Only in library
def test_get_actions_by_tier(self, loader):
"""Test getting actions by tier without location filter."""
free_actions = loader.get_actions_by_tier(UserTier.FREE)
premium_actions = loader.get_actions_by_tier(UserTier.PREMIUM)
elite_actions = loader.get_actions_by_tier(UserTier.ELITE)
assert len(free_actions) == 1
assert len(premium_actions) == 2
assert len(elite_actions) == 3
def test_get_actions_by_category(self, loader):
"""Test getting actions by category."""
explore_actions = loader.get_actions_by_category("explore")
special_actions = loader.get_actions_by_category("special")
assert len(explore_actions) == 1
assert explore_actions[0].prompt_id == "test_free"
assert len(special_actions) == 1
assert special_actions[0].prompt_id == "test_elite"
def test_get_locked_actions(self, loader):
"""Test getting locked actions for upgrade prompts."""
# Free user in library sees elite action as locked
locked = loader.get_locked_actions(UserTier.FREE, LocationType.LIBRARY)
# test_premium (any) and test_elite (library) are locked for free
assert len(locked) == 2
prompt_ids = [a.prompt_id for a in locked]
assert "test_premium" in prompt_ids
assert "test_elite" in prompt_ids
# Singleton and reload tests
def test_singleton_pattern(self, sample_yaml):
"""Test that loader is singleton."""
loader1 = ActionPromptLoader()
loader1.load_from_yaml(sample_yaml)
loader2 = ActionPromptLoader()
assert loader1 is loader2
assert loader2.is_loaded()
def test_reload(self, sample_yaml):
"""Test reloading prompts."""
loader = ActionPromptLoader()
loader.load_from_yaml(sample_yaml)
# Modify and reload
count = loader.reload(sample_yaml)
assert count == 3
def test_get_prompt_count(self, loader):
"""Test getting prompt count."""
assert loader.get_prompt_count() == 3
class TestActionPromptLoaderIntegration:
"""Integration tests with actual YAML file."""
@pytest.fixture(autouse=True)
def reset_singleton(self):
"""Reset singleton before each test."""
ActionPromptLoader.reset_instance()
yield
ActionPromptLoader.reset_instance()
def test_load_actual_yaml(self):
"""Test loading the actual action_prompts.yaml file."""
loader = ActionPromptLoader()
filepath = os.path.join(
os.path.dirname(__file__),
'..', 'app', 'data', 'action_prompts.yaml'
)
if os.path.exists(filepath):
count = loader.load_from_yaml(filepath)
# Should have 10 actions
assert count == 10
# Verify tier distribution
free_actions = loader.get_actions_by_tier(UserTier.FREE)
premium_actions = loader.get_actions_by_tier(UserTier.PREMIUM)
elite_actions = loader.get_actions_by_tier(UserTier.ELITE)
assert len(free_actions) == 4 # Only free tier
assert len(premium_actions) == 7 # Free + premium
assert len(elite_actions) == 10 # All
def test_free_tier_town_actions(self):
"""Test free tier actions in town location."""
loader = ActionPromptLoader()
filepath = os.path.join(
os.path.dirname(__file__),
'..', 'app', 'data', 'action_prompts.yaml'
)
if os.path.exists(filepath):
loader.load_from_yaml(filepath)
actions = loader.get_available_actions(UserTier.FREE, LocationType.TOWN)
# Free user in town should have:
# - ask_locals (town/tavern)
# - search_supplies (any)
# - rest_recover (town/tavern/safe_area)
prompt_ids = [a.prompt_id for a in actions]
assert "ask_locals" in prompt_ids
assert "search_supplies" in prompt_ids
assert "rest_recover" in prompt_ids
def test_premium_tier_wilderness_actions(self):
"""Test premium tier actions in wilderness location."""
loader = ActionPromptLoader()
filepath = os.path.join(
os.path.dirname(__file__),
'..', 'app', 'data', 'action_prompts.yaml'
)
if os.path.exists(filepath):
loader.load_from_yaml(filepath)
actions = loader.get_available_actions(UserTier.PREMIUM, LocationType.WILDERNESS)
prompt_ids = [a.prompt_id for a in actions]
# Should include wilderness actions
assert "explore_area" in prompt_ids
assert "make_camp" in prompt_ids
assert "search_supplies" in prompt_ids

571
api/tests/test_ai_tasks.py Normal file
View File

@@ -0,0 +1,571 @@
"""
Unit tests for AI Task Jobs.
These tests verify the job enqueueing, status tracking,
and result storage functionality.
"""
import pytest
from unittest.mock import Mock, patch, MagicMock
import json
from app.tasks.ai_tasks import (
enqueue_ai_task,
process_ai_task,
get_job_status,
get_job_result,
JobStatus,
TaskType,
TaskPriority,
_store_job_status,
_update_job_status,
_store_job_result,
)
class TestEnqueueAITask:
"""Test AI task enqueueing."""
@patch('app.tasks.ai_tasks.get_queue')
@patch('app.tasks.ai_tasks._store_job_status')
def test_enqueue_narrative_task(self, mock_store_status, mock_get_queue):
"""Test enqueueing a narrative task."""
# Setup mock queue
mock_queue = MagicMock()
mock_job = MagicMock()
mock_job.id = "test_job_123"
mock_queue.enqueue.return_value = mock_job
mock_get_queue.return_value = mock_queue
# Enqueue task
result = enqueue_ai_task(
task_type="narrative",
user_id="user_123",
context={"action": "explore"},
priority="normal",
)
# Verify
assert "job_id" in result
assert result["status"] == "queued"
mock_queue.enqueue.assert_called_once()
mock_store_status.assert_called_once()
@patch('app.tasks.ai_tasks.get_queue')
@patch('app.tasks.ai_tasks._store_job_status')
def test_enqueue_high_priority_task(self, mock_store_status, mock_get_queue):
"""Test high priority task goes to front of queue."""
mock_queue = MagicMock()
mock_job = MagicMock()
mock_queue.enqueue.return_value = mock_job
mock_get_queue.return_value = mock_queue
enqueue_ai_task(
task_type="narrative",
user_id="user_123",
context={},
priority="high",
)
# Verify at_front=True for high priority
call_kwargs = mock_queue.enqueue.call_args[1]
assert call_kwargs["at_front"] is True
@patch('app.tasks.ai_tasks.get_queue')
@patch('app.tasks.ai_tasks._store_job_status')
def test_enqueue_with_session_and_character(self, mock_store_status, mock_get_queue):
"""Test enqueueing with session and character IDs."""
mock_queue = MagicMock()
mock_job = MagicMock()
mock_queue.enqueue.return_value = mock_job
mock_get_queue.return_value = mock_queue
result = enqueue_ai_task(
task_type="combat",
user_id="user_123",
context={"enemy": "goblin"},
session_id="sess_456",
character_id="char_789",
)
assert "job_id" in result
# Verify kwargs passed to enqueue include session and character
call_kwargs = mock_queue.enqueue.call_args[1]
job_kwargs = call_kwargs["kwargs"]
assert job_kwargs["session_id"] == "sess_456"
assert job_kwargs["character_id"] == "char_789"
def test_enqueue_invalid_task_type(self):
"""Test enqueueing with invalid task type raises error."""
with pytest.raises(ValueError) as exc_info:
enqueue_ai_task(
task_type="invalid_type",
user_id="user_123",
context={},
)
assert "Invalid task_type" in str(exc_info.value)
def test_enqueue_invalid_priority(self):
"""Test enqueueing with invalid priority raises error."""
with pytest.raises(ValueError) as exc_info:
enqueue_ai_task(
task_type="narrative",
user_id="user_123",
context={},
priority="invalid_priority",
)
assert "Invalid priority" in str(exc_info.value)
class TestProcessAITask:
"""Test AI task processing with mocked NarrativeGenerator."""
@pytest.fixture
def sample_character(self):
"""Sample character data for tests."""
return {
"name": "Aldric",
"level": 3,
"player_class": "Fighter",
"current_hp": 25,
"max_hp": 30,
"stats": {"strength": 16, "dexterity": 12},
"skills": [],
"effects": []
}
@pytest.fixture
def sample_game_state(self):
"""Sample game state for tests."""
return {
"current_location": "Tavern",
"location_type": "TAVERN",
"discovered_locations": [],
"active_quests": []
}
@pytest.fixture
def mock_narrative_response(self):
"""Mock NarrativeResponse for tests."""
from app.ai.narrative_generator import NarrativeResponse
return NarrativeResponse(
narrative="You enter the tavern...",
tokens_used=150,
model="meta/meta-llama-3-8b-instruct",
context_type="story_progression",
generation_time=2.5
)
@patch('app.tasks.ai_tasks._get_user_tier')
@patch('app.tasks.ai_tasks._update_game_session')
@patch('app.tasks.ai_tasks.NarrativeGenerator')
@patch('app.tasks.ai_tasks._update_job_status')
@patch('app.tasks.ai_tasks._store_job_result')
def test_process_narrative_task(
self, mock_store_result, mock_update_status, mock_generator_class,
mock_update_session, mock_get_tier, sample_character, sample_game_state, mock_narrative_response
):
"""Test processing a narrative task with NarrativeGenerator."""
from app.ai.model_selector import UserTier
# Setup mocks
mock_get_tier.return_value = UserTier.FREE
mock_generator = MagicMock()
mock_generator.generate_story_response.return_value = mock_narrative_response
mock_generator_class.return_value = mock_generator
# Process task
result = process_ai_task(
task_type="narrative",
user_id="user_123",
context={
"action": "I explore the tavern",
"character": sample_character,
"game_state": sample_game_state
},
job_id="job_123",
session_id="sess_456",
character_id="char_789"
)
# Verify result structure
assert "narrative" in result
assert result["narrative"] == "You enter the tavern..."
assert result["tokens_used"] == 150
assert result["model"] == "meta/meta-llama-3-8b-instruct"
# Verify generator was called
mock_generator.generate_story_response.assert_called_once()
# Verify session update was called
mock_update_session.assert_called_once()
@patch('app.tasks.ai_tasks._get_user_tier')
@patch('app.tasks.ai_tasks.NarrativeGenerator')
@patch('app.tasks.ai_tasks._update_job_status')
@patch('app.tasks.ai_tasks._store_job_result')
def test_process_combat_task(
self, mock_store_result, mock_update_status, mock_generator_class,
mock_get_tier, sample_character, mock_narrative_response
):
"""Test processing a combat task."""
from app.ai.model_selector import UserTier
mock_get_tier.return_value = UserTier.BASIC
mock_generator = MagicMock()
mock_generator.generate_combat_narration.return_value = mock_narrative_response
mock_generator_class.return_value = mock_generator
result = process_ai_task(
task_type="combat",
user_id="user_123",
context={
"character": sample_character,
"combat_state": {"round_number": 1, "enemies": [], "current_turn": "player"},
"action": "swings sword",
"action_result": {"hit": True, "damage": 10}
},
job_id="job_123",
)
assert "combat_narrative" in result
mock_generator.generate_combat_narration.assert_called_once()
@patch('app.tasks.ai_tasks._get_user_tier')
@patch('app.tasks.ai_tasks.NarrativeGenerator')
@patch('app.tasks.ai_tasks._update_job_status')
@patch('app.tasks.ai_tasks._store_job_result')
def test_process_quest_selection_task(
self, mock_store_result, mock_update_status, mock_generator_class,
mock_get_tier, sample_character, sample_game_state
):
"""Test processing a quest selection task."""
from app.ai.model_selector import UserTier
mock_get_tier.return_value = UserTier.FREE
mock_generator = MagicMock()
mock_generator.generate_quest_selection.return_value = "goblin_cave"
mock_generator_class.return_value = mock_generator
result = process_ai_task(
task_type="quest_selection",
user_id="user_123",
context={
"character": sample_character,
"eligible_quests": [{"quest_id": "goblin_cave"}],
"game_context": sample_game_state
},
job_id="job_123",
)
assert result["selected_quest_id"] == "goblin_cave"
mock_generator.generate_quest_selection.assert_called_once()
@patch('app.tasks.ai_tasks._get_user_tier')
@patch('app.tasks.ai_tasks.NarrativeGenerator')
@patch('app.tasks.ai_tasks._update_job_status')
@patch('app.tasks.ai_tasks._store_job_result')
def test_process_npc_dialogue_task(
self, mock_store_result, mock_update_status, mock_generator_class,
mock_get_tier, sample_character, sample_game_state, mock_narrative_response
):
"""Test processing an NPC dialogue task."""
from app.ai.model_selector import UserTier
mock_get_tier.return_value = UserTier.PREMIUM
mock_generator = MagicMock()
mock_generator.generate_npc_dialogue.return_value = mock_narrative_response
mock_generator_class.return_value = mock_generator
result = process_ai_task(
task_type="npc_dialogue",
user_id="user_123",
context={
"character": sample_character,
"npc": {"name": "Innkeeper", "role": "Barkeeper"},
"conversation_topic": "What's the news?",
"game_state": sample_game_state
},
job_id="job_123",
)
assert "dialogue" in result
mock_generator.generate_npc_dialogue.assert_called_once()
@patch('app.tasks.ai_tasks._update_job_status')
def test_process_task_failure_invalid_type(self, mock_update_status):
"""Test task processing failure with invalid task type."""
with pytest.raises(ValueError):
process_ai_task(
task_type="invalid",
user_id="user_123",
context={},
job_id="job_123",
)
# Verify status was updated to PROCESSING then FAILED
calls = mock_update_status.call_args_list
assert any(call[0][1] == JobStatus.FAILED for call in calls)
@patch('app.tasks.ai_tasks._get_user_tier')
@patch('app.tasks.ai_tasks._update_job_status')
def test_process_task_missing_context_fields(self, mock_update_status, mock_get_tier):
"""Test task processing failure with missing required context fields."""
from app.ai.model_selector import UserTier
mock_get_tier.return_value = UserTier.FREE
with pytest.raises(ValueError) as exc_info:
process_ai_task(
task_type="narrative",
user_id="user_123",
context={"action": "test"}, # Missing character and game_state
job_id="job_123",
)
assert "Missing required context field" in str(exc_info.value)
class TestJobStatus:
"""Test job status retrieval."""
@patch('app.tasks.ai_tasks.RedisService')
def test_get_job_status_from_cache(self, mock_redis_class):
"""Test getting job status from Redis cache."""
mock_redis = MagicMock()
mock_redis.get_json.return_value = {
"job_id": "job_123",
"status": "completed",
"created_at": "2025-11-21T10:00:00Z",
}
mock_redis_class.return_value = mock_redis
result = get_job_status("job_123")
assert result["job_id"] == "job_123"
assert result["status"] == "completed"
@patch('app.tasks.ai_tasks.RedisService')
@patch('app.tasks.ai_tasks.get_redis_connection')
@patch('app.tasks.ai_tasks.Job')
def test_get_job_status_from_rq(self, mock_job_class, mock_get_conn, mock_redis_class):
"""Test getting job status from RQ when not in cache."""
# No cached status
mock_redis = MagicMock()
mock_redis.get_json.return_value = None
mock_redis_class.return_value = mock_redis
# RQ job
mock_job = MagicMock()
mock_job.is_finished = True
mock_job.is_failed = False
mock_job.is_started = False
mock_job.created_at = None
mock_job.started_at = None
mock_job.ended_at = None
mock_job_class.fetch.return_value = mock_job
result = get_job_status("job_123")
assert result["status"] == "completed"
@patch('app.tasks.ai_tasks.RedisService')
@patch('app.tasks.ai_tasks.get_redis_connection')
@patch('app.tasks.ai_tasks.Job')
def test_get_job_status_not_found(self, mock_job_class, mock_get_conn, mock_redis_class):
"""Test getting status of non-existent job."""
mock_redis = MagicMock()
mock_redis.get_json.return_value = None
mock_redis_class.return_value = mock_redis
mock_job_class.fetch.side_effect = Exception("Job not found")
result = get_job_status("nonexistent_job")
assert result["status"] == "unknown"
assert "error" in result
class TestJobResult:
"""Test job result retrieval."""
@patch('app.tasks.ai_tasks.RedisService')
def test_get_job_result_from_cache(self, mock_redis_class):
"""Test getting job result from Redis cache."""
expected_result = {
"narrative": "You enter the tavern...",
"tokens_used": 450,
}
mock_redis = MagicMock()
mock_redis.get_json.return_value = expected_result
mock_redis_class.return_value = mock_redis
result = get_job_result("job_123")
assert result == expected_result
@patch('app.tasks.ai_tasks.RedisService')
@patch('app.tasks.ai_tasks.get_redis_connection')
@patch('app.tasks.ai_tasks.Job')
def test_get_job_result_from_rq(self, mock_job_class, mock_get_conn, mock_redis_class):
"""Test getting job result from RQ when not in cache."""
expected_result = {"narrative": "You find a sword..."}
mock_redis = MagicMock()
mock_redis.get_json.return_value = None
mock_redis_class.return_value = mock_redis
mock_job = MagicMock()
mock_job.is_finished = True
mock_job.result = expected_result
mock_job_class.fetch.return_value = mock_job
result = get_job_result("job_123")
assert result == expected_result
@patch('app.tasks.ai_tasks.RedisService')
@patch('app.tasks.ai_tasks.get_redis_connection')
@patch('app.tasks.ai_tasks.Job')
def test_get_job_result_not_found(self, mock_job_class, mock_get_conn, mock_redis_class):
"""Test getting result of non-existent job."""
mock_redis = MagicMock()
mock_redis.get_json.return_value = None
mock_redis_class.return_value = mock_redis
mock_job_class.fetch.side_effect = Exception("Job not found")
result = get_job_result("nonexistent_job")
assert result is None
class TestStatusStorage:
"""Test job status storage functions."""
@patch('app.tasks.ai_tasks.RedisService')
def test_store_job_status(self, mock_redis_class):
"""Test storing initial job status."""
mock_redis = MagicMock()
mock_redis_class.return_value = mock_redis
_store_job_status(
job_id="job_123",
status=JobStatus.QUEUED,
task_type="narrative",
user_id="user_456",
)
mock_redis.set_json.assert_called_once()
call_args = mock_redis.set_json.call_args
stored_data = call_args[0][1]
assert stored_data["job_id"] == "job_123"
assert stored_data["status"] == "queued"
assert stored_data["task_type"] == "narrative"
assert stored_data["user_id"] == "user_456"
@patch('app.tasks.ai_tasks.RedisService')
def test_update_job_status_to_processing(self, mock_redis_class):
"""Test updating job status to processing."""
mock_redis = MagicMock()
mock_redis.get_json.return_value = {
"job_id": "job_123",
"status": "queued",
"created_at": "2025-11-21T10:00:00Z",
}
mock_redis_class.return_value = mock_redis
_update_job_status("job_123", JobStatus.PROCESSING)
mock_redis.set_json.assert_called_once()
call_args = mock_redis.set_json.call_args
updated_data = call_args[0][1]
assert updated_data["status"] == "processing"
assert updated_data["started_at"] is not None
@patch('app.tasks.ai_tasks.RedisService')
def test_update_job_status_to_completed(self, mock_redis_class):
"""Test updating job status to completed with result."""
mock_redis = MagicMock()
mock_redis.get_json.return_value = {
"job_id": "job_123",
"status": "processing",
}
mock_redis_class.return_value = mock_redis
result_data = {"narrative": "Test result"}
_update_job_status("job_123", JobStatus.COMPLETED, result=result_data)
call_args = mock_redis.set_json.call_args
updated_data = call_args[0][1]
assert updated_data["status"] == "completed"
assert updated_data["completed_at"] is not None
assert updated_data["result"] == result_data
@patch('app.tasks.ai_tasks.RedisService')
def test_update_job_status_to_failed(self, mock_redis_class):
"""Test updating job status to failed with error."""
mock_redis = MagicMock()
mock_redis.get_json.return_value = {
"job_id": "job_123",
"status": "processing",
}
mock_redis_class.return_value = mock_redis
_update_job_status("job_123", JobStatus.FAILED, error="Something went wrong")
call_args = mock_redis.set_json.call_args
updated_data = call_args[0][1]
assert updated_data["status"] == "failed"
assert updated_data["error"] == "Something went wrong"
@patch('app.tasks.ai_tasks.RedisService')
def test_store_job_result(self, mock_redis_class):
"""Test storing job result."""
mock_redis = MagicMock()
mock_redis_class.return_value = mock_redis
result_data = {
"narrative": "You discover a hidden passage...",
"tokens_used": 500,
}
_store_job_result("job_123", result_data)
mock_redis.set_json.assert_called_once()
call_args = mock_redis.set_json.call_args
assert call_args[0][0] == "ai_job_result:job_123"
assert call_args[0][1] == result_data
assert call_args[1]["ttl"] == 3600 # 1 hour
class TestTaskTypes:
"""Test task type and priority enums."""
def test_task_type_values(self):
"""Test all task types are defined."""
assert TaskType.NARRATIVE.value == "narrative"
assert TaskType.COMBAT.value == "combat"
assert TaskType.QUEST_SELECTION.value == "quest_selection"
assert TaskType.NPC_DIALOGUE.value == "npc_dialogue"
def test_task_priority_values(self):
"""Test all priorities are defined."""
assert TaskPriority.LOW.value == "low"
assert TaskPriority.NORMAL.value == "normal"
assert TaskPriority.HIGH.value == "high"
def test_job_status_values(self):
"""Test all job statuses are defined."""
assert JobStatus.QUEUED.value == "queued"
assert JobStatus.PROCESSING.value == "processing"
assert JobStatus.COMPLETED.value == "completed"
assert JobStatus.FAILED.value == "failed"

View File

@@ -0,0 +1,579 @@
"""
Integration tests for Character API endpoints.
These tests verify the complete character management API flow including:
- List characters
- Get character details
- Create character (with tier limits)
- Delete character
- Unlock skills
- Respec skills
- Get classes and origins (reference data)
Tests use Flask test client with mocked authentication and database layers.
"""
import pytest
import json
from unittest.mock import Mock, patch, MagicMock
from datetime import datetime
from app.models.character import Character
from app.models.stats import Stats
from app.models.skills import PlayerClass, SkillTree, SkillNode
from app.models.origins import Origin, StartingLocation, StartingBonus
from app.services.character_service import (
CharacterLimitExceeded,
CharacterNotFound,
SkillUnlockError,
InsufficientGold
)
from app.services.database_service import DatabaseDocument
class TestCharacterAPIIntegration:
"""Integration tests for Character API endpoints."""
@pytest.fixture
def app(self):
"""Create Flask application for testing."""
# Import here to avoid circular imports
from app import create_app
app = create_app()
app.config['TESTING'] = True
return app
@pytest.fixture
def client(self, app):
"""Create test client."""
return app.test_client()
@pytest.fixture(autouse=True)
def mock_auth_decorator(self):
"""Mock the require_auth decorator to bypass authentication."""
def pass_through_decorator(func):
"""Decorator that does nothing - passes through to the function."""
return func
with patch('app.api.characters.require_auth', side_effect=pass_through_decorator):
yield
@pytest.fixture
def mock_user(self):
"""Create a mock authenticated user."""
user = Mock()
user.id = "test_user_123"
user.email = "test@example.com"
user.name = "Test User"
user.tier = "free"
user.email_verified = True
return user
@pytest.fixture
def sample_class(self):
"""Create a sample player class."""
base_stats = Stats(strength=14, dexterity=10, constitution=14,
intelligence=8, wisdom=10, charisma=9)
skill_nodes = [
SkillNode(
skill_id="shield_wall",
name="Shield Wall",
description="Increase armor by 5",
tier=1,
prerequisites=[],
effects={"armor": 5}
),
SkillNode(
skill_id="toughness",
name="Toughness",
description="Increase HP by 10",
tier=2,
prerequisites=["shield_wall"],
effects={"hit_points": 10}
)
]
skill_tree = SkillTree(
tree_id="shield_bearer",
name="Shield Bearer",
description="Defensive techniques",
nodes=skill_nodes
)
return PlayerClass(
class_id="vanguard",
name="Vanguard",
description="Armored warrior",
base_stats=base_stats,
skill_trees=[skill_tree],
starting_equipment=["Rusty Sword", "Tattered Cloth Armor"],
starting_abilities=[]
)
@pytest.fixture
def sample_origin(self):
"""Create a sample origin."""
starting_location = StartingLocation(
id="forgotten_crypt",
name="The Forgotten Crypt",
region="The Deadlands",
description="A crumbling stone tomb beneath a dead forest"
)
starting_bonus = StartingBonus(
type="stat",
value={"constitution": 1}
)
return Origin(
id="soul_revenant",
name="The Soul Revenant",
description="You died. That much you remember...",
starting_location=starting_location,
narrative_hooks=["Who brought you back?"],
starting_bonus=starting_bonus
)
@pytest.fixture
def sample_character(self, sample_class, sample_origin):
"""Create a sample character."""
return Character(
character_id="char_123",
user_id="test_user_123",
name="Thorin Ironforge",
player_class=sample_class,
origin=sample_origin,
level=5,
experience=2400,
base_stats=Stats(strength=16, dexterity=10, constitution=14,
intelligence=8, wisdom=12, charisma=10),
unlocked_skills=["shield_wall"],
inventory=[],
equipped={},
gold=500,
active_quests=[],
discovered_locations=["forgotten_crypt"],
current_location="forgotten_crypt"
)
# ===== LIST CHARACTERS =====
@patch('app.api.characters.get_current_user')
@patch('app.api.characters.get_character_service')
def test_list_characters_success(self, mock_get_service, mock_get_user,
client, mock_user, sample_character):
"""Test listing user's characters."""
# Setup
mock_get_user.return_value = mock_user
mock_service = Mock()
mock_service.get_user_characters.return_value = [sample_character]
mock_get_service.return_value = mock_service
# Execute
response = client.get('/api/v1/characters')
# Verify
assert response.status_code == 200
data = json.loads(response.data)
assert data['status'] == 200
assert len(data['result']['characters']) == 1
assert data['result']['characters'][0]['name'] == "Thorin Ironforge"
assert data['result']['characters'][0]['class'] == "vanguard"
assert data['result']['count'] == 1
assert data['result']['tier'] == "free"
assert data['result']['limit'] == 1
@patch('app.api.characters.get_current_user')
def test_list_characters_unauthorized(self, mock_get_user, client):
"""Test listing characters without authentication."""
# Setup - simulate no user logged in
mock_get_user.return_value = None
# Execute
response = client.get('/api/v1/characters')
# Verify
assert response.status_code == 401
# ===== GET CHARACTER =====
@patch('app.api.characters.get_current_user')
@patch('app.api.characters.get_character_service')
def test_get_character_success(self, mock_get_service, mock_get_user,
client, mock_user, sample_character):
"""Test getting a single character."""
# Setup
mock_get_user.return_value = mock_user
mock_service = Mock()
mock_service.get_character.return_value = sample_character
mock_get_service.return_value = mock_service
# Execute
response = client.get('/api/v1/characters/char_123')
# Verify
assert response.status_code == 200
data = json.loads(response.data)
assert data['status'] == 200
assert data['result']['name'] == "Thorin Ironforge"
assert data['result']['level'] == 5
assert data['result']['gold'] == 500
@patch('app.api.characters.get_current_user')
@patch('app.api.characters.get_character_service')
def test_get_character_not_found(self, mock_get_service, mock_get_user,
client, mock_user):
"""Test getting a non-existent character."""
# Setup
mock_get_user.return_value = mock_user
mock_service = Mock()
mock_service.get_character.side_effect = CharacterNotFound("Character not found")
mock_get_service.return_value = mock_service
# Execute
response = client.get('/api/v1/characters/char_999')
# Verify
assert response.status_code == 404
data = json.loads(response.data)
assert 'error' in data
# ===== CREATE CHARACTER =====
@patch('app.api.characters.get_current_user')
@patch('app.api.characters.get_character_service')
def test_create_character_success(self, mock_get_service, mock_get_user,
client, mock_user, sample_character):
"""Test creating a new character."""
# Setup
mock_get_user.return_value = mock_user
mock_service = Mock()
mock_service.create_character.return_value = sample_character
mock_get_service.return_value = mock_service
# Execute
response = client.post('/api/v1/characters',
json={
'name': 'Thorin Ironforge',
'class_id': 'vanguard',
'origin_id': 'soul_revenant'
})
# Verify
assert response.status_code == 201
data = json.loads(response.data)
assert data['status'] == 201
assert data['result']['name'] == "Thorin Ironforge"
assert data['result']['class'] == "vanguard"
assert data['result']['origin'] == "soul_revenant"
@patch('app.api.characters.get_current_user')
@patch('app.api.characters.get_character_service')
def test_create_character_limit_exceeded(self, mock_get_service, mock_get_user,
client, mock_user):
"""Test creating a character when tier limit is reached."""
# Setup
mock_get_user.return_value = mock_user
mock_service = Mock()
mock_service.create_character.side_effect = CharacterLimitExceeded(
"Character limit reached for free tier (1/1)"
)
mock_get_service.return_value = mock_service
# Execute
response = client.post('/api/v1/characters',
json={
'name': 'Thorin Ironforge',
'class_id': 'vanguard',
'origin_id': 'soul_revenant'
})
# Verify
assert response.status_code == 400
data = json.loads(response.data)
assert 'error' in data
assert data['error']['code'] == 'CHARACTER_LIMIT_EXCEEDED'
def test_create_character_validation_errors(self, client, mock_user):
"""Test character creation with invalid input."""
with patch('app.api.characters.get_current_user', return_value=mock_user):
# Test missing name
response = client.post('/api/v1/characters',
json={
'class_id': 'vanguard',
'origin_id': 'soul_revenant'
})
assert response.status_code == 400
# Test invalid class_id
response = client.post('/api/v1/characters',
json={
'name': 'Test',
'class_id': 'invalid_class',
'origin_id': 'soul_revenant'
})
assert response.status_code == 400
# Test invalid origin_id
response = client.post('/api/v1/characters',
json={
'name': 'Test',
'class_id': 'vanguard',
'origin_id': 'invalid_origin'
})
assert response.status_code == 400
# Test name too short
response = client.post('/api/v1/characters',
json={
'name': 'T',
'class_id': 'vanguard',
'origin_id': 'soul_revenant'
})
assert response.status_code == 400
# ===== DELETE CHARACTER =====
@patch('app.api.characters.get_current_user')
@patch('app.api.characters.get_character_service')
def test_delete_character_success(self, mock_get_service, mock_get_user,
client, mock_user):
"""Test deleting a character."""
# Setup
mock_get_user.return_value = mock_user
mock_service = Mock()
mock_service.delete_character.return_value = True
mock_get_service.return_value = mock_service
# Execute
response = client.delete('/api/v1/characters/char_123')
# Verify
assert response.status_code == 200
data = json.loads(response.data)
assert data['status'] == 200
assert 'deleted successfully' in data['result']['message']
@patch('app.api.characters.get_current_user')
@patch('app.api.characters.get_character_service')
def test_delete_character_not_found(self, mock_get_service, mock_get_user,
client, mock_user):
"""Test deleting a non-existent character."""
# Setup
mock_get_user.return_value = mock_user
mock_service = Mock()
mock_service.delete_character.side_effect = CharacterNotFound("Character not found")
mock_get_service.return_value = mock_service
# Execute
response = client.delete('/api/v1/characters/char_999')
# Verify
assert response.status_code == 404
# ===== UNLOCK SKILL =====
@patch('app.api.characters.get_current_user')
@patch('app.api.characters.get_character_service')
def test_unlock_skill_success(self, mock_get_service, mock_get_user,
client, mock_user, sample_character):
"""Test unlocking a skill."""
# Setup
mock_get_user.return_value = mock_user
sample_character.unlocked_skills = ["shield_wall", "toughness"]
mock_service = Mock()
mock_service.unlock_skill.return_value = sample_character
mock_get_service.return_value = mock_service
# Execute
response = client.post('/api/v1/characters/char_123/skills/unlock',
json={'skill_id': 'toughness'})
# Verify
assert response.status_code == 200
data = json.loads(response.data)
assert data['status'] == 200
assert data['result']['skill_id'] == 'toughness'
assert 'toughness' in data['result']['unlocked_skills']
@patch('app.api.characters.get_current_user')
@patch('app.api.characters.get_character_service')
def test_unlock_skill_prerequisites_not_met(self, mock_get_service, mock_get_user,
client, mock_user):
"""Test unlocking a skill without meeting prerequisites."""
# Setup
mock_get_user.return_value = mock_user
mock_service = Mock()
mock_service.unlock_skill.side_effect = SkillUnlockError(
"Prerequisite not met: shield_wall required for toughness"
)
mock_get_service.return_value = mock_service
# Execute
response = client.post('/api/v1/characters/char_123/skills/unlock',
json={'skill_id': 'toughness'})
# Verify
assert response.status_code == 400
data = json.loads(response.data)
assert 'error' in data
assert data['error']['code'] == 'SKILL_UNLOCK_ERROR'
@patch('app.api.characters.get_current_user')
@patch('app.api.characters.get_character_service')
def test_unlock_skill_no_points(self, mock_get_service, mock_get_user,
client, mock_user):
"""Test unlocking a skill without available skill points."""
# Setup
mock_get_user.return_value = mock_user
mock_service = Mock()
mock_service.unlock_skill.side_effect = SkillUnlockError(
"No skill points available (Level 1, 1 skills unlocked)"
)
mock_get_service.return_value = mock_service
# Execute
response = client.post('/api/v1/characters/char_123/skills/unlock',
json={'skill_id': 'shield_wall'})
# Verify
assert response.status_code == 400
data = json.loads(response.data)
assert 'error' in data
# ===== RESPEC SKILLS =====
@patch('app.api.characters.get_current_user')
@patch('app.api.characters.get_character_service')
def test_respec_skills_success(self, mock_get_service, mock_get_user,
client, mock_user, sample_character):
"""Test respeccing character skills."""
# Setup
mock_get_user.return_value = mock_user
# Get character returns current state
char_before = sample_character
char_before.unlocked_skills = ["shield_wall"]
char_before.gold = 500
# After respec, skills cleared and gold reduced
char_after = sample_character
char_after.unlocked_skills = []
char_after.gold = 0 # 500 - (5 * 100)
mock_service = Mock()
mock_service.get_character.return_value = char_before
mock_service.respec_skills.return_value = char_after
mock_get_service.return_value = mock_service
# Execute
response = client.post('/api/v1/characters/char_123/skills/respec')
# Verify
assert response.status_code == 200
data = json.loads(response.data)
assert data['status'] == 200
assert data['result']['cost'] == 500 # level 5 * 100
assert data['result']['remaining_gold'] == 0
assert data['result']['available_points'] == 5
@patch('app.api.characters.get_current_user')
@patch('app.api.characters.get_character_service')
def test_respec_skills_insufficient_gold(self, mock_get_service, mock_get_user,
client, mock_user, sample_character):
"""Test respeccing without enough gold."""
# Setup
mock_get_user.return_value = mock_user
sample_character.gold = 100 # Not enough for level 5 respec (needs 500)
mock_service = Mock()
mock_service.get_character.return_value = sample_character
mock_service.respec_skills.side_effect = InsufficientGold(
"Insufficient gold for respec. Cost: 500, Available: 100"
)
mock_get_service.return_value = mock_service
# Execute
response = client.post('/api/v1/characters/char_123/skills/respec')
# Verify
assert response.status_code == 400
data = json.loads(response.data)
assert 'error' in data
assert data['error']['code'] == 'INSUFFICIENT_GOLD'
# ===== CLASSES ENDPOINTS (REFERENCE DATA) =====
@patch('app.api.characters.get_class_loader')
def test_list_classes(self, mock_get_loader, client, sample_class):
"""Test listing all character classes."""
# Setup
mock_loader = Mock()
mock_loader.get_all_class_ids.return_value = ['vanguard', 'arcanist']
mock_loader.load_class.return_value = sample_class
mock_get_loader.return_value = mock_loader
# Execute
response = client.get('/api/v1/classes')
# Verify
assert response.status_code == 200
data = json.loads(response.data)
assert data['status'] == 200
assert len(data['result']['classes']) == 2
assert data['result']['count'] == 2
@patch('app.api.characters.get_class_loader')
def test_get_class_details(self, mock_get_loader, client, sample_class):
"""Test getting details of a specific class."""
# Setup
mock_loader = Mock()
mock_loader.load_class.return_value = sample_class
mock_get_loader.return_value = mock_loader
# Execute
response = client.get('/api/v1/classes/vanguard')
# Verify
assert response.status_code == 200
data = json.loads(response.data)
assert data['status'] == 200
assert data['result']['class_id'] == 'vanguard'
assert data['result']['name'] == 'Vanguard'
assert len(data['result']['skill_trees']) == 1
@patch('app.api.characters.get_class_loader')
def test_get_class_not_found(self, mock_get_loader, client):
"""Test getting a non-existent class."""
# Setup
mock_loader = Mock()
mock_loader.load_class.return_value = None
mock_get_loader.return_value = mock_loader
# Execute
response = client.get('/api/v1/classes/invalid_class')
# Verify
assert response.status_code == 404
# ===== ORIGINS ENDPOINTS (REFERENCE DATA) =====
@patch('app.api.characters.get_origin_service')
def test_list_origins(self, mock_get_service, client, sample_origin):
"""Test listing all character origins."""
# Setup
mock_service = Mock()
mock_service.get_all_origin_ids.return_value = ['soul_revenant', 'memory_thief']
mock_service.load_origin.return_value = sample_origin
mock_get_service.return_value = mock_service
# Execute
response = client.get('/api/v1/origins')
# Verify
assert response.status_code == 200
data = json.loads(response.data)
assert data['status'] == 200
assert len(data['result']['origins']) == 2
assert data['result']['count'] == 2

454
api/tests/test_character.py Normal file
View File

@@ -0,0 +1,454 @@
"""
Unit tests for Character dataclass.
Tests the critical get_effective_stats() method which combines all stat modifiers,
as well as inventory, equipment, experience, and serialization.
"""
import pytest
from app.models.character import Character
from app.models.stats import Stats
from app.models.items import Item
from app.models.effects import Effect
from app.models.skills import PlayerClass, SkillTree, SkillNode
from app.models.origins import Origin, StartingLocation, StartingBonus
from app.models.enums import ItemType, EffectType, StatType, DamageType
@pytest.fixture
def basic_player_class():
"""Create a basic player class for testing."""
base_stats = Stats(strength=12, dexterity=10, constitution=14, intelligence=8, wisdom=10, charisma=11)
# Create a simple skill tree
skill_tree = SkillTree(
tree_id="warrior_offense",
name="Warrior Offense",
description="Offensive combat skills",
nodes=[
SkillNode(
skill_id="power_strike",
name="Power Strike",
description="+5 Strength",
tier=1,
effects={"strength": 5},
),
],
)
return PlayerClass(
class_id="warrior",
name="Warrior",
description="Strong melee fighter",
base_stats=base_stats,
skill_trees=[skill_tree],
starting_equipment=["basic_sword"],
starting_abilities=["basic_attack"],
)
@pytest.fixture
def basic_origin():
"""Create a basic origin for testing."""
starting_location = StartingLocation(
id="test_location",
name="Test Village",
region="Test Region",
description="A simple test location"
)
starting_bonus = StartingBonus(
trait="Test Trait",
description="A test trait for testing",
effect="+1 to all stats"
)
return Origin(
id="test_origin",
name="Test Origin",
description="A test origin for character testing",
starting_location=starting_location,
narrative_hooks=["Test hook 1", "Test hook 2"],
starting_bonus=starting_bonus
)
@pytest.fixture
def basic_character(basic_player_class, basic_origin):
"""Create a basic character for testing."""
return Character(
character_id="char_001",
user_id="user_001",
name="Test Hero",
player_class=basic_player_class,
origin=basic_origin,
level=1,
experience=0,
base_stats=basic_player_class.base_stats.copy(),
)
def test_character_creation(basic_character):
"""Test creating a Character instance."""
assert basic_character.character_id == "char_001"
assert basic_character.user_id == "user_001"
assert basic_character.name == "Test Hero"
assert basic_character.level == 1
assert basic_character.experience == 0
assert basic_character.gold == 0
def test_get_effective_stats_base_only(basic_character):
"""Test get_effective_stats() with only base stats (no modifiers)."""
effective = basic_character.get_effective_stats()
# Should match base stats exactly
assert effective.strength == 12
assert effective.dexterity == 10
assert effective.constitution == 14
assert effective.intelligence == 8
assert effective.wisdom == 10
assert effective.charisma == 11
def test_get_effective_stats_with_equipment(basic_character):
"""Test get_effective_stats() with equipped items."""
# Create a weapon with +5 strength
weapon = Item(
item_id="iron_sword",
name="Iron Sword",
item_type=ItemType.WEAPON,
description="A sturdy iron sword",
stat_bonuses={"strength": 5},
damage=10,
)
basic_character.equipped["weapon"] = weapon
effective = basic_character.get_effective_stats()
# Strength should be base (12) + weapon (5) = 17
assert effective.strength == 17
assert effective.dexterity == 10 # Unchanged
def test_get_effective_stats_with_skill_bonuses(basic_character):
"""Test get_effective_stats() with skill tree bonuses."""
# Unlock the "power_strike" skill which gives +5 strength
basic_character.unlocked_skills.append("power_strike")
effective = basic_character.get_effective_stats()
# Strength should be base (12) + skill (5) = 17
assert effective.strength == 17
def test_get_effective_stats_with_all_modifiers(basic_character):
"""Test get_effective_stats() with equipment + skills + active effects."""
# Add equipment: +5 strength
weapon = Item(
item_id="iron_sword",
name="Iron Sword",
item_type=ItemType.WEAPON,
description="A sturdy iron sword",
stat_bonuses={"strength": 5},
damage=10,
)
basic_character.equipped["weapon"] = weapon
# Unlock skill: +5 strength
basic_character.unlocked_skills.append("power_strike")
# Add buff effect: +3 strength
buff = Effect(
effect_id="str_buff",
name="Strength Boost",
effect_type=EffectType.BUFF,
duration=3,
power=3,
stat_affected=StatType.STRENGTH,
)
effective = basic_character.get_effective_stats([buff])
# Total strength: 12 (base) + 5 (weapon) + 5 (skill) + 3 (buff) = 25
assert effective.strength == 25
def test_get_effective_stats_with_debuff(basic_character):
"""Test get_effective_stats() with a debuff."""
debuff = Effect(
effect_id="weakened",
name="Weakened",
effect_type=EffectType.DEBUFF,
duration=2,
power=5,
stat_affected=StatType.STRENGTH,
)
effective = basic_character.get_effective_stats([debuff])
# Strength: 12 (base) - 5 (debuff) = 7
assert effective.strength == 7
def test_get_effective_stats_debuff_minimum(basic_character):
"""Test that debuffs cannot reduce stats below 1."""
# Massive debuff
debuff = Effect(
effect_id="weakened",
name="Weakened",
effect_type=EffectType.DEBUFF,
duration=2,
power=20, # More than base strength
stat_affected=StatType.STRENGTH,
)
effective = basic_character.get_effective_stats([debuff])
# Should be clamped at 1, not 0 or negative
assert effective.strength == 1
def test_inventory_management(basic_character):
"""Test adding and removing items from inventory."""
item = Item(
item_id="potion",
name="Health Potion",
item_type=ItemType.CONSUMABLE,
description="Restores 50 HP",
value=25,
)
# Add item
basic_character.add_item(item)
assert len(basic_character.inventory) == 1
assert basic_character.inventory[0].item_id == "potion"
# Remove item
removed = basic_character.remove_item("potion")
assert removed.item_id == "potion"
assert len(basic_character.inventory) == 0
def test_equip_and_unequip(basic_character):
"""Test equipping and unequipping items."""
weapon = Item(
item_id="iron_sword",
name="Iron Sword",
item_type=ItemType.WEAPON,
description="A sturdy sword",
stat_bonuses={"strength": 5},
damage=10,
)
# Add to inventory first
basic_character.add_item(weapon)
assert len(basic_character.inventory) == 1
# Equip weapon
basic_character.equip_item(weapon, "weapon")
assert "weapon" in basic_character.equipped
assert basic_character.equipped["weapon"].item_id == "iron_sword"
assert len(basic_character.inventory) == 0 # Removed from inventory
# Unequip weapon
unequipped = basic_character.unequip_item("weapon")
assert unequipped.item_id == "iron_sword"
assert "weapon" not in basic_character.equipped
assert len(basic_character.inventory) == 1 # Back in inventory
def test_equip_replaces_existing(basic_character):
"""Test that equipping a new item replaces the old one."""
weapon1 = Item(
item_id="iron_sword",
name="Iron Sword",
item_type=ItemType.WEAPON,
description="A sturdy sword",
damage=10,
)
weapon2 = Item(
item_id="steel_sword",
name="Steel Sword",
item_type=ItemType.WEAPON,
description="A better sword",
damage=15,
)
# Equip first weapon
basic_character.add_item(weapon1)
basic_character.equip_item(weapon1, "weapon")
# Equip second weapon
basic_character.add_item(weapon2)
previous = basic_character.equip_item(weapon2, "weapon")
assert previous.item_id == "iron_sword" # Old weapon returned
assert basic_character.equipped["weapon"].item_id == "steel_sword"
assert len(basic_character.inventory) == 1 # Old weapon back in inventory
def test_gold_management(basic_character):
"""Test adding and removing gold."""
assert basic_character.gold == 0
# Add gold
basic_character.add_gold(100)
assert basic_character.gold == 100
# Remove gold
success = basic_character.remove_gold(50)
assert success == True
assert basic_character.gold == 50
# Try to remove more than available
success = basic_character.remove_gold(100)
assert success == False
assert basic_character.gold == 50 # Unchanged
def test_can_afford(basic_character):
"""Test can_afford() method."""
basic_character.gold = 100
assert basic_character.can_afford(50) == True
assert basic_character.can_afford(100) == True
assert basic_character.can_afford(101) == False
def test_add_experience_no_level_up(basic_character):
"""Test adding experience without leveling up."""
leveled_up = basic_character.add_experience(50)
assert leveled_up == False
assert basic_character.level == 1
assert basic_character.experience == 50
def test_add_experience_with_level_up(basic_character):
"""Test adding enough experience to level up."""
# Level 1 requires 100 XP for level 2
leveled_up = basic_character.add_experience(100)
assert leveled_up == True
assert basic_character.level == 2
assert basic_character.experience == 0 # Reset
def test_add_experience_with_overflow(basic_character):
"""Test leveling up with overflow experience."""
# Level 1 requires 100 XP, give 150
leveled_up = basic_character.add_experience(150)
assert leveled_up == True
assert basic_character.level == 2
assert basic_character.experience == 50 # Overflow
def test_xp_calculation(basic_origin):
"""Test XP required for each level."""
char = Character(
character_id="test",
user_id="user",
name="Test",
player_class=PlayerClass(
class_id="test",
name="Test",
description="Test",
base_stats=Stats(),
),
origin=basic_origin,
)
# Formula: 100 * (level ^ 1.5)
assert char._calculate_xp_for_next_level() == 100 # Level 1→2
char.level = 2
assert char._calculate_xp_for_next_level() == 282 # Level 2→3
char.level = 3
assert char._calculate_xp_for_next_level() == 519 # Level 3→4
def test_get_unlocked_abilities(basic_character):
"""Test getting abilities from class + unlocked skills."""
# Should have starting abilities
abilities = basic_character.get_unlocked_abilities()
assert "basic_attack" in abilities
# TODO: When skills unlock abilities, test that here
def test_character_serialization(basic_character):
"""Test character to_dict() serialization."""
# Add some data
basic_character.gold = 500
basic_character.level = 3
basic_character.experience = 100
data = basic_character.to_dict()
assert data["character_id"] == "char_001"
assert data["user_id"] == "user_001"
assert data["name"] == "Test Hero"
assert data["level"] == 3
assert data["experience"] == 100
assert data["gold"] == 500
def test_character_deserialization(basic_player_class, basic_origin):
"""Test character from_dict() deserialization."""
data = {
"character_id": "char_002",
"user_id": "user_002",
"name": "Restored Hero",
"player_class": basic_player_class.to_dict(),
"origin": basic_origin.to_dict(),
"level": 5,
"experience": 200,
"base_stats": Stats(strength=15).to_dict(),
"unlocked_skills": ["power_strike"],
"inventory": [],
"equipped": {},
"gold": 1000,
"active_quests": ["quest_1"],
"discovered_locations": ["town_1"],
}
char = Character.from_dict(data)
assert char.character_id == "char_002"
assert char.name == "Restored Hero"
assert char.level == 5
assert char.gold == 1000
assert "power_strike" in char.unlocked_skills
def test_character_round_trip_serialization(basic_character):
"""Test that serialization and deserialization preserve all data."""
# Add complex state
basic_character.gold = 500
basic_character.level = 3
basic_character.unlocked_skills = ["power_strike"]
weapon = Item(
item_id="sword",
name="Sword",
item_type=ItemType.WEAPON,
description="A sword",
damage=10,
)
basic_character.equipped["weapon"] = weapon
# Serialize and deserialize
data = basic_character.to_dict()
restored = Character.from_dict(data)
assert restored.character_id == basic_character.character_id
assert restored.name == basic_character.name
assert restored.level == basic_character.level
assert restored.gold == basic_character.gold
assert restored.unlocked_skills == basic_character.unlocked_skills
assert "weapon" in restored.equipped
assert restored.equipped["weapon"].item_id == "sword"

View File

@@ -0,0 +1,547 @@
"""
Unit tests for CharacterService - character CRUD operations and tier limits.
These tests verify character creation, retrieval, deletion, skill unlock,
and respec functionality with proper validation and error handling.
"""
import pytest
import json
from unittest.mock import Mock, MagicMock, patch
from datetime import datetime, timezone
from app.services.character_service import (
CharacterService,
CharacterLimitExceeded,
CharacterNotFound,
SkillUnlockError,
InsufficientGold,
CHARACTER_LIMITS
)
from app.models.character import Character
from app.models.stats import Stats
from app.models.skills import PlayerClass, SkillTree, SkillNode
from app.models.origins import Origin, StartingLocation, StartingBonus
from app.services.database_service import DatabaseDocument
class TestCharacterService:
"""Test suite for CharacterService."""
@pytest.fixture
def mock_db(self):
"""Mock database service."""
return Mock()
@pytest.fixture
def mock_appwrite(self):
"""Mock Appwrite service."""
return Mock()
@pytest.fixture
def mock_class_loader(self):
"""Mock class loader."""
return Mock()
@pytest.fixture
def mock_origin_service(self):
"""Mock origin service."""
return Mock()
@pytest.fixture
def character_service(self, mock_db, mock_appwrite, mock_class_loader, mock_origin_service):
"""Create CharacterService with mocked dependencies."""
# Patch the singleton getters before instantiation
with patch('app.services.character_service.get_database_service', return_value=mock_db), \
patch('app.services.character_service.AppwriteService', return_value=mock_appwrite), \
patch('app.services.character_service.get_class_loader', return_value=mock_class_loader), \
patch('app.services.character_service.get_origin_service', return_value=mock_origin_service):
service = CharacterService()
# Ensure mocks are still assigned
service.db = mock_db
service.appwrite = mock_appwrite
service.class_loader = mock_class_loader
service.origin_service = mock_origin_service
return service
@pytest.fixture
def sample_class(self):
"""Create a sample player class."""
base_stats = Stats(strength=12, dexterity=10, constitution=14)
skill_tree = SkillTree(
tree_id="warrior_offense",
name="Warrior Offense",
description="Offensive skills",
nodes=[
SkillNode(
skill_id="power_strike",
name="Power Strike",
description="+5 Strength",
tier=1,
effects={"strength": 5},
),
SkillNode(
skill_id="heavy_blow",
name="Heavy Blow",
description="+10 Strength",
tier=2,
prerequisites=["power_strike"],
effects={"strength": 10},
),
],
)
return PlayerClass(
class_id="warrior",
name="Warrior",
description="Strong fighter",
base_stats=base_stats,
skill_trees=[skill_tree],
starting_equipment=["basic_sword"],
starting_abilities=["basic_attack"],
)
@pytest.fixture
def sample_origin(self):
"""Create a sample origin."""
starting_location = StartingLocation(
id="test_crypt",
name="Test Crypt",
region="Test Region",
description="A test location"
)
starting_bonus = StartingBonus(
trait="Test Trait",
description="Test bonus",
effect="+1 to all stats"
)
return Origin(
id="test_origin",
name="Test Origin",
description="A test origin",
starting_location=starting_location,
narrative_hooks=["Test hook"],
starting_bonus=starting_bonus
)
def test_create_character_success(
self, character_service, mock_appwrite, mock_class_loader,
mock_origin_service, mock_db, sample_class, sample_origin
):
"""Test successful character creation."""
# Setup mocks
mock_appwrite.get_user_tier.return_value = 'free'
mock_class_loader.load_class.return_value = sample_class
mock_origin_service.load_origin.return_value = sample_origin
# Mock count_user_characters to return 0
character_service.count_user_characters = Mock(return_value=0)
# Create character
with patch('app.services.character_service.ID') as mock_id:
mock_id.unique.return_value = 'char_123'
character = character_service.create_character(
user_id='user_001',
name='Test Hero',
class_id='warrior',
origin_id='test_origin'
)
# Assertions
assert character.character_id == 'char_123'
assert character.user_id == 'user_001'
assert character.name == 'Test Hero'
assert character.level == 1
assert character.experience == 0
assert character.gold == 0
assert character.current_location == 'test_crypt'
# Verify database was called
mock_db.create_document.assert_called_once()
def test_create_character_exceeds_limit(
self, character_service, mock_appwrite, mock_class_loader, mock_origin_service
):
"""Test character creation fails when tier limit exceeded."""
# Setup: user on free tier (limit 1) with 1 existing character
mock_appwrite.get_user_tier.return_value = 'free'
character_service.count_user_characters = Mock(return_value=1)
# Attempt to create second character
with pytest.raises(CharacterLimitExceeded) as exc_info:
character_service.create_character(
user_id='user_001',
name='Second Hero',
class_id='warrior',
origin_id='test_origin'
)
assert 'free tier' in str(exc_info.value)
assert '1/1' in str(exc_info.value)
def test_create_character_tier_limits(self):
"""Test that tier limits are correctly defined."""
assert CHARACTER_LIMITS['free'] == 1
assert CHARACTER_LIMITS['basic'] == 3
assert CHARACTER_LIMITS['premium'] == 5
assert CHARACTER_LIMITS['elite'] == 10
def test_create_character_invalid_class(
self, character_service, mock_appwrite, mock_class_loader, mock_origin_service
):
"""Test character creation fails with invalid class."""
mock_appwrite.get_user_tier.return_value = 'free'
character_service.count_user_characters = Mock(return_value=0)
mock_class_loader.load_class.return_value = None
with pytest.raises(ValueError) as exc_info:
character_service.create_character(
user_id='user_001',
name='Test Hero',
class_id='invalid_class',
origin_id='test_origin'
)
assert 'Class not found' in str(exc_info.value)
def test_create_character_invalid_origin(
self, character_service, mock_appwrite, mock_class_loader,
mock_origin_service, sample_class
):
"""Test character creation fails with invalid origin."""
mock_appwrite.get_user_tier.return_value = 'free'
character_service.count_user_characters = Mock(return_value=0)
mock_class_loader.load_class.return_value = sample_class
mock_origin_service.load_origin.return_value = None
with pytest.raises(ValueError) as exc_info:
character_service.create_character(
user_id='user_001',
name='Test Hero',
class_id='warrior',
origin_id='invalid_origin'
)
assert 'Origin not found' in str(exc_info.value)
def test_get_character_success(self, character_service, mock_db, sample_class, sample_origin):
"""Test successfully retrieving a character."""
# Create test character data
character = Character(
character_id='char_123',
user_id='user_001',
name='Test Hero',
player_class=sample_class,
origin=sample_origin,
level=5,
gold=500
)
# Mock database response
mock_doc = Mock(spec=DatabaseDocument)
mock_doc.id = 'char_123'
mock_doc.data = {
'userId': 'user_001',
'characterData': json.dumps(character.to_dict()),
'is_active': True
}
mock_db.get_document.return_value = mock_doc
# Get character
result = character_service.get_character('char_123', 'user_001')
# Assertions
assert result.character_id == 'char_123'
assert result.name == 'Test Hero'
assert result.level == 5
assert result.gold == 500
def test_get_character_not_found(self, character_service, mock_db):
"""Test getting non-existent character raises error."""
mock_db.get_document.return_value = None
with pytest.raises(CharacterNotFound):
character_service.get_character('nonexistent', 'user_001')
def test_get_character_wrong_owner(self, character_service, mock_db):
"""Test getting character owned by different user raises error."""
mock_doc = Mock(spec=DatabaseDocument)
mock_doc.data = {'userId': 'user_002'} # Different user
mock_db.get_document.return_value = mock_doc
with pytest.raises(CharacterNotFound):
character_service.get_character('char_123', 'user_001')
def test_get_user_characters(self, character_service, mock_db, sample_class, sample_origin):
"""Test getting all characters for a user."""
# Create test character data
char1_data = Character(
character_id='char_1',
user_id='user_001',
name='Hero 1',
player_class=sample_class,
origin=sample_origin
).to_dict()
char2_data = Character(
character_id='char_2',
user_id='user_001',
name='Hero 2',
player_class=sample_class,
origin=sample_origin
).to_dict()
# Mock database response
mock_doc1 = Mock(spec=DatabaseDocument)
mock_doc1.id = 'char_1'
mock_doc1.data = {'characterData': json.dumps(char1_data)}
mock_doc2 = Mock(spec=DatabaseDocument)
mock_doc2.id = 'char_2'
mock_doc2.data = {'characterData': json.dumps(char2_data)}
mock_db.list_rows.return_value = [mock_doc1, mock_doc2]
# Get characters
characters = character_service.get_user_characters('user_001')
# Assertions
assert len(characters) == 2
assert characters[0].name == 'Hero 1'
assert characters[1].name == 'Hero 2'
def test_count_user_characters(self, character_service, mock_db):
"""Test counting user's characters."""
mock_db.count_documents.return_value = 3
count = character_service.count_user_characters('user_001')
assert count == 3
mock_db.count_documents.assert_called_once()
def test_delete_character_success(self, character_service, mock_db):
"""Test successfully deleting a character."""
# Mock get_character to return a valid character
character_service.get_character = Mock(return_value=Mock(character_id='char_123'))
# Delete character
result = character_service.delete_character('char_123', 'user_001')
# Assertions
assert result is True
mock_db.update_document.assert_called_once()
# Verify it's a soft delete (is_active set to False)
call_args = mock_db.update_document.call_args
assert call_args[1]['data']['is_active'] is False
def test_delete_character_not_found(self, character_service):
"""Test deleting non-existent character raises error."""
character_service.get_character = Mock(side_effect=CharacterNotFound("Not found"))
with pytest.raises(CharacterNotFound):
character_service.delete_character('nonexistent', 'user_001')
def test_unlock_skill_success(self, character_service, sample_class, sample_origin):
"""Test successfully unlocking a skill."""
# Create character with level 2 (1 skill point available)
character = Character(
character_id='char_123',
user_id='user_001',
name='Test Hero',
player_class=sample_class,
origin=sample_origin,
level=2,
unlocked_skills=[]
)
# Mock get_character
character_service.get_character = Mock(return_value=character)
# Mock _save_character
character_service._save_character = Mock()
# Unlock skill
result = character_service.unlock_skill('char_123', 'user_001', 'power_strike')
# Assertions
assert 'power_strike' in result.unlocked_skills
character_service._save_character.assert_called_once()
def test_unlock_skill_already_unlocked(self, character_service, sample_class, sample_origin):
"""Test unlocking already unlocked skill raises error."""
character = Character(
character_id='char_123',
user_id='user_001',
name='Test Hero',
player_class=sample_class,
origin=sample_origin,
level=2,
unlocked_skills=['power_strike'] # Already unlocked
)
character_service.get_character = Mock(return_value=character)
with pytest.raises(SkillUnlockError) as exc_info:
character_service.unlock_skill('char_123', 'user_001', 'power_strike')
assert 'already unlocked' in str(exc_info.value)
def test_unlock_skill_not_in_class(self, character_service, sample_class, sample_origin):
"""Test unlocking skill not in class raises error."""
character = Character(
character_id='char_123',
user_id='user_001',
name='Test Hero',
player_class=sample_class,
origin=sample_origin,
level=2,
unlocked_skills=[]
)
character_service.get_character = Mock(return_value=character)
with pytest.raises(SkillUnlockError) as exc_info:
character_service.unlock_skill('char_123', 'user_001', 'invalid_skill')
assert 'not found in class' in str(exc_info.value)
def test_unlock_skill_missing_prerequisite(self, character_service, sample_class, sample_origin):
"""Test unlocking skill without prerequisite raises error."""
character = Character(
character_id='char_123',
user_id='user_001',
name='Test Hero',
player_class=sample_class,
origin=sample_origin,
level=2,
unlocked_skills=[] # Missing 'power_strike' prerequisite
)
character_service.get_character = Mock(return_value=character)
with pytest.raises(SkillUnlockError) as exc_info:
character_service.unlock_skill('char_123', 'user_001', 'heavy_blow')
assert 'Prerequisite not met' in str(exc_info.value)
assert 'power_strike' in str(exc_info.value)
def test_unlock_skill_no_points_available(self, character_service, sample_class, sample_origin):
"""Test unlocking skill without available points raises error."""
# Add a tier 3 skill to test with
tier3_skill = SkillNode(
skill_id="master_strike",
name="Master Strike",
description="+15 Strength",
tier=3,
effects={"strength": 15},
)
sample_class.skill_trees[0].nodes.append(tier3_skill)
# Level 1 character with 1 skill already unlocked = 0 points remaining
character = Character(
character_id='char_123',
user_id='user_001',
name='Test Hero',
player_class=sample_class,
origin=sample_origin,
level=1,
unlocked_skills=['power_strike'] # Used the 1 point from level 1
)
character_service.get_character = Mock(return_value=character)
# Try to unlock another skill (master_strike exists in class)
with pytest.raises(SkillUnlockError) as exc_info:
character_service.unlock_skill('char_123', 'user_001', 'master_strike')
assert 'No skill points available' in str(exc_info.value)
def test_respec_skills_success(self, character_service, sample_class, sample_origin):
"""Test successfully respecing character skills."""
# Level 5 character with 3 skills and 500 gold
# Respec cost = 5 * 100 = 500 gold
character = Character(
character_id='char_123',
user_id='user_001',
name='Test Hero',
player_class=sample_class,
origin=sample_origin,
level=5,
gold=500,
unlocked_skills=['skill1', 'skill2', 'skill3']
)
character_service.get_character = Mock(return_value=character)
character_service._save_character = Mock()
# Respec skills
result = character_service.respec_skills('char_123', 'user_001')
# Assertions
assert len(result.unlocked_skills) == 0 # Skills cleared
assert result.gold == 0 # 500 - 500 = 0
character_service._save_character.assert_called_once()
def test_respec_skills_insufficient_gold(self, character_service, sample_class, sample_origin):
"""Test respec fails with insufficient gold."""
# Level 5 character with only 100 gold (needs 500)
character = Character(
character_id='char_123',
user_id='user_001',
name='Test Hero',
player_class=sample_class,
origin=sample_origin,
level=5,
gold=100, # Not enough
unlocked_skills=['skill1']
)
character_service.get_character = Mock(return_value=character)
with pytest.raises(InsufficientGold) as exc_info:
character_service.respec_skills('char_123', 'user_001')
assert '500' in str(exc_info.value) # Cost
assert '100' in str(exc_info.value) # Available
def test_update_character_success(self, character_service, sample_class, sample_origin):
"""Test successfully updating a character."""
character = Character(
character_id='char_123',
user_id='user_001',
name='Test Hero',
player_class=sample_class,
origin=sample_origin,
gold=1000 # Updated gold
)
# Mock get_character to verify ownership
character_service.get_character = Mock(return_value=character)
character_service._save_character = Mock()
# Update character
result = character_service.update_character(character, 'user_001')
# Assertions
assert result.gold == 1000
character_service._save_character.assert_called_once()
def test_update_character_not_found(self, character_service, sample_class, sample_origin):
"""Test updating non-existent character raises error."""
character = Character(
character_id='nonexistent',
user_id='user_001',
name='Test Hero',
player_class=sample_class,
origin=sample_origin
)
character_service.get_character = Mock(side_effect=CharacterNotFound("Not found"))
with pytest.raises(CharacterNotFound):
character_service.update_character(character, 'user_001')

View File

@@ -0,0 +1,256 @@
"""
Unit tests for ClassLoader service.
Tests loading player class definitions from YAML files.
"""
import pytest
from pathlib import Path
from app.services.class_loader import ClassLoader
from app.models.skills import PlayerClass, SkillTree, SkillNode
from app.models.stats import Stats
class TestClassLoader:
"""Test ClassLoader functionality."""
@pytest.fixture
def loader(self):
"""Create a ClassLoader instance for testing."""
return ClassLoader()
def test_load_vanguard_class(self, loader):
"""Test loading the Vanguard class."""
vanguard = loader.load_class("vanguard")
# Verify class loaded successfully
assert vanguard is not None
assert isinstance(vanguard, PlayerClass)
# Check basic properties
assert vanguard.class_id == "vanguard"
assert vanguard.name == "Vanguard"
assert "melee combat" in vanguard.description.lower()
def test_vanguard_base_stats(self, loader):
"""Test Vanguard base stats are correct."""
vanguard = loader.load_class("vanguard")
assert isinstance(vanguard.base_stats, Stats)
assert vanguard.base_stats.strength == 14
assert vanguard.base_stats.dexterity == 10
assert vanguard.base_stats.constitution == 14
assert vanguard.base_stats.intelligence == 8
assert vanguard.base_stats.wisdom == 10
assert vanguard.base_stats.charisma == 9
def test_vanguard_skill_trees(self, loader):
"""Test Vanguard has 2 skill trees."""
vanguard = loader.load_class("vanguard")
# Should have exactly 2 skill trees
assert len(vanguard.skill_trees) == 2
# Get trees by ID
shield_bearer = vanguard.get_skill_tree("shield_bearer")
weapon_master = vanguard.get_skill_tree("weapon_master")
assert shield_bearer is not None
assert weapon_master is not None
# Check tree names
assert shield_bearer.name == "Shield Bearer"
assert weapon_master.name == "Weapon Master"
def test_shield_bearer_tree_structure(self, loader):
"""Test Shield Bearer tree has correct structure."""
vanguard = loader.load_class("vanguard")
shield_bearer = vanguard.get_skill_tree("shield_bearer")
# Should have 10 nodes (5 tiers × 2 nodes)
assert len(shield_bearer.nodes) == 10
# Check tier distribution
tier_counts = {}
for node in shield_bearer.nodes:
tier_counts[node.tier] = tier_counts.get(node.tier, 0) + 1
# Should have 2 nodes per tier for tiers 1-5
assert tier_counts == {1: 2, 2: 2, 3: 2, 4: 2, 5: 2}
def test_weapon_master_tree_structure(self, loader):
"""Test Weapon Master tree has correct structure."""
vanguard = loader.load_class("vanguard")
weapon_master = vanguard.get_skill_tree("weapon_master")
# Should have 10 nodes (5 tiers × 2 nodes)
assert len(weapon_master.nodes) == 10
# Check tier distribution
tier_counts = {}
for node in weapon_master.nodes:
tier_counts[node.tier] = tier_counts.get(node.tier, 0) + 1
# Should have 2 nodes per tier for tiers 1-5
assert tier_counts == {1: 2, 2: 2, 3: 2, 4: 2, 5: 2}
def test_skill_node_prerequisites(self, loader):
"""Test skill nodes have correct prerequisites."""
vanguard = loader.load_class("vanguard")
shield_bearer = vanguard.get_skill_tree("shield_bearer")
# Find tier 1 and tier 2 nodes
tier1_nodes = [n for n in shield_bearer.nodes if n.tier == 1]
tier2_nodes = [n for n in shield_bearer.nodes if n.tier == 2]
# Tier 1 nodes should have no prerequisites
for node in tier1_nodes:
assert len(node.prerequisites) == 0
# Tier 2 nodes should have prerequisites
for node in tier2_nodes:
assert len(node.prerequisites) > 0
# Prerequisites should reference tier 1 skills
for prereq_id in node.prerequisites:
prereq_found = any(n.skill_id == prereq_id for n in tier1_nodes)
assert prereq_found, f"Prerequisite {prereq_id} not found in tier 1"
def test_skill_node_effects(self, loader):
"""Test skill nodes have proper effects defined."""
vanguard = loader.load_class("vanguard")
shield_bearer = vanguard.get_skill_tree("shield_bearer")
# Find the "fortify" skill (passive defense bonus)
fortify = next((n for n in shield_bearer.nodes if n.skill_id == "fortify"), None)
assert fortify is not None
# Should have stat bonuses in effects
assert "stat_bonuses" in fortify.effects
assert "defense" in fortify.effects["stat_bonuses"]
assert fortify.effects["stat_bonuses"]["defense"] == 5
def test_skill_node_abilities(self, loader):
"""Test skill nodes with ability unlocks."""
vanguard = loader.load_class("vanguard")
shield_bearer = vanguard.get_skill_tree("shield_bearer")
# Find the "shield_bash" skill (active ability)
shield_bash = next((n for n in shield_bearer.nodes if n.skill_id == "shield_bash"), None)
assert shield_bash is not None
# Should have abilities in effects
assert "abilities" in shield_bash.effects
assert "shield_bash" in shield_bash.effects["abilities"]
def test_starting_equipment(self, loader):
"""Test Vanguard starting equipment."""
vanguard = loader.load_class("vanguard")
# Should have starting equipment
assert len(vanguard.starting_equipment) > 0
assert "rusty_sword" in vanguard.starting_equipment
assert "cloth_armor" in vanguard.starting_equipment
assert "rusty_knife" in vanguard.starting_equipment
def test_starting_abilities(self, loader):
"""Test Vanguard starting abilities."""
vanguard = loader.load_class("vanguard")
# Should have basic_attack
assert len(vanguard.starting_abilities) > 0
assert "basic_attack" in vanguard.starting_abilities
def test_cache_functionality(self, loader):
"""Test that classes are cached after first load."""
# Load class twice
vanguard1 = loader.load_class("vanguard")
vanguard2 = loader.load_class("vanguard")
# Should be the same object (cached)
assert vanguard1 is vanguard2
def test_reload_class(self, loader):
"""Test reload_class forces a fresh load."""
# Load class
vanguard1 = loader.load_class("vanguard")
# Reload class
vanguard2 = loader.reload_class("vanguard")
# Should still be equal but different objects
assert vanguard1.class_id == vanguard2.class_id
# Note: May or may not be same object depending on implementation
def test_load_nonexistent_class(self, loader):
"""Test loading a class that doesn't exist."""
result = loader.load_class("nonexistent_class")
assert result is None
def test_get_all_class_ids(self, loader):
"""Test getting list of all class IDs."""
class_ids = loader.get_all_class_ids()
# Should include vanguard
assert "vanguard" in class_ids
# Should be a list of strings
assert all(isinstance(cid, str) for cid in class_ids)
def test_load_all_classes(self, loader):
"""Test loading all classes at once."""
classes = loader.load_all_classes()
# Should have at least 1 class (vanguard)
assert len(classes) >= 1
# All should be PlayerClass instances
assert all(isinstance(c, PlayerClass) for c in classes)
# Vanguard should be in the list
vanguard_found = any(c.class_id == "vanguard" for c in classes)
assert vanguard_found
def test_get_all_skills_method(self, loader):
"""Test PlayerClass.get_all_skills() method."""
vanguard = loader.load_class("vanguard")
all_skills = vanguard.get_all_skills()
# Should have 20 skills (2 trees × 10 nodes)
assert len(all_skills) == 20
# All should be SkillNode instances
assert all(isinstance(skill, SkillNode) for skill in all_skills)
# Should include skills from both trees
skill_ids = [s.skill_id for s in all_skills]
assert "shield_bash" in skill_ids # Shield Bearer tree
assert "power_strike" in skill_ids # Weapon Master tree
def test_tier_5_ultimate_skills(self, loader):
"""Test that tier 5 skills exist and have powerful effects."""
vanguard = loader.load_class("vanguard")
all_skills = vanguard.get_all_skills()
# Get tier 5 skills
tier5_skills = [s for s in all_skills if s.tier == 5]
# Should have 4 tier 5 skills (2 per tree)
assert len(tier5_skills) == 4
# Each tier 5 skill should have prerequisites
for skill in tier5_skills:
assert len(skill.prerequisites) > 0
# At least one tier 5 skill should have significant stat bonuses
has_major_bonuses = any(
"stat_bonuses" in s.effects and
any(v >= 10 for v in s.effects.get("stat_bonuses", {}).values())
for s in tier5_skills
)
assert has_major_bonuses
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,509 @@
"""
Integration tests for combat system.
Tests the complete combat flow including damage calculation, effects, and turn order.
"""
import pytest
from app.models.stats import Stats
from app.models.items import Item
from app.models.effects import Effect
from app.models.abilities import Ability
from app.models.combat import Combatant, CombatEncounter
from app.models.character import Character
from app.models.skills import PlayerClass
from app.models.enums import (
ItemType,
DamageType,
EffectType,
StatType,
AbilityType,
CombatStatus,
)
@pytest.fixture
def warrior_combatant():
"""Create a warrior combatant for testing."""
stats = Stats(strength=15, dexterity=10, constitution=14, intelligence=8, wisdom=10, charisma=11)
return Combatant(
combatant_id="warrior_1",
name="Test Warrior",
is_player=True,
current_hp=stats.hit_points,
max_hp=stats.hit_points,
current_mp=stats.mana_points,
max_mp=stats.mana_points,
stats=stats,
abilities=["basic_attack", "power_strike"],
)
@pytest.fixture
def goblin_combatant():
"""Create a goblin enemy for testing."""
stats = Stats(strength=8, dexterity=12, constitution=10, intelligence=6, wisdom=8, charisma=6)
return Combatant(
combatant_id="goblin_1",
name="Goblin",
is_player=False,
current_hp=stats.hit_points,
max_hp=stats.hit_points,
current_mp=stats.mana_points,
max_mp=stats.mana_points,
stats=stats,
abilities=["basic_attack"],
)
def test_combatant_creation(warrior_combatant):
"""Test creating a Combatant."""
assert warrior_combatant.combatant_id == "warrior_1"
assert warrior_combatant.name == "Test Warrior"
assert warrior_combatant.is_player == True
assert warrior_combatant.is_alive() == True
assert warrior_combatant.is_stunned() == False
def test_combatant_take_damage(warrior_combatant):
"""Test taking damage."""
initial_hp = warrior_combatant.current_hp
damage_dealt = warrior_combatant.take_damage(10)
assert damage_dealt == 10
assert warrior_combatant.current_hp == initial_hp - 10
def test_combatant_take_damage_with_shield(warrior_combatant):
"""Test taking damage with shield absorption."""
# Add a shield effect
shield = Effect(
effect_id="shield_1",
name="Shield",
effect_type=EffectType.SHIELD,
duration=3,
power=15,
)
warrior_combatant.add_effect(shield)
initial_hp = warrior_combatant.current_hp
# Deal 10 damage - should be fully absorbed by shield
damage_dealt = warrior_combatant.take_damage(10)
assert damage_dealt == 0 # No HP damage
assert warrior_combatant.current_hp == initial_hp
def test_combatant_death(warrior_combatant):
"""Test combatant death."""
assert warrior_combatant.is_alive() == True
# Deal massive damage
warrior_combatant.take_damage(1000)
assert warrior_combatant.is_alive() == False
assert warrior_combatant.is_dead() == True
def test_combatant_healing(warrior_combatant):
"""Test healing."""
# Take some damage first
warrior_combatant.take_damage(20)
damaged_hp = warrior_combatant.current_hp
# Heal
healed = warrior_combatant.heal(10)
assert healed == 10
assert warrior_combatant.current_hp == damaged_hp + 10
def test_combatant_healing_capped_at_max(warrior_combatant):
"""Test that healing cannot exceed max HP."""
max_hp = warrior_combatant.max_hp
# Try to heal beyond max
healed = warrior_combatant.heal(1000)
assert warrior_combatant.current_hp == max_hp
def test_combatant_stun_effect(warrior_combatant):
"""Test stun effect prevents actions."""
assert warrior_combatant.is_stunned() == False
# Add stun effect
stun = Effect(
effect_id="stun_1",
name="Stunned",
effect_type=EffectType.STUN,
duration=1,
power=0,
)
warrior_combatant.add_effect(stun)
assert warrior_combatant.is_stunned() == True
def test_combatant_tick_effects(warrior_combatant):
"""Test that ticking effects deals damage/healing."""
# Add a DOT effect
poison = Effect(
effect_id="poison_1",
name="Poison",
effect_type=EffectType.DOT,
duration=3,
power=5,
)
warrior_combatant.add_effect(poison)
initial_hp = warrior_combatant.current_hp
# Tick effects
results = warrior_combatant.tick_effects()
# Should have taken 5 poison damage
assert len(results) == 1
assert results[0]["effect_type"] == "dot"
assert results[0]["value"] == 5
assert warrior_combatant.current_hp == initial_hp - 5
def test_combatant_effect_expiration(warrior_combatant):
"""Test that expired effects are removed."""
# Add effect with 1 turn duration
dot = Effect(
effect_id="burn_1",
name="Burning",
effect_type=EffectType.DOT,
duration=1,
power=5,
)
warrior_combatant.add_effect(dot)
assert len(warrior_combatant.active_effects) == 1
# Tick - effect should expire
results = warrior_combatant.tick_effects()
assert results[0]["expired"] == True
assert len(warrior_combatant.active_effects) == 0 # Removed
def test_ability_mana_cost(warrior_combatant):
"""Test ability mana cost and usage."""
ability = Ability(
ability_id="fireball",
name="Fireball",
description="Fiery explosion",
ability_type=AbilityType.SPELL,
base_power=30,
damage_type=DamageType.FIRE,
mana_cost=15,
)
initial_mp = warrior_combatant.current_mp
# Check if can use
assert warrior_combatant.can_use_ability("fireball", ability) == False # Not in ability list
warrior_combatant.abilities.append("fireball")
assert warrior_combatant.can_use_ability("fireball", ability) == True
# Use ability
warrior_combatant.use_ability_cost(ability, "fireball")
assert warrior_combatant.current_mp == initial_mp - 15
def test_ability_cooldown(warrior_combatant):
"""Test ability cooldowns."""
ability = Ability(
ability_id="power_strike",
name="Power Strike",
description="Powerful attack",
ability_type=AbilityType.SKILL,
base_power=20,
cooldown=3,
)
warrior_combatant.abilities.append("power_strike")
# Can use initially
assert warrior_combatant.can_use_ability("power_strike", ability) == True
# Use ability
warrior_combatant.use_ability_cost(ability, "power_strike")
# Now on cooldown
assert "power_strike" in warrior_combatant.cooldowns
assert warrior_combatant.cooldowns["power_strike"] == 3
assert warrior_combatant.can_use_ability("power_strike", ability) == False
# Tick cooldown
warrior_combatant.tick_cooldowns()
assert warrior_combatant.cooldowns["power_strike"] == 2
# Tick more
warrior_combatant.tick_cooldowns()
warrior_combatant.tick_cooldowns()
# Should be available again
assert "power_strike" not in warrior_combatant.cooldowns
assert warrior_combatant.can_use_ability("power_strike", ability) == True
def test_combat_encounter_initialization(warrior_combatant, goblin_combatant):
"""Test initializing a combat encounter."""
encounter = CombatEncounter(
encounter_id="combat_001",
combatants=[warrior_combatant, goblin_combatant],
)
encounter.initialize_combat()
# Should have turn order
assert len(encounter.turn_order) == 2
assert encounter.round_number == 1
assert encounter.status == CombatStatus.ACTIVE
# Both combatants should have initiative
assert warrior_combatant.initiative > 0
assert goblin_combatant.initiative > 0
def test_combat_turn_advancement(warrior_combatant, goblin_combatant):
"""Test advancing turns in combat."""
encounter = CombatEncounter(
encounter_id="combat_001",
combatants=[warrior_combatant, goblin_combatant],
)
encounter.initialize_combat()
# Get first combatant
first = encounter.get_current_combatant()
assert first is not None
# Advance turn
encounter.advance_turn()
# Should be second combatant now
second = encounter.get_current_combatant()
assert second is not None
assert second.combatant_id != first.combatant_id
# Advance again - should cycle back to first and increment round
encounter.advance_turn()
assert encounter.round_number == 2
third = encounter.get_current_combatant()
assert third.combatant_id == first.combatant_id
def test_combat_victory_condition(warrior_combatant, goblin_combatant):
"""Test victory condition detection."""
encounter = CombatEncounter(
encounter_id="combat_001",
combatants=[warrior_combatant, goblin_combatant],
)
encounter.initialize_combat()
# Kill the goblin
goblin_combatant.current_hp = 0
# Check end condition
status = encounter.check_end_condition()
assert status == CombatStatus.VICTORY
assert encounter.status == CombatStatus.VICTORY
def test_combat_defeat_condition(warrior_combatant, goblin_combatant):
"""Test defeat condition detection."""
encounter = CombatEncounter(
encounter_id="combat_001",
combatants=[warrior_combatant, goblin_combatant],
)
encounter.initialize_combat()
# Kill the warrior
warrior_combatant.current_hp = 0
# Check end condition
status = encounter.check_end_condition()
assert status == CombatStatus.DEFEAT
assert encounter.status == CombatStatus.DEFEAT
def test_combat_start_turn_processing(warrior_combatant):
"""Test start_turn() processes effects and cooldowns."""
encounter = CombatEncounter(
encounter_id="combat_001",
combatants=[warrior_combatant],
)
# Initialize combat to set turn order
encounter.initialize_combat()
# Add a DOT effect
poison = Effect(
effect_id="poison_1",
name="Poison",
effect_type=EffectType.DOT,
duration=3,
power=5,
)
warrior_combatant.add_effect(poison)
# Add a cooldown
warrior_combatant.cooldowns["power_strike"] = 2
initial_hp = warrior_combatant.current_hp
# Start turn
results = encounter.start_turn()
# Effects should have ticked
assert len(results) == 1
assert warrior_combatant.current_hp == initial_hp - 5
# Cooldown should have decreased
assert warrior_combatant.cooldowns["power_strike"] == 1
def test_combat_logging(warrior_combatant, goblin_combatant):
"""Test combat log entries."""
encounter = CombatEncounter(
encounter_id="combat_001",
combatants=[warrior_combatant, goblin_combatant],
)
encounter.log_action("attack", "warrior_1", "Warrior attacks Goblin for 10 damage")
assert len(encounter.combat_log) == 1
assert encounter.combat_log[0]["action_type"] == "attack"
assert encounter.combat_log[0]["combatant_id"] == "warrior_1"
assert "Warrior attacks Goblin" in encounter.combat_log[0]["message"]
def test_ability_damage_calculation():
"""Test ability power calculation with stat scaling."""
stats = Stats(strength=20, intelligence=16)
# Physical ability scaling with strength
physical = Ability(
ability_id="cleave",
name="Cleave",
description="Powerful strike",
ability_type=AbilityType.SKILL,
base_power=15,
scaling_stat=StatType.STRENGTH,
scaling_factor=0.5,
)
power = physical.calculate_power(stats)
# 15 (base) + (20 strength × 0.5) = 15 + 10 = 25
assert power == 25
# Magical ability scaling with intelligence
magical = Ability(
ability_id="fireball",
name="Fireball",
description="Fire spell",
ability_type=AbilityType.SPELL,
base_power=20,
scaling_stat=StatType.INTELLIGENCE,
scaling_factor=0.5,
)
power = magical.calculate_power(stats)
# 20 (base) + (16 intelligence × 0.5) = 20 + 8 = 28
assert power == 28
def test_full_combat_simulation():
"""Integration test: Full combat simulation with all systems."""
# Create warrior
warrior_stats = Stats(strength=15, constitution=14)
warrior = Combatant(
combatant_id="hero",
name="Hero",
is_player=True,
current_hp=warrior_stats.hit_points,
max_hp=warrior_stats.hit_points,
current_mp=warrior_stats.mana_points,
max_mp=warrior_stats.mana_points,
stats=warrior_stats,
)
# Create goblin
goblin_stats = Stats(strength=8, constitution=10)
goblin = Combatant(
combatant_id="goblin",
name="Goblin",
is_player=False,
current_hp=goblin_stats.hit_points,
max_hp=goblin_stats.hit_points,
current_mp=goblin_stats.mana_points,
max_mp=goblin_stats.mana_points,
stats=goblin_stats,
)
# Create encounter
encounter = CombatEncounter(
encounter_id="test_combat",
combatants=[warrior, goblin],
)
encounter.initialize_combat()
# Verify setup
assert encounter.status == CombatStatus.ACTIVE
assert len(encounter.turn_order) == 2
assert warrior.is_alive() and goblin.is_alive()
# Simulate turns until combat ends
max_turns = 50 # Increased to ensure combat completes
turn_count = 0
while encounter.status == CombatStatus.ACTIVE and turn_count < max_turns:
# Get current combatant
current = encounter.get_current_combatant()
# Start turn (tick effects)
encounter.start_turn()
if current and current.is_alive() and not current.is_stunned():
# Simple AI: deal damage to opponent
if current.combatant_id == "hero":
target = goblin
else:
target = warrior
# Calculate simple attack damage: strength / 2 - target defense
damage = max(1, (current.stats.strength // 2) - target.stats.defense)
target.take_damage(damage)
encounter.log_action(
"attack",
current.combatant_id,
f"{current.name} attacks {target.name} for {damage} damage",
)
# Check for combat end
encounter.check_end_condition()
# Advance turn
encounter.advance_turn()
turn_count += 1
# Combat should have ended
assert encounter.status in [CombatStatus.VICTORY, CombatStatus.DEFEAT]
assert len(encounter.combat_log) > 0

361
api/tests/test_effects.py Normal file
View File

@@ -0,0 +1,361 @@
"""
Unit tests for Effect dataclass.
Tests all effect types, tick() method, stacking, and serialization.
"""
import pytest
from app.models.effects import Effect
from app.models.enums import EffectType, StatType
def test_effect_creation():
"""Test creating an Effect instance."""
effect = Effect(
effect_id="burn_1",
name="Burning",
effect_type=EffectType.DOT,
duration=3,
power=5,
)
assert effect.effect_id == "burn_1"
assert effect.name == "Burning"
assert effect.effect_type == EffectType.DOT
assert effect.duration == 3
assert effect.power == 5
assert effect.stacks == 1
assert effect.max_stacks == 5
def test_dot_effect_tick():
"""Test DOT (damage over time) effect ticking."""
effect = Effect(
effect_id="poison_1",
name="Poisoned",
effect_type=EffectType.DOT,
duration=3,
power=10,
stacks=2,
)
result = effect.tick()
assert result["effect_type"] == "dot"
assert result["value"] == 20 # 10 power × 2 stacks
assert result["expired"] == False
assert effect.duration == 2 # Reduced by 1
def test_hot_effect_tick():
"""Test HOT (heal over time) effect ticking."""
effect = Effect(
effect_id="regen_1",
name="Regeneration",
effect_type=EffectType.HOT,
duration=5,
power=8,
stacks=1,
)
result = effect.tick()
assert result["effect_type"] == "hot"
assert result["value"] == 8 # 8 power × 1 stack
assert result["expired"] == False
assert effect.duration == 4
def test_stun_effect_tick():
"""Test STUN effect ticking."""
effect = Effect(
effect_id="stun_1",
name="Stunned",
effect_type=EffectType.STUN,
duration=1,
power=0,
)
result = effect.tick()
assert result["effect_type"] == "stun"
assert result.get("stunned") == True
assert effect.duration == 0
def test_shield_effect_tick():
"""Test SHIELD effect ticking."""
effect = Effect(
effect_id="shield_1",
name="Shield",
effect_type=EffectType.SHIELD,
duration=3,
power=50,
)
result = effect.tick()
assert result["effect_type"] == "shield"
assert result["shield_remaining"] == 50
assert effect.duration == 2
def test_buff_effect_tick():
"""Test BUFF effect ticking."""
effect = Effect(
effect_id="str_buff_1",
name="Strength Boost",
effect_type=EffectType.BUFF,
duration=4,
power=5,
stat_affected=StatType.STRENGTH,
stacks=2,
)
result = effect.tick()
assert result["effect_type"] == "buff"
assert result["stat_affected"] == "strength"
assert result["stat_modifier"] == 10 # 5 power × 2 stacks
assert effect.duration == 3
def test_debuff_effect_tick():
"""Test DEBUFF effect ticking."""
effect = Effect(
effect_id="weak_1",
name="Weakened",
effect_type=EffectType.DEBUFF,
duration=2,
power=3,
stat_affected=StatType.STRENGTH,
)
result = effect.tick()
assert result["effect_type"] == "debuff"
assert result["stat_affected"] == "strength"
assert result["stat_modifier"] == 3
assert effect.duration == 1
def test_effect_expiration():
"""Test that effect expires when duration reaches 0."""
effect = Effect(
effect_id="burn_1",
name="Burning",
effect_type=EffectType.DOT,
duration=1,
power=5,
)
result = effect.tick()
assert result["expired"] == True
assert effect.duration == 0
def test_effect_stacking():
"""Test apply_stack() increases stacks up to max."""
effect = Effect(
effect_id="poison_1",
name="Poison",
effect_type=EffectType.DOT,
duration=3,
power=5,
max_stacks=5,
)
assert effect.stacks == 1
effect.apply_stack()
assert effect.stacks == 2
effect.apply_stack()
assert effect.stacks == 3
# Apply 3 more to reach max
effect.apply_stack()
effect.apply_stack()
effect.apply_stack()
assert effect.stacks == 5 # Capped at max_stacks
# Try to apply one more - should still be 5
effect.apply_stack()
assert effect.stacks == 5
def test_shield_damage_absorption():
"""Test reduce_shield() method."""
effect = Effect(
effect_id="shield_1",
name="Shield",
effect_type=EffectType.SHIELD,
duration=3,
power=50,
stacks=1,
)
# Shield absorbs 20 damage
remaining = effect.reduce_shield(20)
assert remaining == 0 # All damage absorbed
assert effect.power == 30 # Shield reduced by 20
# Shield absorbs 20 more
remaining = effect.reduce_shield(20)
assert remaining == 0
assert effect.power == 10
# Shield takes 15 damage, breaks completely
remaining = effect.reduce_shield(15)
assert remaining == 5 # 5 damage passes through
assert effect.power == 0
assert effect.duration == 0 # Effect expires
def test_shield_with_stacks():
"""Test shield absorption with multiple stacks."""
effect = Effect(
effect_id="shield_1",
name="Shield",
effect_type=EffectType.SHIELD,
duration=3,
power=20,
stacks=3,
)
# Total shield = 20 × 3 = 60
# Apply 50 damage
remaining = effect.reduce_shield(50)
assert remaining == 0 # All absorbed
# Shield reduced: 50 / 3 stacks = 16.67 per stack
assert effect.power < 20 # Power reduced
def test_effect_serialization():
"""Test to_dict() serialization."""
effect = Effect(
effect_id="burn_1",
name="Burning",
effect_type=EffectType.DOT,
duration=3,
power=5,
stacks=2,
max_stacks=5,
source="fireball",
)
data = effect.to_dict()
assert data["effect_id"] == "burn_1"
assert data["name"] == "Burning"
assert data["effect_type"] == "dot"
assert data["duration"] == 3
assert data["power"] == 5
assert data["stacks"] == 2
assert data["max_stacks"] == 5
assert data["source"] == "fireball"
def test_effect_deserialization():
"""Test from_dict() deserialization."""
data = {
"effect_id": "regen_1",
"name": "Regeneration",
"effect_type": "hot",
"duration": 5,
"power": 10,
"stat_affected": None,
"stacks": 1,
"max_stacks": 5,
"source": "potion",
}
effect = Effect.from_dict(data)
assert effect.effect_id == "regen_1"
assert effect.name == "Regeneration"
assert effect.effect_type == EffectType.HOT
assert effect.duration == 5
assert effect.power == 10
assert effect.stacks == 1
def test_effect_with_stat_deserialization():
"""Test deserializing effect with stat_affected."""
data = {
"effect_id": "buff_1",
"name": "Strength Boost",
"effect_type": "buff",
"duration": 3,
"power": 5,
"stat_affected": "strength",
"stacks": 1,
"max_stacks": 5,
"source": "spell",
}
effect = Effect.from_dict(data)
assert effect.stat_affected == StatType.STRENGTH
assert effect.effect_type == EffectType.BUFF
def test_effect_round_trip_serialization():
"""Test that serialization and deserialization preserve data."""
original = Effect(
effect_id="test_effect",
name="Test Effect",
effect_type=EffectType.DEBUFF,
duration=10,
power=15,
stat_affected=StatType.CONSTITUTION,
stacks=3,
max_stacks=5,
source="enemy_spell",
)
# Serialize then deserialize
data = original.to_dict()
restored = Effect.from_dict(data)
assert restored.effect_id == original.effect_id
assert restored.name == original.name
assert restored.effect_type == original.effect_type
assert restored.duration == original.duration
assert restored.power == original.power
assert restored.stat_affected == original.stat_affected
assert restored.stacks == original.stacks
assert restored.max_stacks == original.max_stacks
assert restored.source == original.source
def test_effect_repr():
"""Test string representation."""
effect = Effect(
effect_id="poison_1",
name="Poison",
effect_type=EffectType.DOT,
duration=3,
power=5,
stacks=2,
)
repr_str = repr(effect)
assert "Poison" in repr_str
assert "dot" in repr_str
def test_non_shield_reduce_shield():
"""Test that reduce_shield() on non-SHIELD effects returns full damage."""
effect = Effect(
effect_id="burn_1",
name="Burning",
effect_type=EffectType.DOT,
duration=3,
power=5,
)
remaining = effect.reduce_shield(50)
assert remaining == 50 # All damage passes through for non-shield effects

View File

@@ -0,0 +1,294 @@
"""
Unit tests for model selector module.
"""
import pytest
from app.ai import (
ModelSelector,
ModelConfig,
UserTier,
ContextType,
ModelType,
)
class TestModelSelector:
"""Tests for ModelSelector class."""
def setup_method(self):
"""Set up test fixtures."""
self.selector = ModelSelector()
def test_initialization(self):
"""Test ModelSelector initializes correctly."""
assert self.selector is not None
# Test tier to model mapping
def test_free_tier_gets_llama(self):
"""Free tier should get Llama-3 8B."""
config = self.selector.select_model(UserTier.FREE)
assert config.model_type == ModelType.LLAMA_3_8B
def test_basic_tier_gets_haiku(self):
"""Basic tier should get Claude Haiku."""
config = self.selector.select_model(UserTier.BASIC)
assert config.model_type == ModelType.CLAUDE_HAIKU
def test_premium_tier_gets_sonnet(self):
"""Premium tier should get Claude Sonnet."""
config = self.selector.select_model(UserTier.PREMIUM)
assert config.model_type == ModelType.CLAUDE_SONNET
def test_elite_tier_gets_opus(self):
"""Elite tier should get Claude Opus."""
config = self.selector.select_model(UserTier.ELITE)
assert config.model_type == ModelType.CLAUDE_SONNET_4
# Test token limits by tier (using STORY_PROGRESSION for full allocation)
def test_free_tier_token_limit(self):
"""Free tier should have 256 base tokens."""
config = self.selector.select_model(UserTier.FREE, ContextType.STORY_PROGRESSION)
assert config.max_tokens == 256
def test_basic_tier_token_limit(self):
"""Basic tier should have 512 base tokens."""
config = self.selector.select_model(UserTier.BASIC, ContextType.STORY_PROGRESSION)
assert config.max_tokens == 512
def test_premium_tier_token_limit(self):
"""Premium tier should have 1024 base tokens."""
config = self.selector.select_model(UserTier.PREMIUM, ContextType.STORY_PROGRESSION)
assert config.max_tokens == 1024
def test_elite_tier_token_limit(self):
"""Elite tier should have 2048 base tokens."""
config = self.selector.select_model(UserTier.ELITE, ContextType.STORY_PROGRESSION)
assert config.max_tokens == 2048
# Test context-based token adjustments
def test_story_progression_full_tokens(self):
"""Story progression should use full token allocation."""
config = self.selector.select_model(
UserTier.PREMIUM,
ContextType.STORY_PROGRESSION
)
# Full allocation = 1024 tokens for premium
assert config.max_tokens == 1024
def test_combat_narration_reduced_tokens(self):
"""Combat narration should use 75% of tokens."""
config = self.selector.select_model(
UserTier.PREMIUM,
ContextType.COMBAT_NARRATION
)
# 75% of 1024 = 768
assert config.max_tokens == 768
def test_quest_selection_half_tokens(self):
"""Quest selection should use 50% of tokens."""
config = self.selector.select_model(
UserTier.PREMIUM,
ContextType.QUEST_SELECTION
)
# 50% of 1024 = 512
assert config.max_tokens == 512
def test_npc_dialogue_reduced_tokens(self):
"""NPC dialogue should use 75% of tokens."""
config = self.selector.select_model(
UserTier.PREMIUM,
ContextType.NPC_DIALOGUE
)
# 75% of 1024 = 768
assert config.max_tokens == 768
def test_simple_response_half_tokens(self):
"""Simple response should use 50% of tokens."""
config = self.selector.select_model(
UserTier.PREMIUM,
ContextType.SIMPLE_RESPONSE
)
# 50% of 1024 = 512
assert config.max_tokens == 512
# Test context-based temperature settings
def test_story_progression_high_temperature(self):
"""Story progression should have high temperature (0.9)."""
config = self.selector.select_model(
UserTier.PREMIUM,
ContextType.STORY_PROGRESSION
)
assert config.temperature == 0.9
def test_combat_narration_medium_high_temperature(self):
"""Combat narration should have medium-high temperature (0.8)."""
config = self.selector.select_model(
UserTier.PREMIUM,
ContextType.COMBAT_NARRATION
)
assert config.temperature == 0.8
def test_quest_selection_low_temperature(self):
"""Quest selection should have low temperature (0.5)."""
config = self.selector.select_model(
UserTier.PREMIUM,
ContextType.QUEST_SELECTION
)
assert config.temperature == 0.5
def test_npc_dialogue_medium_temperature(self):
"""NPC dialogue should have medium temperature (0.85)."""
config = self.selector.select_model(
UserTier.PREMIUM,
ContextType.NPC_DIALOGUE
)
assert config.temperature == 0.85
def test_simple_response_balanced_temperature(self):
"""Simple response should have balanced temperature (0.7)."""
config = self.selector.select_model(
UserTier.PREMIUM,
ContextType.SIMPLE_RESPONSE
)
assert config.temperature == 0.7
# Test ModelConfig properties
def test_model_config_model_property(self):
"""ModelConfig.model should return model identifier string."""
config = self.selector.select_model(UserTier.PREMIUM)
assert config.model == "anthropic/claude-3.5-sonnet"
# Test get_model_for_tier method
def test_get_model_for_tier_free(self):
"""get_model_for_tier should return correct model for free tier."""
model = self.selector.get_model_for_tier(UserTier.FREE)
assert model == ModelType.LLAMA_3_8B
def test_get_model_for_tier_elite(self):
"""get_model_for_tier should return correct model for elite tier."""
model = self.selector.get_model_for_tier(UserTier.ELITE)
assert model == ModelType.CLAUDE_SONNET_4
# Test get_tier_info method
def test_get_tier_info_structure(self):
"""get_tier_info should return complete tier information."""
info = self.selector.get_tier_info(UserTier.PREMIUM)
assert "tier" in info
assert "model" in info
assert "model_name" in info
assert "base_tokens" in info
assert "quality" in info
def test_get_tier_info_premium_values(self):
"""get_tier_info should return correct values for premium tier."""
info = self.selector.get_tier_info(UserTier.PREMIUM)
assert info["tier"] == "premium"
assert info["model"] == "anthropic/claude-3.5-sonnet"
assert info["model_name"] == "Claude 3.5 Sonnet"
assert info["base_tokens"] == 1024
def test_get_tier_info_free_values(self):
"""get_tier_info should return correct values for free tier."""
info = self.selector.get_tier_info(UserTier.FREE)
assert info["tier"] == "free"
assert info["model_name"] == "Llama 3 8B"
assert info["base_tokens"] == 256
# Test estimate_cost_per_request method
def test_free_tier_zero_cost(self):
"""Free tier should have zero cost."""
cost = self.selector.estimate_cost_per_request(UserTier.FREE)
assert cost == 0.0
def test_basic_tier_has_cost(self):
"""Basic tier should have non-zero cost."""
cost = self.selector.estimate_cost_per_request(UserTier.BASIC)
assert cost > 0
def test_premium_tier_higher_cost(self):
"""Premium tier should have higher cost than basic."""
basic_cost = self.selector.estimate_cost_per_request(UserTier.BASIC)
premium_cost = self.selector.estimate_cost_per_request(UserTier.PREMIUM)
assert premium_cost > basic_cost
def test_elite_tier_highest_cost(self):
"""Elite tier should have highest cost."""
premium_cost = self.selector.estimate_cost_per_request(UserTier.PREMIUM)
elite_cost = self.selector.estimate_cost_per_request(UserTier.ELITE)
assert elite_cost > premium_cost
# Test all tier combinations
def test_all_tiers_return_valid_config(self):
"""All tiers should return valid ModelConfig objects."""
for tier in UserTier:
config = self.selector.select_model(tier)
assert isinstance(config, ModelConfig)
assert config.model_type in ModelType
assert config.max_tokens > 0
assert 0 <= config.temperature <= 1
# Test all context combinations
def test_all_contexts_return_valid_config(self):
"""All context types should return valid ModelConfig objects."""
for context in ContextType:
config = self.selector.select_model(UserTier.PREMIUM, context)
assert isinstance(config, ModelConfig)
assert config.max_tokens > 0
assert 0 <= config.temperature <= 1
class TestUserTierEnum:
"""Tests for UserTier enum."""
def test_tier_values(self):
"""Test UserTier enum values are correct strings."""
assert UserTier.FREE.value == "free"
assert UserTier.BASIC.value == "basic"
assert UserTier.PREMIUM.value == "premium"
assert UserTier.ELITE.value == "elite"
def test_tier_string_conversion(self):
"""Test UserTier can be converted to string."""
assert str(UserTier.FREE) == "UserTier.FREE"
class TestContextTypeEnum:
"""Tests for ContextType enum."""
def test_context_values(self):
"""Test ContextType enum values are correct strings."""
assert ContextType.STORY_PROGRESSION.value == "story_progression"
assert ContextType.COMBAT_NARRATION.value == "combat_narration"
assert ContextType.QUEST_SELECTION.value == "quest_selection"
assert ContextType.NPC_DIALOGUE.value == "npc_dialogue"
assert ContextType.SIMPLE_RESPONSE.value == "simple_response"
class TestModelConfig:
"""Tests for ModelConfig dataclass."""
def test_model_config_creation(self):
"""Test ModelConfig can be created with valid data."""
config = ModelConfig(
model_type=ModelType.CLAUDE_SONNET,
max_tokens=1024,
temperature=0.9
)
assert config.model_type == ModelType.CLAUDE_SONNET
assert config.max_tokens == 1024
assert config.temperature == 0.9
def test_model_property(self):
"""Test model property returns model identifier."""
config = ModelConfig(
model_type=ModelType.LLAMA_3_8B,
max_tokens=256,
temperature=0.7
)
assert config.model == "meta/meta-llama-3-8b-instruct"

View File

@@ -0,0 +1,583 @@
"""
Tests for the NarrativeGenerator wrapper.
These tests use mocked AI clients to verify the generator's
behavior without making actual API calls.
"""
import pytest
from unittest.mock import MagicMock, patch
from app.ai.narrative_generator import (
NarrativeGenerator,
NarrativeResponse,
NarrativeGeneratorError,
)
from app.ai.replicate_client import ReplicateResponse, ReplicateClientError
from app.ai.model_selector import UserTier, ContextType, ModelSelector, ModelConfig, ModelType
from app.ai.prompt_templates import PromptTemplates
# Test fixtures
@pytest.fixture
def mock_replicate_client():
"""Create a mock Replicate client."""
client = MagicMock()
client.generate.return_value = ReplicateResponse(
text="The tavern falls silent as you step through the doorway...",
tokens_used=150,
model="meta/meta-llama-3-8b-instruct",
generation_time=2.5
)
return client
@pytest.fixture
def generator(mock_replicate_client):
"""Create a NarrativeGenerator with mocked dependencies."""
return NarrativeGenerator(replicate_client=mock_replicate_client)
@pytest.fixture
def sample_character():
"""Sample character data for tests."""
return {
"name": "Aldric",
"level": 3,
"player_class": "Fighter",
"current_hp": 25,
"max_hp": 30,
"stats": {
"strength": 16,
"dexterity": 12,
"constitution": 14,
"intelligence": 10,
"wisdom": 11,
"charisma": 8
},
"skills": [
{"name": "Sword Mastery", "level": 2},
{"name": "Shield Block", "level": 1}
],
"effects": [],
"completed_quests": []
}
@pytest.fixture
def sample_game_state():
"""Sample game state for tests."""
return {
"current_location": "The Rusty Anchor Tavern",
"location_type": "TAVERN",
"discovered_locations": ["Crossroads Village", "Dark Forest"],
"active_quests": [],
"time_of_day": "Evening"
}
@pytest.fixture
def sample_combat_state():
"""Sample combat state for tests."""
return {
"round_number": 2,
"current_turn": "player",
"enemies": [
{
"name": "Goblin Scout",
"current_hp": 8,
"max_hp": 12,
"effects": []
}
]
}
@pytest.fixture
def sample_npc():
"""Sample NPC data for tests."""
return {
"name": "Barkeep Magnus",
"role": "Tavern Owner",
"personality": "Gruff but kind-hearted",
"speaking_style": "Short sentences, occasional jokes",
"goals": "Keep the tavern running, help adventurers"
}
@pytest.fixture
def sample_quests():
"""Sample eligible quests for tests."""
return [
{
"quest_id": "goblin_cave",
"name": "Clear the Goblin Cave",
"difficulty": "EASY",
"quest_giver": "Village Elder",
"description": "A nearby cave has been overrun by goblins.",
"narrative_hooks": [
"Farmers complain about stolen livestock",
"Goblin tracks spotted on the road"
]
},
{
"quest_id": "missing_merchant",
"name": "Find the Missing Merchant",
"difficulty": "MEDIUM",
"quest_giver": "Guild Master",
"description": "A merchant caravan has gone missing.",
"narrative_hooks": [
"The merchant was carrying valuable goods",
"His family is worried sick"
]
}
]
class TestNarrativeGeneratorInit:
"""Tests for NarrativeGenerator initialization."""
def test_init_with_defaults(self):
"""Test initialization with default dependencies."""
with patch('app.ai.narrative_generator.ReplicateClient'):
generator = NarrativeGenerator()
assert generator.model_selector is not None
assert generator.prompt_templates is not None
def test_init_with_custom_selector(self):
"""Test initialization with custom model selector."""
custom_selector = ModelSelector()
generator = NarrativeGenerator(model_selector=custom_selector)
assert generator.model_selector is custom_selector
def test_init_with_custom_client(self, mock_replicate_client):
"""Test initialization with custom Replicate client."""
generator = NarrativeGenerator(replicate_client=mock_replicate_client)
assert generator.replicate_client is mock_replicate_client
class TestGenerateStoryResponse:
"""Tests for generate_story_response method."""
def test_basic_story_generation(
self, generator, mock_replicate_client, sample_character, sample_game_state
):
"""Test basic story response generation."""
response = generator.generate_story_response(
character=sample_character,
action="I search the room for hidden doors",
game_state=sample_game_state,
user_tier=UserTier.FREE
)
assert isinstance(response, NarrativeResponse)
assert response.narrative is not None
assert response.tokens_used > 0
assert response.context_type == "story_progression"
assert response.generation_time > 0
# Verify client was called
mock_replicate_client.generate.assert_called_once()
def test_story_with_conversation_history(
self, generator, mock_replicate_client, sample_character, sample_game_state
):
"""Test story generation with conversation history."""
history = [
{
"turn": 1,
"action": "I enter the tavern",
"dm_response": "The tavern is warm and inviting..."
},
{
"turn": 2,
"action": "I approach the bar",
"dm_response": "The barkeep nods in greeting..."
}
]
response = generator.generate_story_response(
character=sample_character,
action="I ask about local rumors",
game_state=sample_game_state,
user_tier=UserTier.PREMIUM,
conversation_history=history
)
assert isinstance(response, NarrativeResponse)
mock_replicate_client.generate.assert_called_once()
def test_story_with_world_context(
self, generator, mock_replicate_client, sample_character, sample_game_state
):
"""Test story generation with additional world context."""
response = generator.generate_story_response(
character=sample_character,
action="I look around",
game_state=sample_game_state,
user_tier=UserTier.ELITE,
world_context="A festival is being celebrated in the village"
)
assert isinstance(response, NarrativeResponse)
def test_story_uses_correct_model_for_tier(
self, generator, mock_replicate_client, sample_character, sample_game_state
):
"""Test that correct model is selected based on tier."""
for tier in UserTier:
mock_replicate_client.generate.reset_mock()
generator.generate_story_response(
character=sample_character,
action="Test action",
game_state=sample_game_state,
user_tier=tier
)
# Verify generate was called with appropriate model
call_kwargs = mock_replicate_client.generate.call_args
assert call_kwargs is not None
def test_story_handles_client_error(
self, generator, mock_replicate_client, sample_character, sample_game_state
):
"""Test error handling when client fails."""
mock_replicate_client.generate.side_effect = ReplicateClientError("API error")
with pytest.raises(NarrativeGeneratorError) as exc_info:
generator.generate_story_response(
character=sample_character,
action="Test action",
game_state=sample_game_state,
user_tier=UserTier.FREE
)
assert "AI generation failed" in str(exc_info.value)
class TestGenerateCombatNarration:
"""Tests for generate_combat_narration method."""
def test_basic_combat_narration(
self, generator, mock_replicate_client, sample_character, sample_combat_state
):
"""Test basic combat narration generation."""
action_result = {
"hit": True,
"damage": 8,
"target": "Goblin Scout"
}
response = generator.generate_combat_narration(
character=sample_character,
combat_state=sample_combat_state,
action="swings their sword",
action_result=action_result,
user_tier=UserTier.BASIC
)
assert isinstance(response, NarrativeResponse)
assert response.context_type == "combat_narration"
def test_critical_hit_narration(
self, generator, mock_replicate_client, sample_character, sample_combat_state
):
"""Test combat narration for critical hit."""
action_result = {
"hit": True,
"damage": 16,
"target": "Goblin Scout"
}
response = generator.generate_combat_narration(
character=sample_character,
combat_state=sample_combat_state,
action="strikes with precision",
action_result=action_result,
user_tier=UserTier.PREMIUM,
is_critical=True
)
assert isinstance(response, NarrativeResponse)
def test_finishing_blow_narration(
self, generator, mock_replicate_client, sample_character, sample_combat_state
):
"""Test combat narration for finishing blow."""
action_result = {
"hit": True,
"damage": 10,
"target": "Goblin Scout"
}
response = generator.generate_combat_narration(
character=sample_character,
combat_state=sample_combat_state,
action="delivers the final blow",
action_result=action_result,
user_tier=UserTier.ELITE,
is_finishing_blow=True
)
assert isinstance(response, NarrativeResponse)
def test_miss_narration(
self, generator, mock_replicate_client, sample_character, sample_combat_state
):
"""Test combat narration for a miss."""
action_result = {
"hit": False,
"damage": 0,
"target": "Goblin Scout"
}
response = generator.generate_combat_narration(
character=sample_character,
combat_state=sample_combat_state,
action="swings wildly",
action_result=action_result,
user_tier=UserTier.FREE
)
assert isinstance(response, NarrativeResponse)
class TestGenerateQuestSelection:
"""Tests for generate_quest_selection method."""
def test_basic_quest_selection(
self, generator, mock_replicate_client, sample_character, sample_game_state, sample_quests
):
"""Test basic quest selection."""
# Mock response to return a valid quest_id
mock_replicate_client.generate.return_value = ReplicateResponse(
text="goblin_cave",
tokens_used=50,
model="meta/meta-llama-3-8b-instruct",
generation_time=1.0
)
quest_id = generator.generate_quest_selection(
character=sample_character,
eligible_quests=sample_quests,
game_context=sample_game_state,
user_tier=UserTier.FREE
)
assert quest_id == "goblin_cave"
def test_quest_selection_with_recent_actions(
self, generator, mock_replicate_client, sample_character, sample_game_state, sample_quests
):
"""Test quest selection with recent actions."""
mock_replicate_client.generate.return_value = ReplicateResponse(
text="missing_merchant",
tokens_used=50,
model="meta/meta-llama-3-8b-instruct",
generation_time=1.0
)
recent_actions = [
"Asked about missing traders",
"Investigated the market square"
]
quest_id = generator.generate_quest_selection(
character=sample_character,
eligible_quests=sample_quests,
game_context=sample_game_state,
user_tier=UserTier.PREMIUM,
recent_actions=recent_actions
)
assert quest_id == "missing_merchant"
def test_quest_selection_invalid_response_fallback(
self, generator, mock_replicate_client, sample_character, sample_game_state, sample_quests
):
"""Test fallback when AI returns invalid quest_id."""
# Mock response with invalid quest_id
mock_replicate_client.generate.return_value = ReplicateResponse(
text="invalid_quest_id",
tokens_used=50,
model="meta/meta-llama-3-8b-instruct",
generation_time=1.0
)
quest_id = generator.generate_quest_selection(
character=sample_character,
eligible_quests=sample_quests,
game_context=sample_game_state,
user_tier=UserTier.FREE
)
# Should fall back to first eligible quest
assert quest_id == "goblin_cave"
def test_quest_selection_no_quests_error(
self, generator, sample_character, sample_game_state
):
"""Test error when no eligible quests provided."""
with pytest.raises(NarrativeGeneratorError) as exc_info:
generator.generate_quest_selection(
character=sample_character,
eligible_quests=[],
game_context=sample_game_state,
user_tier=UserTier.FREE
)
assert "No eligible quests" in str(exc_info.value)
class TestGenerateNPCDialogue:
"""Tests for generate_npc_dialogue method."""
def test_basic_npc_dialogue(
self, generator, mock_replicate_client, sample_character, sample_npc, sample_game_state
):
"""Test basic NPC dialogue generation."""
mock_replicate_client.generate.return_value = ReplicateResponse(
text='*wipes down the bar* "Aye, rumors aplenty around here..."',
tokens_used=100,
model="anthropic/claude-3.5-haiku",
generation_time=1.5
)
response = generator.generate_npc_dialogue(
character=sample_character,
npc=sample_npc,
conversation_topic="What rumors have you heard?",
game_state=sample_game_state,
user_tier=UserTier.BASIC
)
assert isinstance(response, NarrativeResponse)
assert response.context_type == "npc_dialogue"
def test_npc_dialogue_with_relationship(
self, generator, mock_replicate_client, sample_character, sample_npc, sample_game_state
):
"""Test NPC dialogue with established relationship."""
response = generator.generate_npc_dialogue(
character=sample_character,
npc=sample_npc,
conversation_topic="Hello old friend",
game_state=sample_game_state,
user_tier=UserTier.PREMIUM,
npc_relationship="Friendly - helped defend the tavern last month"
)
assert isinstance(response, NarrativeResponse)
def test_npc_dialogue_with_previous_conversation(
self, generator, mock_replicate_client, sample_character, sample_npc, sample_game_state
):
"""Test NPC dialogue with previous conversation history."""
previous = [
{
"player_line": "What's on tap tonight?",
"npc_response": "Got some fine ale from the southern vineyards."
}
]
response = generator.generate_npc_dialogue(
character=sample_character,
npc=sample_npc,
conversation_topic="I'll take one",
game_state=sample_game_state,
user_tier=UserTier.FREE,
previous_dialogue=previous
)
assert isinstance(response, NarrativeResponse)
def test_npc_dialogue_with_special_knowledge(
self, generator, mock_replicate_client, sample_character, sample_npc, sample_game_state
):
"""Test NPC dialogue with special knowledge."""
response = generator.generate_npc_dialogue(
character=sample_character,
npc=sample_npc,
conversation_topic="Have you seen anything strange?",
game_state=sample_game_state,
user_tier=UserTier.ELITE,
npc_knowledge=["Secret passage in the cellar", "Hidden treasure map"]
)
assert isinstance(response, NarrativeResponse)
class TestModelSelection:
"""Tests for model selection behavior."""
def test_free_tier_uses_llama(
self, generator, mock_replicate_client, sample_character, sample_game_state
):
"""Test that free tier uses Llama model."""
generator.generate_story_response(
character=sample_character,
action="Test",
game_state=sample_game_state,
user_tier=UserTier.FREE
)
call_kwargs = mock_replicate_client.generate.call_args
assert call_kwargs is not None
# Model type should be Llama for free tier
model_arg = call_kwargs.kwargs.get('model')
if model_arg:
assert "llama" in str(model_arg).lower() or model_arg == ModelType.LLAMA_3_8B
def test_different_contexts_use_different_temperatures(self, generator):
"""Test that different contexts have different temperature settings."""
# Get model configs for different contexts
config_story = generator.model_selector.select_model(
UserTier.FREE, ContextType.STORY_PROGRESSION
)
config_quest = generator.model_selector.select_model(
UserTier.FREE, ContextType.QUEST_SELECTION
)
# Story should have higher temperature (more creative)
assert config_story.temperature > config_quest.temperature
class TestErrorHandling:
"""Tests for error handling behavior."""
def test_template_error_handling(self, mock_replicate_client):
"""Test handling of template errors."""
from app.ai.prompt_templates import PromptTemplateError
# Create generator with bad template path
with patch.object(PromptTemplates, 'render') as mock_render:
mock_render.side_effect = PromptTemplateError("Template not found")
generator = NarrativeGenerator(replicate_client=mock_replicate_client)
with pytest.raises(NarrativeGeneratorError) as exc_info:
generator.generate_story_response(
character={"name": "Test"},
action="Test",
game_state={"current_location": "Test"},
user_tier=UserTier.FREE
)
assert "Prompt template error" in str(exc_info.value)
def test_api_error_handling(
self, generator, mock_replicate_client, sample_character, sample_game_state
):
"""Test handling of API errors."""
mock_replicate_client.generate.side_effect = ReplicateClientError("Connection failed")
with pytest.raises(NarrativeGeneratorError) as exc_info:
generator.generate_story_response(
character=sample_character,
action="Test",
game_state=sample_game_state,
user_tier=UserTier.FREE
)
assert "AI generation failed" in str(exc_info.value)

View File

@@ -0,0 +1,200 @@
"""
Unit tests for OriginService - character origin loading and validation.
Tests verify that origins load correctly from YAML, contain all required
data, and provide proper narrative hooks for the AI DM.
"""
import pytest
from app.services.origin_service import OriginService, get_origin_service
from app.models.origins import Origin, StartingLocation, StartingBonus
class TestOriginService:
"""Test suite for OriginService functionality."""
@pytest.fixture
def origin_service(self):
"""Create a fresh OriginService instance for each test."""
service = OriginService()
service.clear_cache()
return service
def test_service_initializes(self, origin_service):
"""Test that OriginService initializes correctly."""
assert origin_service is not None
assert origin_service.data_file.exists()
def test_singleton_pattern(self):
"""Test that get_origin_service returns a singleton."""
service1 = get_origin_service()
service2 = get_origin_service()
assert service1 is service2
def test_load_all_origins(self, origin_service):
"""Test loading all origins from YAML file."""
origins = origin_service.load_all_origins()
assert len(origins) == 4
origin_ids = [origin.id for origin in origins]
assert "soul_revenant" in origin_ids
assert "memory_thief" in origin_ids
assert "shadow_apprentice" in origin_ids
assert "escaped_captive" in origin_ids
def test_load_origin_by_id(self, origin_service):
"""Test loading a specific origin by ID."""
origin = origin_service.load_origin("soul_revenant")
assert origin is not None
assert origin.id == "soul_revenant"
assert origin.name == "Soul Revenant"
assert isinstance(origin.description, str)
assert len(origin.description) > 0
def test_origin_not_found(self, origin_service):
"""Test that loading non-existent origin returns None."""
origin = origin_service.load_origin("nonexistent_origin")
assert origin is None
def test_origin_has_starting_location(self, origin_service):
"""Test that all origins have valid starting locations."""
origins = origin_service.load_all_origins()
for origin in origins:
assert origin.starting_location is not None
assert isinstance(origin.starting_location, StartingLocation)
assert origin.starting_location.id
assert origin.starting_location.name
assert origin.starting_location.region
assert origin.starting_location.description
def test_origin_has_narrative_hooks(self, origin_service):
"""Test that all origins have narrative hooks for AI DM."""
origins = origin_service.load_all_origins()
for origin in origins:
assert origin.narrative_hooks is not None
assert isinstance(origin.narrative_hooks, list)
assert len(origin.narrative_hooks) > 0, f"{origin.id} has no narrative hooks"
def test_origin_has_starting_bonus(self, origin_service):
"""Test that all origins have starting bonuses."""
origins = origin_service.load_all_origins()
for origin in origins:
assert origin.starting_bonus is not None
assert isinstance(origin.starting_bonus, StartingBonus)
assert origin.starting_bonus.trait
assert origin.starting_bonus.description
assert origin.starting_bonus.effect
def test_soul_revenant_details(self, origin_service):
"""Test Soul Revenant origin has correct details."""
origin = origin_service.load_origin("soul_revenant")
assert origin.name == "Soul Revenant"
assert "centuries" in origin.description.lower()
assert origin.starting_location.id == "forgotten_crypt"
assert "Deathless Resolve" in origin.starting_bonus.trait
assert any("past" in hook.lower() for hook in origin.narrative_hooks)
def test_memory_thief_details(self, origin_service):
"""Test Memory Thief origin has correct details."""
origin = origin_service.load_origin("memory_thief")
assert origin.name == "Memory Thief"
assert "memory" in origin.description.lower()
assert origin.starting_location.id == "thornfield_plains"
assert "Blank Slate" in origin.starting_bonus.trait
assert any("memory" in hook.lower() for hook in origin.narrative_hooks)
def test_shadow_apprentice_details(self, origin_service):
"""Test Shadow Apprentice origin has correct details."""
origin = origin_service.load_origin("shadow_apprentice")
assert origin.name == "Shadow Apprentice"
assert "master" in origin.description.lower()
assert origin.starting_location.id == "shadowfen"
assert "Trained in Shadows" in origin.starting_bonus.trait
assert any("master" in hook.lower() for hook in origin.narrative_hooks)
def test_escaped_captive_details(self, origin_service):
"""Test Escaped Captive origin has correct details."""
origin = origin_service.load_origin("escaped_captive")
assert origin.name == "The Escaped Captive"
assert "prison" in origin.description.lower() or "ironpeak" in origin.description.lower()
assert origin.starting_location.id == "ironpeak_pass"
assert "Hardened Survivor" in origin.starting_bonus.trait
assert any("prison" in hook.lower() or "past" in hook.lower() for hook in origin.narrative_hooks)
def test_origin_serialization(self, origin_service):
"""Test that origins can be serialized to dict and back."""
original = origin_service.load_origin("soul_revenant")
# Serialize to dict
origin_dict = original.to_dict()
assert isinstance(origin_dict, dict)
assert origin_dict["id"] == "soul_revenant"
assert origin_dict["name"] == "Soul Revenant"
# Deserialize from dict
restored = Origin.from_dict(origin_dict)
assert restored.id == original.id
assert restored.name == original.name
assert restored.description == original.description
assert restored.starting_location.id == original.starting_location.id
assert restored.starting_bonus.trait == original.starting_bonus.trait
def test_caching_works(self, origin_service):
"""Test that caching improves performance on repeated loads."""
# First load
origin1 = origin_service.load_origin("soul_revenant")
# Second load should come from cache
origin2 = origin_service.load_origin("soul_revenant")
# Should be the exact same instance from cache
assert origin1 is origin2
def test_cache_clear(self, origin_service):
"""Test that cache clearing works correctly."""
# Load origins to populate cache
origin_service.load_all_origins()
assert len(origin_service._origins_cache) > 0
# Clear cache
origin_service.clear_cache()
assert len(origin_service._origins_cache) == 0
assert origin_service._all_origins_loaded is False
def test_reload_origins(self, origin_service):
"""Test that reloading clears cache and reloads from file."""
# Load origins
origins1 = origin_service.load_all_origins()
# Reload
origins2 = origin_service.reload_origins()
# Should have same content but be fresh instances
assert len(origins1) == len(origins2)
assert all(o1.id == o2.id for o1, o2 in zip(origins1, origins2))
def test_get_all_origin_ids(self, origin_service):
"""Test getting list of all origin IDs."""
origin_ids = origin_service.get_all_origin_ids()
assert isinstance(origin_ids, list)
assert len(origin_ids) == 4
assert "soul_revenant" in origin_ids
assert "memory_thief" in origin_ids
assert "shadow_apprentice" in origin_ids
assert "escaped_captive" in origin_ids
def test_get_origin_by_id_alias(self, origin_service):
"""Test that get_origin_by_id is an alias for load_origin."""
origin1 = origin_service.load_origin("soul_revenant")
origin2 = origin_service.get_origin_by_id("soul_revenant")
assert origin1 is origin2

View File

@@ -0,0 +1,321 @@
"""
Unit tests for prompt templates module.
"""
import pytest
from pathlib import Path
from app.ai import (
PromptTemplates,
PromptTemplateError,
get_prompt_templates,
render_prompt,
)
class TestPromptTemplates:
"""Tests for PromptTemplates class."""
def setup_method(self):
"""Set up test fixtures."""
self.templates = PromptTemplates()
def test_initialization(self):
"""Test PromptTemplates initializes correctly."""
assert self.templates is not None
assert self.templates.env is not None
def test_template_directory_exists(self):
"""Test that template directory is created."""
assert self.templates.template_dir.exists()
def test_get_template_names(self):
"""Test listing available templates."""
names = self.templates.get_template_names()
assert isinstance(names, list)
# Should have our 4 core templates
assert 'story_action.j2' in names
assert 'combat_action.j2' in names
assert 'quest_offering.j2' in names
assert 'npc_dialogue.j2' in names
def test_render_string_simple(self):
"""Test rendering a simple template string."""
result = self.templates.render_string(
"Hello, {{ name }}!",
name="Player"
)
assert result == "Hello, Player!"
def test_render_string_with_filter(self):
"""Test rendering with custom filter."""
result = self.templates.render_string(
"{{ items | format_inventory }}",
items=[
{"name": "Sword", "quantity": 1},
{"name": "Potion", "quantity": 3}
]
)
assert "Sword" in result
assert "Potion (x3)" in result
def test_render_story_action_template(self):
"""Test rendering story_action template."""
result = self.templates.render(
"story_action.j2",
character={
"name": "Aldric",
"level": 5,
"player_class": "Warrior",
"current_hp": 45,
"max_hp": 50,
"stats": {"strength": 16, "dexterity": 12},
"skills": [],
"effects": []
},
game_state={
"current_location": "The Rusty Anchor",
"location_type": "TAVERN",
"active_quests": [],
"discovered_locations": []
},
action="I look around the tavern for anyone suspicious"
)
assert "Aldric" in result
assert "Warrior" in result
assert "Rusty Anchor" in result
assert "suspicious" in result
def test_render_combat_action_template(self):
"""Test rendering combat_action template."""
result = self.templates.render(
"combat_action.j2",
character={
"name": "Aldric",
"level": 5,
"player_class": "Warrior",
"current_hp": 45,
"max_hp": 50,
"effects": []
},
combat_state={
"round_number": 2,
"current_turn": "Player",
"enemies": [
{
"name": "Goblin",
"current_hp": 8,
"max_hp": 15,
"effects": []
}
]
},
action="swings their sword at the Goblin",
action_result={
"hit": True,
"damage": 7,
"effects_applied": []
},
is_critical=False,
is_finishing_blow=False
)
assert "Aldric" in result
assert "Goblin" in result
assert "sword" in result
def test_render_quest_offering_template(self):
"""Test rendering quest_offering template."""
result = self.templates.render(
"quest_offering.j2",
character={
"name": "Aldric",
"level": 3,
"player_class": "Warrior",
"completed_quests": []
},
game_context={
"current_location": "Village Square",
"location_type": "TOWN",
"active_quests": [],
"world_events": []
},
eligible_quests=[
{
"quest_id": "quest_goblin_cave",
"name": "Clear the Goblin Cave",
"difficulty": "EASY",
"quest_giver": "Village Elder",
"description": "Goblins have been raiding farms",
"narrative_hooks": [
"Farmers complaining about lost livestock"
]
}
],
recent_actions=["Talked to locals"]
)
assert "quest_goblin_cave" in result
assert "Clear the Goblin Cave" in result
assert "Village Elder" in result
def test_render_npc_dialogue_template(self):
"""Test rendering npc_dialogue template."""
result = self.templates.render(
"npc_dialogue.j2",
character={
"name": "Aldric",
"level": 5,
"player_class": "Warrior"
},
npc={
"name": "Grizzled Bartender",
"role": "Tavern Owner",
"personality": "Gruff but kind",
"speaking_style": "Short sentences, common slang"
},
conversation_topic="What's the latest news around here?",
game_state={
"current_location": "The Rusty Anchor",
"time_of_day": "Evening",
"active_quests": []
}
)
assert "Grizzled Bartender" in result
assert "Aldric" in result
assert "news" in result
def test_format_inventory_filter_empty(self):
"""Test format_inventory filter with empty list."""
result = PromptTemplates._format_inventory([])
assert result == "Empty inventory"
def test_format_inventory_filter_single(self):
"""Test format_inventory filter with single item."""
result = PromptTemplates._format_inventory([
{"name": "Sword", "quantity": 1}
])
assert result == "Sword"
def test_format_inventory_filter_multiple(self):
"""Test format_inventory filter with multiple items."""
result = PromptTemplates._format_inventory([
{"name": "Sword", "quantity": 1},
{"name": "Shield", "quantity": 1},
{"name": "Potion", "quantity": 5}
])
assert "Sword" in result
assert "Shield" in result
assert "Potion (x5)" in result
def test_format_inventory_filter_truncation(self):
"""Test format_inventory filter truncates long lists."""
items = [{"name": f"Item{i}", "quantity": 1} for i in range(15)]
result = PromptTemplates._format_inventory(items, max_items=10)
assert "and 5 more items" in result
def test_format_stats_filter(self):
"""Test format_stats filter."""
result = PromptTemplates._format_stats({
"strength": 16,
"dexterity": 14
})
assert "Strength: 16" in result
assert "Dexterity: 14" in result
def test_format_stats_filter_empty(self):
"""Test format_stats filter with empty dict."""
result = PromptTemplates._format_stats({})
assert result == "No stats available"
def test_format_skills_filter(self):
"""Test format_skills filter."""
result = PromptTemplates._format_skills([
{"name": "Sword Mastery", "level": 3},
{"name": "Shield Block", "level": 2}
])
assert "Sword Mastery (Lv.3)" in result
assert "Shield Block (Lv.2)" in result
def test_format_skills_filter_empty(self):
"""Test format_skills filter with empty list."""
result = PromptTemplates._format_skills([])
assert result == "No skills"
def test_format_effects_filter(self):
"""Test format_effects filter."""
result = PromptTemplates._format_effects([
{"name": "Blessed", "remaining_turns": 3},
{"name": "Strength Buff"}
])
assert "Blessed (3 turns)" in result
assert "Strength Buff" in result
def test_format_effects_filter_empty(self):
"""Test format_effects filter with empty list."""
result = PromptTemplates._format_effects([])
assert result == "No active effects"
def test_truncate_text_filter_short(self):
"""Test truncate_text filter with short text."""
result = PromptTemplates._truncate_text("Hello", 100)
assert result == "Hello"
def test_truncate_text_filter_long(self):
"""Test truncate_text filter with long text."""
long_text = "A" * 150
result = PromptTemplates._truncate_text(long_text, 100)
assert len(result) == 100
assert result.endswith("...")
def test_format_gold_filter(self):
"""Test format_gold filter."""
assert PromptTemplates._format_gold(1000) == "1,000 gold"
assert PromptTemplates._format_gold(1000000) == "1,000,000 gold"
assert PromptTemplates._format_gold(50) == "50 gold"
def test_invalid_template_raises_error(self):
"""Test that invalid template raises PromptTemplateError."""
with pytest.raises(PromptTemplateError):
self.templates.render("nonexistent_template.j2")
def test_invalid_template_string_raises_error(self):
"""Test that invalid template string raises PromptTemplateError."""
with pytest.raises(PromptTemplateError):
self.templates.render_string("{{ invalid syntax")
class TestPromptTemplateConvenienceFunctions:
"""Tests for module-level convenience functions."""
def test_get_prompt_templates_singleton(self):
"""Test get_prompt_templates returns singleton."""
templates1 = get_prompt_templates()
templates2 = get_prompt_templates()
assert templates1 is templates2
def test_render_prompt_function(self):
"""Test render_prompt convenience function."""
result = render_prompt(
"story_action.j2",
character={
"name": "Test",
"level": 1,
"player_class": "Warrior",
"current_hp": 10,
"max_hp": 10,
"stats": {},
"skills": [],
"effects": []
},
game_state={
"current_location": "Test Location",
"location_type": "TOWN",
"active_quests": []
},
action="test action"
)
assert "Test" in result
assert "test action" in result

View File

@@ -0,0 +1,342 @@
"""
Tests for RateLimiterService
Tests the tier-based rate limiting functionality including:
- Daily limits per tier
- Usage tracking and incrementing
- Rate limit checks and exceptions
- Reset functionality
"""
import pytest
from unittest.mock import MagicMock, patch
from datetime import datetime, timezone, timedelta
from app.services.rate_limiter_service import (
RateLimiterService,
RateLimitExceeded,
)
from app.ai.model_selector import UserTier
class TestRateLimiterService:
"""Tests for RateLimiterService."""
@pytest.fixture
def mock_redis(self):
"""Create a mock Redis service."""
mock = MagicMock()
mock.get.return_value = None
mock.incr.return_value = 1
mock.expire.return_value = True
mock.delete.return_value = 1
return mock
@pytest.fixture
def rate_limiter(self, mock_redis):
"""Create a RateLimiterService with mock Redis."""
return RateLimiterService(redis_service=mock_redis)
def test_tier_limits(self, rate_limiter):
"""Test that tier limits are correctly defined."""
assert rate_limiter.get_limit_for_tier(UserTier.FREE) == 20
assert rate_limiter.get_limit_for_tier(UserTier.BASIC) == 50
assert rate_limiter.get_limit_for_tier(UserTier.PREMIUM) == 100
assert rate_limiter.get_limit_for_tier(UserTier.ELITE) == 200
def test_get_daily_key(self, rate_limiter):
"""Test Redis key generation."""
from datetime import date
key = rate_limiter._get_daily_key("user_123", date(2025, 1, 15))
assert key == "rate_limit:daily:user_123:2025-01-15"
def test_get_current_usage_no_usage(self, rate_limiter, mock_redis):
"""Test getting current usage when no usage exists."""
mock_redis.get.return_value = None
usage = rate_limiter.get_current_usage("user_123")
assert usage == 0
mock_redis.get.assert_called_once()
def test_get_current_usage_with_usage(self, rate_limiter, mock_redis):
"""Test getting current usage when usage exists."""
mock_redis.get.return_value = "15"
usage = rate_limiter.get_current_usage("user_123")
assert usage == 15
def test_check_rate_limit_under_limit(self, rate_limiter, mock_redis):
"""Test that check passes when under limit."""
mock_redis.get.return_value = "10"
# Should not raise
rate_limiter.check_rate_limit("user_123", UserTier.FREE)
def test_check_rate_limit_at_limit(self, rate_limiter, mock_redis):
"""Test that check fails when at limit."""
mock_redis.get.return_value = "20" # Free tier limit
with pytest.raises(RateLimitExceeded) as exc_info:
rate_limiter.check_rate_limit("user_123", UserTier.FREE)
exc = exc_info.value
assert exc.user_id == "user_123"
assert exc.user_tier == UserTier.FREE
assert exc.limit == 20
assert exc.current_usage == 20
def test_check_rate_limit_over_limit(self, rate_limiter, mock_redis):
"""Test that check fails when over limit."""
mock_redis.get.return_value = "25" # Over free tier limit
with pytest.raises(RateLimitExceeded):
rate_limiter.check_rate_limit("user_123", UserTier.FREE)
def test_check_rate_limit_premium_tier(self, rate_limiter, mock_redis):
"""Test that premium tier has higher limit."""
mock_redis.get.return_value = "50" # Over free limit, under premium
# Should not raise for premium
rate_limiter.check_rate_limit("user_123", UserTier.PREMIUM)
# Should raise for free
with pytest.raises(RateLimitExceeded):
rate_limiter.check_rate_limit("user_123", UserTier.FREE)
def test_increment_usage_first_time(self, rate_limiter, mock_redis):
"""Test incrementing usage for the first time (sets expiration)."""
mock_redis.incr.return_value = 1
new_count = rate_limiter.increment_usage("user_123")
assert new_count == 1
mock_redis.incr.assert_called_once()
mock_redis.expire.assert_called_once() # Should set expiration
def test_increment_usage_subsequent(self, rate_limiter, mock_redis):
"""Test incrementing usage after first time (no expiration set)."""
mock_redis.incr.return_value = 5
new_count = rate_limiter.increment_usage("user_123")
assert new_count == 5
mock_redis.incr.assert_called_once()
mock_redis.expire.assert_not_called() # Should NOT set expiration
def test_get_remaining_turns_full(self, rate_limiter, mock_redis):
"""Test remaining turns when no usage."""
mock_redis.get.return_value = None
remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE)
assert remaining == 20
def test_get_remaining_turns_partial(self, rate_limiter, mock_redis):
"""Test remaining turns with partial usage."""
mock_redis.get.return_value = "12"
remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE)
assert remaining == 8
def test_get_remaining_turns_exhausted(self, rate_limiter, mock_redis):
"""Test remaining turns when limit reached."""
mock_redis.get.return_value = "20"
remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE)
assert remaining == 0
def test_get_remaining_turns_over_limit(self, rate_limiter, mock_redis):
"""Test remaining turns when over limit (should be 0, not negative)."""
mock_redis.get.return_value = "25"
remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE)
assert remaining == 0
def test_get_usage_info(self, rate_limiter, mock_redis):
"""Test getting comprehensive usage info."""
mock_redis.get.return_value = "15"
info = rate_limiter.get_usage_info("user_123", UserTier.FREE)
assert info["user_id"] == "user_123"
assert info["user_tier"] == "free"
assert info["current_usage"] == 15
assert info["daily_limit"] == 20
assert info["remaining"] == 5
assert info["is_limited"] is False
assert "reset_time" in info
def test_get_usage_info_limited(self, rate_limiter, mock_redis):
"""Test usage info when limited."""
mock_redis.get.return_value = "20"
info = rate_limiter.get_usage_info("user_123", UserTier.FREE)
assert info["is_limited"] is True
assert info["remaining"] == 0
def test_reset_usage(self, rate_limiter, mock_redis):
"""Test resetting usage counter."""
mock_redis.delete.return_value = 1
result = rate_limiter.reset_usage("user_123")
assert result is True
mock_redis.delete.assert_called_once()
def test_reset_usage_no_key(self, rate_limiter, mock_redis):
"""Test resetting when no usage exists."""
mock_redis.delete.return_value = 0
result = rate_limiter.reset_usage("user_123")
assert result is False
class TestRateLimitExceeded:
"""Tests for RateLimitExceeded exception."""
def test_exception_attributes(self):
"""Test that exception has correct attributes."""
reset_time = datetime(2025, 1, 16, 0, 0, 0, tzinfo=timezone.utc)
exc = RateLimitExceeded(
user_id="user_123",
user_tier=UserTier.FREE,
limit=20,
current_usage=20,
reset_time=reset_time
)
assert exc.user_id == "user_123"
assert exc.user_tier == UserTier.FREE
assert exc.limit == 20
assert exc.current_usage == 20
assert exc.reset_time == reset_time
def test_exception_message(self):
"""Test that exception message is formatted correctly."""
reset_time = datetime(2025, 1, 16, 0, 0, 0, tzinfo=timezone.utc)
exc = RateLimitExceeded(
user_id="user_123",
user_tier=UserTier.FREE,
limit=20,
current_usage=20,
reset_time=reset_time
)
message = str(exc)
assert "user_123" in message
assert "free tier" in message
assert "20/20" in message
class TestRateLimiterIntegration:
"""Integration tests for rate limiter workflow."""
@pytest.fixture
def mock_redis(self):
"""Create a mock Redis that simulates real behavior."""
storage = {}
mock = MagicMock()
def mock_get(key):
return storage.get(key)
def mock_incr(key):
if key not in storage:
storage[key] = 0
storage[key] = int(storage[key]) + 1
return storage[key]
def mock_delete(key):
if key in storage:
del storage[key]
return 1
return 0
mock.get.side_effect = mock_get
mock.incr.side_effect = mock_incr
mock.delete.side_effect = mock_delete
mock.expire.return_value = True
return mock
@pytest.fixture
def rate_limiter(self, mock_redis):
"""Create rate limiter with simulated Redis."""
return RateLimiterService(redis_service=mock_redis)
def test_full_workflow(self, rate_limiter):
"""Test complete rate limiting workflow."""
user_id = "user_123"
tier = UserTier.FREE # 20 turns/day
# Initial state - should pass
rate_limiter.check_rate_limit(user_id, tier)
assert rate_limiter.get_remaining_turns(user_id, tier) == 20
# Use some turns
for i in range(15):
rate_limiter.check_rate_limit(user_id, tier)
rate_limiter.increment_usage(user_id)
# Check remaining
assert rate_limiter.get_remaining_turns(user_id, tier) == 5
# Use remaining turns
for i in range(5):
rate_limiter.check_rate_limit(user_id, tier)
rate_limiter.increment_usage(user_id)
# Now should be limited
assert rate_limiter.get_remaining_turns(user_id, tier) == 0
with pytest.raises(RateLimitExceeded):
rate_limiter.check_rate_limit(user_id, tier)
def test_different_tiers_same_usage(self, rate_limiter):
"""Test that same usage affects different tiers differently."""
user_id = "user_123"
# Use 30 turns
for _ in range(30):
rate_limiter.increment_usage(user_id)
# Free tier (20) should be limited
with pytest.raises(RateLimitExceeded):
rate_limiter.check_rate_limit(user_id, UserTier.FREE)
# Basic tier (50) should not be limited
rate_limiter.check_rate_limit(user_id, UserTier.BASIC)
# Premium tier (100) should not be limited
rate_limiter.check_rate_limit(user_id, UserTier.PREMIUM)
def test_reset_clears_usage(self, rate_limiter):
"""Test that reset allows new usage."""
user_id = "user_123"
tier = UserTier.FREE
# Use all turns
for _ in range(20):
rate_limiter.increment_usage(user_id)
# Should be limited
with pytest.raises(RateLimitExceeded):
rate_limiter.check_rate_limit(user_id, tier)
# Reset usage
rate_limiter.reset_usage(user_id)
# Should be able to use again
rate_limiter.check_rate_limit(user_id, tier)
assert rate_limiter.get_remaining_turns(user_id, tier) == 20

View File

@@ -0,0 +1,573 @@
"""
Unit tests for Redis Service.
These tests use mocking to test the RedisService without requiring
a real Redis connection.
"""
import pytest
from unittest.mock import Mock, patch, MagicMock
import json
from redis.exceptions import RedisError, ConnectionError as RedisConnectionError
from app.services.redis_service import (
RedisService,
RedisServiceError,
RedisConnectionFailed
)
class TestRedisServiceInit:
"""Test RedisService initialization."""
@patch('app.services.redis_service.redis.ConnectionPool.from_url')
@patch('app.services.redis_service.redis.Redis')
def test_init_success(self, mock_redis_class, mock_pool_from_url):
"""Test successful initialization."""
# Setup mocks
mock_pool = MagicMock()
mock_pool_from_url.return_value = mock_pool
mock_client = MagicMock()
mock_client.ping.return_value = True
mock_redis_class.return_value = mock_client
# Create service
service = RedisService(redis_url='redis://localhost:6379/0')
# Verify
mock_pool_from_url.assert_called_once()
mock_redis_class.assert_called_once_with(connection_pool=mock_pool)
mock_client.ping.assert_called_once()
assert service.client == mock_client
@patch('app.services.redis_service.redis.ConnectionPool.from_url')
def test_init_connection_failed(self, mock_pool_from_url):
"""Test initialization fails when Redis is unavailable."""
mock_pool_from_url.side_effect = RedisConnectionError("Connection refused")
with pytest.raises(RedisConnectionFailed) as exc_info:
RedisService(redis_url='redis://localhost:6379/0')
assert "Could not connect to Redis" in str(exc_info.value)
def test_init_missing_url(self):
"""Test initialization fails with missing URL."""
# Clear environment variable and pass empty string
with patch.dict('os.environ', {'REDIS_URL': ''}, clear=True):
with pytest.raises(ValueError) as exc_info:
RedisService(redis_url='')
assert "Redis URL not configured" in str(exc_info.value)
class TestRedisServiceOperations:
"""Test Redis operations (get, set, delete, exists)."""
@pytest.fixture
def redis_service(self):
"""Create a RedisService with mocked client."""
with patch('app.services.redis_service.redis.ConnectionPool.from_url') as mock_pool_from_url:
with patch('app.services.redis_service.redis.Redis') as mock_redis_class:
mock_pool = MagicMock()
mock_pool_from_url.return_value = mock_pool
mock_client = MagicMock()
mock_client.ping.return_value = True
mock_redis_class.return_value = mock_client
service = RedisService(redis_url='redis://localhost:6379/0')
yield service
def test_get_existing_key(self, redis_service):
"""Test getting an existing key."""
redis_service.client.get.return_value = "test_value"
result = redis_service.get("test_key")
redis_service.client.get.assert_called_once_with("test_key")
assert result == "test_value"
def test_get_nonexistent_key(self, redis_service):
"""Test getting a non-existent key returns None."""
redis_service.client.get.return_value = None
result = redis_service.get("nonexistent_key")
assert result is None
def test_get_error(self, redis_service):
"""Test get raises RedisServiceError on failure."""
redis_service.client.get.side_effect = RedisError("Connection lost")
with pytest.raises(RedisServiceError) as exc_info:
redis_service.get("test_key")
assert "Failed to get key" in str(exc_info.value)
def test_set_basic(self, redis_service):
"""Test basic set operation."""
redis_service.client.set.return_value = True
result = redis_service.set("test_key", "test_value")
redis_service.client.set.assert_called_once_with(
"test_key", "test_value", ex=None, nx=False, xx=False
)
assert result is True
def test_set_with_ttl(self, redis_service):
"""Test set with TTL."""
redis_service.client.set.return_value = True
result = redis_service.set("test_key", "test_value", ttl=3600)
redis_service.client.set.assert_called_once_with(
"test_key", "test_value", ex=3600, nx=False, xx=False
)
assert result is True
def test_set_nx_success(self, redis_service):
"""Test set with NX (only if not exists) - success."""
redis_service.client.set.return_value = True
result = redis_service.set("test_key", "test_value", nx=True)
redis_service.client.set.assert_called_once_with(
"test_key", "test_value", ex=None, nx=True, xx=False
)
assert result is True
def test_set_nx_failure(self, redis_service):
"""Test set with NX fails when key exists."""
redis_service.client.set.return_value = None # NX returns None if key exists
result = redis_service.set("test_key", "test_value", nx=True)
assert result is False
def test_set_error(self, redis_service):
"""Test set raises RedisServiceError on failure."""
redis_service.client.set.side_effect = RedisError("Connection lost")
with pytest.raises(RedisServiceError) as exc_info:
redis_service.set("test_key", "test_value")
assert "Failed to set key" in str(exc_info.value)
def test_delete_single_key(self, redis_service):
"""Test deleting a single key."""
redis_service.client.delete.return_value = 1
result = redis_service.delete("test_key")
redis_service.client.delete.assert_called_once_with("test_key")
assert result == 1
def test_delete_multiple_keys(self, redis_service):
"""Test deleting multiple keys."""
redis_service.client.delete.return_value = 3
result = redis_service.delete("key1", "key2", "key3")
redis_service.client.delete.assert_called_once_with("key1", "key2", "key3")
assert result == 3
def test_delete_no_keys(self, redis_service):
"""Test delete with no keys returns 0."""
result = redis_service.delete()
redis_service.client.delete.assert_not_called()
assert result == 0
def test_delete_error(self, redis_service):
"""Test delete raises RedisServiceError on failure."""
redis_service.client.delete.side_effect = RedisError("Connection lost")
with pytest.raises(RedisServiceError) as exc_info:
redis_service.delete("test_key")
assert "Failed to delete keys" in str(exc_info.value)
def test_exists_single_key(self, redis_service):
"""Test checking existence of a single key."""
redis_service.client.exists.return_value = 1
result = redis_service.exists("test_key")
redis_service.client.exists.assert_called_once_with("test_key")
assert result == 1
def test_exists_multiple_keys(self, redis_service):
"""Test checking existence of multiple keys."""
redis_service.client.exists.return_value = 2
result = redis_service.exists("key1", "key2", "key3")
redis_service.client.exists.assert_called_once_with("key1", "key2", "key3")
assert result == 2
def test_exists_no_keys(self, redis_service):
"""Test exists with no keys returns 0."""
result = redis_service.exists()
redis_service.client.exists.assert_not_called()
assert result == 0
def test_exists_error(self, redis_service):
"""Test exists raises RedisServiceError on failure."""
redis_service.client.exists.side_effect = RedisError("Connection lost")
with pytest.raises(RedisServiceError) as exc_info:
redis_service.exists("test_key")
assert "Failed to check existence" in str(exc_info.value)
class TestRedisServiceJSON:
"""Test JSON serialization methods."""
@pytest.fixture
def redis_service(self):
"""Create a RedisService with mocked client."""
with patch('app.services.redis_service.redis.ConnectionPool.from_url') as mock_pool_from_url:
with patch('app.services.redis_service.redis.Redis') as mock_redis_class:
mock_pool = MagicMock()
mock_pool_from_url.return_value = mock_pool
mock_client = MagicMock()
mock_client.ping.return_value = True
mock_redis_class.return_value = mock_client
service = RedisService(redis_url='redis://localhost:6379/0')
yield service
def test_get_json_success(self, redis_service):
"""Test getting and deserializing JSON."""
test_data = {"name": "test", "value": 123, "nested": {"key": "value"}}
redis_service.client.get.return_value = json.dumps(test_data)
result = redis_service.get_json("test_key")
assert result == test_data
def test_get_json_none(self, redis_service):
"""Test get_json returns None for non-existent key."""
redis_service.client.get.return_value = None
result = redis_service.get_json("nonexistent_key")
assert result is None
def test_get_json_invalid(self, redis_service):
"""Test get_json raises error for invalid JSON."""
redis_service.client.get.return_value = "not valid json {"
with pytest.raises(RedisServiceError) as exc_info:
redis_service.get_json("test_key")
assert "Failed to decode JSON" in str(exc_info.value)
def test_set_json_success(self, redis_service):
"""Test serializing and setting JSON."""
redis_service.client.set.return_value = True
test_data = {"name": "test", "value": 123}
result = redis_service.set_json("test_key", test_data, ttl=3600)
# Verify the value was serialized
call_args = redis_service.client.set.call_args
stored_value = call_args[0][1]
assert json.loads(stored_value) == test_data
assert result is True
def test_set_json_non_serializable(self, redis_service):
"""Test set_json raises error for non-serializable data."""
non_serializable = {"func": lambda x: x}
with pytest.raises(RedisServiceError) as exc_info:
redis_service.set_json("test_key", non_serializable)
assert "Failed to serialize value" in str(exc_info.value)
class TestRedisServiceTTL:
"""Test TTL-related operations."""
@pytest.fixture
def redis_service(self):
"""Create a RedisService with mocked client."""
with patch('app.services.redis_service.redis.ConnectionPool.from_url') as mock_pool_from_url:
with patch('app.services.redis_service.redis.Redis') as mock_redis_class:
mock_pool = MagicMock()
mock_pool_from_url.return_value = mock_pool
mock_client = MagicMock()
mock_client.ping.return_value = True
mock_redis_class.return_value = mock_client
service = RedisService(redis_url='redis://localhost:6379/0')
yield service
def test_expire_success(self, redis_service):
"""Test setting expiration on existing key."""
redis_service.client.expire.return_value = True
result = redis_service.expire("test_key", 3600)
redis_service.client.expire.assert_called_once_with("test_key", 3600)
assert result is True
def test_expire_nonexistent_key(self, redis_service):
"""Test expire returns False for non-existent key."""
redis_service.client.expire.return_value = False
result = redis_service.expire("nonexistent_key", 3600)
assert result is False
def test_ttl_existing_key(self, redis_service):
"""Test getting TTL of existing key."""
redis_service.client.ttl.return_value = 3500
result = redis_service.ttl("test_key")
redis_service.client.ttl.assert_called_once_with("test_key")
assert result == 3500
def test_ttl_no_expiry(self, redis_service):
"""Test TTL returns -1 for key without expiry."""
redis_service.client.ttl.return_value = -1
result = redis_service.ttl("test_key")
assert result == -1
def test_ttl_nonexistent_key(self, redis_service):
"""Test TTL returns -2 for non-existent key."""
redis_service.client.ttl.return_value = -2
result = redis_service.ttl("test_key")
assert result == -2
class TestRedisServiceIncrement:
"""Test increment/decrement operations."""
@pytest.fixture
def redis_service(self):
"""Create a RedisService with mocked client."""
with patch('app.services.redis_service.redis.ConnectionPool.from_url') as mock_pool_from_url:
with patch('app.services.redis_service.redis.Redis') as mock_redis_class:
mock_pool = MagicMock()
mock_pool_from_url.return_value = mock_pool
mock_client = MagicMock()
mock_client.ping.return_value = True
mock_redis_class.return_value = mock_client
service = RedisService(redis_url='redis://localhost:6379/0')
yield service
def test_incr_default(self, redis_service):
"""Test incrementing by default amount (1)."""
redis_service.client.incrby.return_value = 5
result = redis_service.incr("counter")
redis_service.client.incrby.assert_called_once_with("counter", 1)
assert result == 5
def test_incr_custom_amount(self, redis_service):
"""Test incrementing by custom amount."""
redis_service.client.incrby.return_value = 15
result = redis_service.incr("counter", 10)
redis_service.client.incrby.assert_called_once_with("counter", 10)
assert result == 15
def test_decr_default(self, redis_service):
"""Test decrementing by default amount (1)."""
redis_service.client.decrby.return_value = 4
result = redis_service.decr("counter")
redis_service.client.decrby.assert_called_once_with("counter", 1)
assert result == 4
def test_decr_custom_amount(self, redis_service):
"""Test decrementing by custom amount."""
redis_service.client.decrby.return_value = 0
result = redis_service.decr("counter", 5)
redis_service.client.decrby.assert_called_once_with("counter", 5)
assert result == 0
class TestRedisServiceHealth:
"""Test health check and info operations."""
@pytest.fixture
def redis_service(self):
"""Create a RedisService with mocked client."""
with patch('app.services.redis_service.redis.ConnectionPool.from_url') as mock_pool_from_url:
with patch('app.services.redis_service.redis.Redis') as mock_redis_class:
mock_pool = MagicMock()
mock_pool_from_url.return_value = mock_pool
mock_client = MagicMock()
mock_client.ping.return_value = True
mock_redis_class.return_value = mock_client
service = RedisService(redis_url='redis://localhost:6379/0')
yield service
def test_health_check_success(self, redis_service):
"""Test health check when Redis is healthy."""
redis_service.client.ping.return_value = True
result = redis_service.health_check()
assert result is True
def test_health_check_failure(self, redis_service):
"""Test health check when Redis is unhealthy."""
redis_service.client.ping.side_effect = RedisError("Connection lost")
result = redis_service.health_check()
assert result is False
def test_info_success(self, redis_service):
"""Test getting Redis info."""
mock_info = {
'redis_version': '7.0.0',
'used_memory': 1000000,
'connected_clients': 5
}
redis_service.client.info.return_value = mock_info
result = redis_service.info()
assert result == mock_info
def test_info_error(self, redis_service):
"""Test info raises error on failure."""
redis_service.client.info.side_effect = RedisError("Connection lost")
with pytest.raises(RedisServiceError) as exc_info:
redis_service.info()
assert "Failed to get Redis info" in str(exc_info.value)
class TestRedisServiceUtility:
"""Test utility methods."""
@pytest.fixture
def redis_service(self):
"""Create a RedisService with mocked client."""
with patch('app.services.redis_service.redis.ConnectionPool.from_url') as mock_pool_from_url:
with patch('app.services.redis_service.redis.Redis') as mock_redis_class:
mock_pool = MagicMock()
mock_pool_from_url.return_value = mock_pool
mock_client = MagicMock()
mock_client.ping.return_value = True
mock_redis_class.return_value = mock_client
service = RedisService(redis_url='redis://localhost:6379/0')
yield service
def test_flush_db(self, redis_service):
"""Test flushing database."""
redis_service.client.flushdb.return_value = True
result = redis_service.flush_db()
redis_service.client.flushdb.assert_called_once()
assert result is True
def test_close(self, redis_service):
"""Test closing connection pool."""
redis_service.close()
redis_service.pool.disconnect.assert_called_once()
def test_context_manager(self, redis_service):
"""Test using service as context manager."""
with redis_service as service:
assert service is not None
redis_service.pool.disconnect.assert_called_once()
def test_sanitize_url_with_password(self, redis_service):
"""Test URL sanitization masks password."""
url = "redis://user:secretpassword@localhost:6379/0"
result = redis_service._sanitize_url(url)
assert "secretpassword" not in result
assert "***" in result
assert "localhost:6379/0" in result
def test_sanitize_url_without_password(self, redis_service):
"""Test URL sanitization with no password."""
url = "redis://localhost:6379/0"
result = redis_service._sanitize_url(url)
assert result == url
class TestRedisServiceIntegration:
"""Integration-style tests that verify the flow of operations."""
@pytest.fixture
def redis_service(self):
"""Create a RedisService with mocked client."""
with patch('app.services.redis_service.redis.ConnectionPool.from_url') as mock_pool_from_url:
with patch('app.services.redis_service.redis.Redis') as mock_redis_class:
mock_pool = MagicMock()
mock_pool_from_url.return_value = mock_pool
mock_client = MagicMock()
mock_client.ping.return_value = True
mock_redis_class.return_value = mock_client
service = RedisService(redis_url='redis://localhost:6379/0')
yield service
def test_set_then_get(self, redis_service):
"""Test setting and then getting a value."""
# Set
redis_service.client.set.return_value = True
redis_service.set("test_key", "test_value")
# Get
redis_service.client.get.return_value = "test_value"
result = redis_service.get("test_key")
assert result == "test_value"
def test_json_roundtrip(self, redis_service):
"""Test JSON serialization roundtrip."""
test_data = {
"user_id": "user_123",
"tokens_used": 450,
"model": "claude-3-5-haiku"
}
# Set JSON
redis_service.client.set.return_value = True
redis_service.set_json("job_result", test_data, ttl=3600)
# Get JSON
redis_service.client.get.return_value = json.dumps(test_data)
result = redis_service.get_json("job_result")
assert result == test_data

View File

@@ -0,0 +1,462 @@
"""
Tests for Replicate API client.
Tests cover initialization, prompt formatting, generation,
retry logic, and error handling.
"""
import pytest
from unittest.mock import patch, MagicMock
from app.ai.replicate_client import (
ReplicateClient,
ReplicateResponse,
ReplicateClientError,
ReplicateAPIError,
ReplicateRateLimitError,
ReplicateTimeoutError,
ModelType,
)
class TestReplicateClientInit:
"""Tests for ReplicateClient initialization."""
@patch('app.ai.replicate_client.get_config')
def test_init_with_token(self, mock_config):
"""Test initialization with explicit API token."""
mock_config.return_value = MagicMock(
replicate_api_token=None,
REPLICATE_MODEL=None
)
client = ReplicateClient(api_token="test_token_123")
assert client.api_token == "test_token_123"
assert client.model == ReplicateClient.DEFAULT_MODEL.value
@patch('app.ai.replicate_client.get_config')
def test_init_from_config(self, mock_config):
"""Test initialization from config."""
mock_config.return_value = MagicMock(
replicate_api_token="config_token",
REPLICATE_MODEL="custom/model"
)
client = ReplicateClient()
assert client.api_token == "config_token"
assert client.model == "custom/model"
@patch('app.ai.replicate_client.get_config')
def test_init_missing_token(self, mock_config):
"""Test initialization fails without API token."""
mock_config.return_value = MagicMock(
replicate_api_token=None,
REPLICATE_MODEL=None
)
with pytest.raises(ReplicateClientError) as exc_info:
ReplicateClient()
assert "API token not configured" in str(exc_info.value)
@patch('app.ai.replicate_client.get_config')
def test_init_custom_model(self, mock_config):
"""Test initialization with custom model."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
client = ReplicateClient(model="meta/llama-2-70b")
assert client.model == "meta/llama-2-70b"
class TestPromptFormatting:
"""Tests for Llama-3 prompt formatting."""
@patch('app.ai.replicate_client.get_config')
def test_format_prompt_user_only(self, mock_config):
"""Test formatting with only user prompt."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
client = ReplicateClient()
formatted = client._format_llama_prompt("Hello world")
assert "<|begin_of_text|>" in formatted
assert "<|start_header_id|>user<|end_header_id|>" in formatted
assert "Hello world" in formatted
assert "<|start_header_id|>assistant<|end_header_id|>" in formatted
# No system header without system prompt
assert "system<|end_header_id|>" not in formatted
@patch('app.ai.replicate_client.get_config')
def test_format_prompt_with_system(self, mock_config):
"""Test formatting with system and user prompts."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
client = ReplicateClient()
formatted = client._format_llama_prompt(
"What is 2+2?",
system_prompt="You are a helpful assistant."
)
assert "<|start_header_id|>system<|end_header_id|>" in formatted
assert "You are a helpful assistant." in formatted
assert "<|start_header_id|>user<|end_header_id|>" in formatted
assert "What is 2+2?" in formatted
class TestGenerate:
"""Tests for text generation."""
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_generate_success(self, mock_replicate, mock_config):
"""Test successful text generation."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
# Mock streaming response
mock_replicate.run.return_value = iter(["Hello ", "world ", "!"])
client = ReplicateClient()
response = client.generate("Say hello")
assert isinstance(response, ReplicateResponse)
assert response.text == "Hello world !"
assert response.tokens_used > 0
assert response.model == ReplicateClient.DEFAULT_MODEL.value
assert response.generation_time > 0
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_generate_with_parameters(self, mock_replicate, mock_config):
"""Test generation with custom parameters."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
mock_replicate.run.return_value = iter(["Response"])
client = ReplicateClient()
response = client.generate(
prompt="Test",
system_prompt="Be concise",
max_tokens=100,
temperature=0.5,
top_p=0.8,
timeout=60
)
# Verify parameters were passed
call_args = mock_replicate.run.call_args
assert call_args[1]['input']['max_tokens'] == 100
assert call_args[1]['input']['temperature'] == 0.5
assert call_args[1]['input']['top_p'] == 0.8
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_generate_string_response(self, mock_replicate, mock_config):
"""Test handling string response (non-streaming)."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
mock_replicate.run.return_value = "Direct string response"
client = ReplicateClient()
response = client.generate("Test")
assert response.text == "Direct string response"
class TestRetryLogic:
"""Tests for retry and error handling."""
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
@patch('app.ai.replicate_client.time.sleep')
def test_retry_on_rate_limit(self, mock_sleep, mock_replicate, mock_config):
"""Test retry logic on rate limit errors."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
# First call raises rate limit, second succeeds
mock_replicate.exceptions.ReplicateError = Exception
mock_replicate.run.side_effect = [
Exception("Rate limit exceeded 429"),
iter(["Success"])
]
client = ReplicateClient()
response = client.generate("Test")
assert response.text == "Success"
assert mock_sleep.called # Verify backoff happened
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
@patch('app.ai.replicate_client.time.sleep')
def test_max_retries_exceeded(self, mock_sleep, mock_replicate, mock_config):
"""Test that max retries raises error."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
# All calls fail with rate limit
mock_replicate.exceptions.ReplicateError = Exception
mock_replicate.run.side_effect = Exception("Rate limit exceeded 429")
client = ReplicateClient()
with pytest.raises(ReplicateRateLimitError):
client.generate("Test")
# Should have retried MAX_RETRIES times
assert mock_replicate.run.call_count == ReplicateClient.MAX_RETRIES
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_timeout_error(self, mock_replicate, mock_config):
"""Test timeout error handling."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
mock_replicate.exceptions.ReplicateError = Exception
mock_replicate.run.side_effect = Exception("Request timeout")
client = ReplicateClient()
with pytest.raises(ReplicateTimeoutError):
client.generate("Test")
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_api_error(self, mock_replicate, mock_config):
"""Test generic API error handling."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
mock_replicate.exceptions.ReplicateError = Exception
mock_replicate.run.side_effect = Exception("Invalid model")
client = ReplicateClient()
with pytest.raises(ReplicateAPIError):
client.generate("Test")
class TestValidation:
"""Tests for API key validation."""
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_validate_api_key_success(self, mock_replicate, mock_config):
"""Test successful API key validation."""
mock_config.return_value = MagicMock(
replicate_api_token="valid_token",
REPLICATE_MODEL=None
)
mock_replicate.models.get.return_value = MagicMock()
client = ReplicateClient()
assert client.validate_api_key() is True
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_validate_api_key_failure(self, mock_replicate, mock_config):
"""Test failed API key validation."""
mock_config.return_value = MagicMock(
replicate_api_token="invalid_token",
REPLICATE_MODEL=None
)
mock_replicate.models.get.side_effect = Exception("Invalid API token")
client = ReplicateClient()
assert client.validate_api_key() is False
class TestResponseDataclass:
"""Tests for ReplicateResponse dataclass."""
def test_response_creation(self):
"""Test creating ReplicateResponse."""
response = ReplicateResponse(
text="Hello world",
tokens_used=50,
model="meta/llama-3-8b",
generation_time=1.5
)
assert response.text == "Hello world"
assert response.tokens_used == 50
assert response.model == "meta/llama-3-8b"
assert response.generation_time == 1.5
def test_response_immutability(self):
"""Test that response fields are accessible."""
response = ReplicateResponse(
text="Test",
tokens_used=10,
model="test",
generation_time=0.5
)
# Dataclass should allow attribute access
assert hasattr(response, 'text')
assert hasattr(response, 'tokens_used')
class TestModelType:
"""Tests for ModelType enum and multi-model support."""
def test_model_type_values(self):
"""Test ModelType enum has expected values."""
assert ModelType.LLAMA_3_8B.value == "meta/meta-llama-3-8b-instruct"
assert ModelType.CLAUDE_HAIKU.value == "anthropic/claude-3.5-haiku"
assert ModelType.CLAUDE_SONNET.value == "anthropic/claude-3.5-sonnet"
assert ModelType.CLAUDE_SONNET_4.value == "anthropic/claude-sonnet-4"
@patch('app.ai.replicate_client.get_config')
def test_init_with_model_type_enum(self, mock_config):
"""Test initialization with ModelType enum."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
client = ReplicateClient(model=ModelType.CLAUDE_HAIKU)
assert client.model == "anthropic/claude-3.5-haiku"
assert client.model_type == ModelType.CLAUDE_HAIKU
@patch('app.ai.replicate_client.get_config')
def test_is_claude_model(self, mock_config):
"""Test _is_claude_model helper."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
# Llama model
client = ReplicateClient(model=ModelType.LLAMA_3_8B)
assert client._is_claude_model() is False
# Claude models
client = ReplicateClient(model=ModelType.CLAUDE_HAIKU)
assert client._is_claude_model() is True
client = ReplicateClient(model=ModelType.CLAUDE_SONNET)
assert client._is_claude_model() is True
class TestClaudeModels:
"""Tests for Claude model generation via Replicate."""
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_generate_with_claude_haiku(self, mock_replicate, mock_config):
"""Test generation with Claude Haiku model."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
mock_replicate.run.return_value = iter(["Claude ", "response"])
client = ReplicateClient(model=ModelType.CLAUDE_HAIKU)
response = client.generate("Test prompt")
assert response.text == "Claude response"
assert response.model == "anthropic/claude-3.5-haiku"
# Verify Claude-style params (not Llama formatted prompt)
call_args = mock_replicate.run.call_args
assert "prompt" in call_args[1]['input']
# Claude params don't include Llama special tokens
assert "<|begin_of_text|>" not in call_args[1]['input']['prompt']
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_generate_with_claude_system_prompt(self, mock_replicate, mock_config):
"""Test Claude generation includes system_prompt parameter."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
mock_replicate.run.return_value = iter(["Response"])
client = ReplicateClient(model=ModelType.CLAUDE_SONNET)
client.generate(
prompt="User message",
system_prompt="You are a DM"
)
call_args = mock_replicate.run.call_args
assert call_args[1]['input']['system_prompt'] == "You are a DM"
assert call_args[1]['input']['prompt'] == "User message"
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_model_specific_defaults(self, mock_replicate, mock_config):
"""Test that model-specific defaults are applied."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
mock_replicate.run.return_value = iter(["Response"])
# Claude Sonnet should use higher max_tokens by default
client = ReplicateClient(model=ModelType.CLAUDE_SONNET)
client.generate("Test")
call_args = mock_replicate.run.call_args
# Sonnet default is 1024 tokens
assert call_args[1]['input']['max_tokens'] == 1024
assert call_args[1]['input']['temperature'] == 0.9
@patch('app.ai.replicate_client.get_config')
@patch('app.ai.replicate_client.replicate')
def test_model_override_in_generate(self, mock_replicate, mock_config):
"""Test overriding model in generate() call."""
mock_config.return_value = MagicMock(
replicate_api_token="token",
REPLICATE_MODEL=None
)
mock_replicate.run.return_value = iter(["Response"])
# Init with Llama, but call with Claude
client = ReplicateClient(model=ModelType.LLAMA_3_8B)
response = client.generate("Test", model=ModelType.CLAUDE_HAIKU)
# Response should reflect the overridden model
assert response.model == "anthropic/claude-3.5-haiku"
# Verify correct model was called
call_args = mock_replicate.run.call_args
assert call_args[0][0] == "anthropic/claude-3.5-haiku"

View File

@@ -0,0 +1,390 @@
"""
Tests for the GameSession model with solo play support.
Tests cover:
- Solo session creation and serialization
- Multiplayer session creation and serialization
- ConversationEntry with timestamps
- GameState with location_type
- Session type detection
"""
import pytest
from datetime import datetime
from app.models.session import (
GameSession,
GameState,
ConversationEntry,
SessionConfig,
)
from app.models.enums import (
SessionStatus,
SessionType,
LocationType,
)
class TestGameState:
"""Tests for GameState dataclass."""
def test_default_values(self):
"""Test GameState has correct defaults."""
state = GameState()
assert state.current_location == "Crossroads Village"
assert state.location_type == LocationType.TOWN
assert state.discovered_locations == []
assert state.active_quests == []
assert state.world_events == []
def test_to_dict_serializes_location_type(self):
"""Test location_type is serialized as string."""
state = GameState(
current_location="The Rusty Anchor",
location_type=LocationType.TAVERN,
)
data = state.to_dict()
assert data["current_location"] == "The Rusty Anchor"
assert data["location_type"] == "tavern"
def test_from_dict_deserializes_location_type(self):
"""Test location_type is deserialized from string."""
data = {
"current_location": "Dark Forest",
"location_type": "wilderness",
"discovered_locations": ["Town A"],
"active_quests": ["quest_1"],
"world_events": [],
}
state = GameState.from_dict(data)
assert state.current_location == "Dark Forest"
assert state.location_type == LocationType.WILDERNESS
assert state.discovered_locations == ["Town A"]
def test_roundtrip_serialization(self):
"""Test GameState serializes and deserializes correctly."""
state = GameState(
current_location="Ancient Library",
location_type=LocationType.LIBRARY,
discovered_locations=["Town", "Forest"],
active_quests=["quest_1", "quest_2"],
world_events=[{"type": "festival"}],
)
data = state.to_dict()
restored = GameState.from_dict(data)
assert restored.current_location == state.current_location
assert restored.location_type == state.location_type
assert restored.discovered_locations == state.discovered_locations
assert restored.active_quests == state.active_quests
class TestConversationEntry:
"""Tests for ConversationEntry dataclass."""
def test_auto_timestamp(self):
"""Test timestamp is auto-generated."""
entry = ConversationEntry(
turn=1,
character_id="char_123",
character_name="Hero",
action="I explore",
dm_response="You find a chest",
)
assert entry.timestamp
assert entry.timestamp.endswith("Z")
def test_provided_timestamp_preserved(self):
"""Test provided timestamp is not overwritten."""
ts = "2025-11-21T10:30:00Z"
entry = ConversationEntry(
turn=1,
character_id="char_123",
character_name="Hero",
action="I explore",
dm_response="You find a chest",
timestamp=ts,
)
assert entry.timestamp == ts
def test_to_dict_with_quest_offered(self):
"""Test serialization includes quest_offered when present."""
entry = ConversationEntry(
turn=5,
character_id="char_123",
character_name="Hero",
action="Talk to elder",
dm_response="The elder offers you a quest",
quest_offered={
"quest_id": "quest_goblin_cave",
"quest_name": "Clear the Goblin Cave",
},
)
data = entry.to_dict()
assert "quest_offered" in data
assert data["quest_offered"]["quest_id"] == "quest_goblin_cave"
def test_to_dict_without_quest_offered(self):
"""Test serialization omits quest_offered when None."""
entry = ConversationEntry(
turn=1,
character_id="char_123",
character_name="Hero",
action="I explore",
dm_response="You find nothing",
)
data = entry.to_dict()
assert "quest_offered" not in data
def test_from_dict_roundtrip(self):
"""Test ConversationEntry roundtrip serialization."""
entry = ConversationEntry(
turn=3,
character_id="char_456",
character_name="Wizard",
action="Cast fireball",
dm_response="The spell illuminates the cave",
combat_log=[{"action": "attack", "damage": 15}],
quest_offered={"quest_id": "q1"},
)
data = entry.to_dict()
restored = ConversationEntry.from_dict(data)
assert restored.turn == entry.turn
assert restored.character_id == entry.character_id
assert restored.action == entry.action
assert restored.dm_response == entry.dm_response
assert restored.timestamp == entry.timestamp
assert restored.combat_log == entry.combat_log
assert restored.quest_offered == entry.quest_offered
class TestGameSessionSolo:
"""Tests for solo GameSession functionality."""
def test_create_solo_session(self):
"""Test creating a solo session."""
session = GameSession(
session_id="sess_123",
session_type=SessionType.SOLO,
solo_character_id="char_456",
user_id="user_789",
)
assert session.session_type == SessionType.SOLO
assert session.solo_character_id == "char_456"
assert session.user_id == "user_789"
assert session.is_solo() is True
def test_is_solo_method(self):
"""Test is_solo returns correct values."""
solo = GameSession(
session_id="s1",
session_type=SessionType.SOLO,
solo_character_id="c1",
)
multi = GameSession(
session_id="s2",
session_type=SessionType.MULTIPLAYER,
party_member_ids=["c1", "c2"],
)
assert solo.is_solo() is True
assert multi.is_solo() is False
def test_get_character_id_solo(self):
"""Test get_character_id returns solo_character_id for solo sessions."""
session = GameSession(
session_id="sess_123",
session_type=SessionType.SOLO,
solo_character_id="char_456",
)
assert session.get_character_id() == "char_456"
def test_get_character_id_multiplayer(self):
"""Test get_character_id returns current turn character for multiplayer."""
session = GameSession(
session_id="sess_123",
session_type=SessionType.MULTIPLAYER,
party_member_ids=["c1", "c2", "c3"],
turn_order=["c2", "c1", "c3"],
current_turn=1,
)
assert session.get_character_id() == "c1"
def test_to_dict_includes_new_fields(self):
"""Test to_dict includes session_type, solo_character_id, user_id."""
session = GameSession(
session_id="sess_123",
session_type=SessionType.SOLO,
solo_character_id="char_456",
user_id="user_789",
)
data = session.to_dict()
assert data["session_id"] == "sess_123"
assert data["session_type"] == "solo"
assert data["solo_character_id"] == "char_456"
assert data["user_id"] == "user_789"
def test_from_dict_solo_session(self):
"""Test from_dict correctly deserializes solo session."""
data = {
"session_id": "sess_123",
"session_type": "solo",
"solo_character_id": "char_456",
"user_id": "user_789",
"party_member_ids": [],
"turn_number": 5,
"game_state": {
"current_location": "Town",
"location_type": "town",
},
}
session = GameSession.from_dict(data)
assert session.session_id == "sess_123"
assert session.session_type == SessionType.SOLO
assert session.solo_character_id == "char_456"
assert session.user_id == "user_789"
assert session.turn_number == 5
assert session.game_state.location_type == LocationType.TOWN
def test_roundtrip_serialization(self):
"""Test complete roundtrip of solo session."""
session = GameSession(
session_id="sess_test",
session_type=SessionType.SOLO,
solo_character_id="char_hero",
user_id="user_player",
turn_number=10,
game_state=GameState(
current_location="Dark Dungeon",
location_type=LocationType.DUNGEON,
active_quests=["quest_1"],
),
conversation_history=[
ConversationEntry(
turn=1,
character_id="char_hero",
character_name="Hero",
action="Enter dungeon",
dm_response="The darkness swallows you...",
)
],
)
data = session.to_dict()
restored = GameSession.from_dict(data)
assert restored.session_id == session.session_id
assert restored.session_type == session.session_type
assert restored.solo_character_id == session.solo_character_id
assert restored.user_id == session.user_id
assert restored.turn_number == session.turn_number
assert restored.game_state.current_location == session.game_state.current_location
assert restored.game_state.location_type == session.game_state.location_type
assert len(restored.conversation_history) == 1
assert restored.conversation_history[0].action == "Enter dungeon"
def test_repr_solo(self):
"""Test __repr__ for solo session."""
session = GameSession(
session_id="sess_123",
session_type=SessionType.SOLO,
solo_character_id="char_456",
turn_number=5,
)
repr_str = repr(session)
assert "type=solo" in repr_str
assert "char=char_456" in repr_str
assert "turn=5" in repr_str
def test_repr_multiplayer(self):
"""Test __repr__ for multiplayer session."""
session = GameSession(
session_id="sess_123",
session_type=SessionType.MULTIPLAYER,
party_member_ids=["c1", "c2", "c3"],
turn_number=10,
)
repr_str = repr(session)
assert "type=multiplayer" in repr_str
assert "party=3" in repr_str
assert "turn=10" in repr_str
class TestGameSessionBackwardsCompatibility:
"""Tests for backwards compatibility with existing sessions."""
def test_default_session_type_is_solo(self):
"""Test new sessions default to solo type."""
session = GameSession(session_id="test")
assert session.session_type == SessionType.SOLO
def test_from_dict_without_session_type(self):
"""Test from_dict handles missing session_type (defaults to solo)."""
data = {
"session_id": "old_session",
"party_member_ids": ["c1"],
}
session = GameSession.from_dict(data)
assert session.session_type == SessionType.SOLO
def test_from_dict_without_location_type(self):
"""Test from_dict handles missing location_type in game_state."""
data = {
"session_id": "old_session",
"game_state": {
"current_location": "Old Town",
},
}
session = GameSession.from_dict(data)
assert session.game_state.location_type == LocationType.TOWN
def test_existing_methods_still_work(self):
"""Test existing session methods work with new fields."""
session = GameSession(
session_id="test",
session_type=SessionType.SOLO,
solo_character_id="char_1",
)
# Test existing methods
assert session.is_in_combat() is False
session.update_activity()
assert session.last_activity
# Add conversation entry
entry = ConversationEntry(
turn=1,
character_id="char_1",
character_name="Hero",
action="test",
dm_response="response",
)
session.add_conversation_entry(entry)
assert len(session.conversation_history) == 1
class TestLocationTypeEnum:
"""Tests for LocationType enum."""
def test_all_location_types_defined(self):
"""Test all expected location types exist."""
expected = ["town", "tavern", "wilderness", "dungeon", "ruins", "library", "safe_area"]
actual = [lt.value for lt in LocationType]
assert sorted(actual) == sorted(expected)
def test_location_type_from_string(self):
"""Test LocationType can be created from string."""
assert LocationType("town") == LocationType.TOWN
assert LocationType("wilderness") == LocationType.WILDERNESS
assert LocationType("dungeon") == LocationType.DUNGEON
class TestSessionTypeEnum:
"""Tests for SessionType enum."""
def test_session_types_defined(self):
"""Test session types are defined correctly."""
assert SessionType.SOLO.value == "solo"
assert SessionType.MULTIPLAYER.value == "multiplayer"

View File

@@ -0,0 +1,566 @@
"""
Tests for the SessionService.
Tests cover:
- Solo session creation
- Session retrieval and listing
- Conversation history management
- Game state tracking (location, quests, events)
- Session validation and limits
"""
import pytest
from unittest.mock import Mock, patch, MagicMock
from datetime import datetime
from app.services.session_service import (
SessionService,
SessionNotFound,
SessionLimitExceeded,
SessionValidationError,
MAX_ACTIVE_SESSIONS,
)
from app.models.session import GameSession, GameState, ConversationEntry
from app.models.enums import SessionStatus, SessionType, LocationType
from app.models.character import Character
from app.models.skills import PlayerClass
from app.models.origins import Origin
@pytest.fixture
def mock_db():
"""Create mock database service."""
with patch('app.services.session_service.get_database_service') as mock:
db = Mock()
mock.return_value = db
yield db
@pytest.fixture
def mock_appwrite():
"""Create mock Appwrite service."""
with patch('app.services.session_service.AppwriteService') as mock:
service = Mock()
mock.return_value = service
yield service
@pytest.fixture
def mock_character_service():
"""Create mock character service."""
with patch('app.services.session_service.get_character_service') as mock:
service = Mock()
mock.return_value = service
yield service
@pytest.fixture
def sample_character():
"""Create a sample character for testing."""
return Character(
character_id="char_123",
user_id="user_456",
name="Test Hero",
player_class=Mock(spec=PlayerClass),
origin=Mock(spec=Origin),
level=5,
experience=1000,
base_stats={"strength": 10},
unlocked_skills=[],
inventory=[],
equipped={},
gold=100,
active_quests=[],
discovered_locations=[],
current_location="Town"
)
class TestSessionServiceCreation:
"""Tests for session creation."""
def test_create_solo_session_success(self, mock_db, mock_appwrite, mock_character_service, sample_character):
"""Test successful solo session creation."""
mock_character_service.get_character.return_value = sample_character
mock_db.count_documents.return_value = 0
mock_db.create_document.return_value = None
service = SessionService()
session = service.create_solo_session(
user_id="user_456",
character_id="char_123"
)
assert session.session_type == SessionType.SOLO
assert session.solo_character_id == "char_123"
assert session.user_id == "user_456"
assert session.turn_number == 0
assert session.status == SessionStatus.ACTIVE
assert session.game_state.current_location == "Crossroads Village"
assert session.game_state.location_type == LocationType.TOWN
mock_db.create_document.assert_called_once()
def test_create_solo_session_character_not_found(self, mock_db, mock_appwrite, mock_character_service):
"""Test session creation fails when character not found."""
from app.services.character_service import CharacterNotFound
mock_character_service.get_character.side_effect = CharacterNotFound("Not found")
service = SessionService()
with pytest.raises(CharacterNotFound):
service.create_solo_session(
user_id="user_456",
character_id="char_invalid"
)
def test_create_solo_session_limit_exceeded(self, mock_db, mock_appwrite, mock_character_service, sample_character):
"""Test session creation fails when limit exceeded."""
mock_character_service.get_character.return_value = sample_character
mock_db.count_documents.return_value = MAX_ACTIVE_SESSIONS
service = SessionService()
with pytest.raises(SessionLimitExceeded):
service.create_solo_session(
user_id="user_456",
character_id="char_123"
)
def test_create_solo_session_custom_location(self, mock_db, mock_appwrite, mock_character_service, sample_character):
"""Test session creation with custom starting location."""
mock_character_service.get_character.return_value = sample_character
mock_db.count_documents.return_value = 0
service = SessionService()
session = service.create_solo_session(
user_id="user_456",
character_id="char_123",
starting_location="Dark Forest",
starting_location_type=LocationType.WILDERNESS
)
assert session.game_state.current_location == "Dark Forest"
assert session.game_state.location_type == LocationType.WILDERNESS
assert "Dark Forest" in session.game_state.discovered_locations
class TestSessionServiceRetrieval:
"""Tests for session retrieval."""
def test_get_session_success(self, mock_db, mock_appwrite, mock_character_service):
"""Test successful session retrieval."""
session_data = {
"session_id": "sess_123",
"session_type": "solo",
"solo_character_id": "char_456",
"user_id": "user_789",
"party_member_ids": [],
"turn_number": 5,
"status": "active",
"game_state": {
"current_location": "Town",
"location_type": "town"
},
"conversation_history": [],
"config": {},
"turn_order": [],
"current_turn": 0
}
mock_document = Mock()
mock_document.data = {
'userId': 'user_789',
'sessionData': __import__('json').dumps(session_data)
}
mock_db.get_document.return_value = mock_document
service = SessionService()
session = service.get_session("sess_123", "user_789")
assert session.session_id == "sess_123"
assert session.user_id == "user_789"
assert session.turn_number == 5
def test_get_session_not_found(self, mock_db, mock_appwrite, mock_character_service):
"""Test session retrieval when not found."""
mock_db.get_document.return_value = None
service = SessionService()
with pytest.raises(SessionNotFound):
service.get_session("sess_invalid")
def test_get_session_wrong_user(self, mock_db, mock_appwrite, mock_character_service):
"""Test session retrieval with wrong user ID."""
mock_document = Mock()
mock_document.data = {
'userId': 'user_other',
'sessionData': '{}'
}
mock_db.get_document.return_value = mock_document
service = SessionService()
with pytest.raises(SessionNotFound):
service.get_session("sess_123", "user_wrong")
def test_get_user_sessions(self, mock_db, mock_appwrite, mock_character_service):
"""Test getting all sessions for a user."""
session_data = {
"session_id": "sess_1",
"session_type": "solo",
"user_id": "user_123",
"status": "active",
"turn_number": 0,
"game_state": {},
"conversation_history": [],
"config": {},
"turn_order": [],
"current_turn": 0,
"party_member_ids": []
}
mock_doc = Mock()
mock_doc.data = {'sessionData': __import__('json').dumps(session_data)}
mock_doc.id = "sess_1"
mock_db.list_rows.return_value = [mock_doc]
service = SessionService()
sessions = service.get_user_sessions("user_123")
assert len(sessions) == 1
assert sessions[0].session_id == "sess_1"
class TestConversationHistory:
"""Tests for conversation history management."""
def test_add_conversation_entry(self, mock_db, mock_appwrite, mock_character_service):
"""Test adding a conversation entry."""
session_data = {
"session_id": "sess_123",
"session_type": "solo",
"user_id": "user_456",
"turn_number": 0,
"status": "active",
"game_state": {"current_location": "Town", "location_type": "town"},
"conversation_history": [],
"config": {},
"turn_order": [],
"current_turn": 0,
"party_member_ids": []
}
mock_document = Mock()
mock_document.data = {
'userId': 'user_456',
'sessionData': __import__('json').dumps(session_data)
}
mock_db.get_document.return_value = mock_document
service = SessionService()
updated = service.add_conversation_entry(
session_id="sess_123",
character_id="char_789",
character_name="Hero",
action="I explore the area",
dm_response="You find a hidden path..."
)
assert updated.turn_number == 1
assert len(updated.conversation_history) == 1
assert updated.conversation_history[0].action == "I explore the area"
assert updated.conversation_history[0].dm_response == "You find a hidden path..."
mock_db.update_document.assert_called_once()
def test_add_conversation_entry_with_quest(self, mock_db, mock_appwrite, mock_character_service):
"""Test adding conversation entry with quest offering."""
session_data = {
"session_id": "sess_123",
"session_type": "solo",
"user_id": "user_456",
"turn_number": 5,
"status": "active",
"game_state": {},
"conversation_history": [],
"config": {},
"turn_order": [],
"current_turn": 0,
"party_member_ids": []
}
mock_document = Mock()
mock_document.data = {
'userId': 'user_456',
'sessionData': __import__('json').dumps(session_data)
}
mock_db.get_document.return_value = mock_document
service = SessionService()
updated = service.add_conversation_entry(
session_id="sess_123",
character_id="char_789",
character_name="Hero",
action="Talk to elder",
dm_response="The elder offers you a quest...",
quest_offered={"quest_id": "quest_goblin", "name": "Clear Goblin Cave"}
)
assert updated.conversation_history[0].quest_offered is not None
assert updated.conversation_history[0].quest_offered["quest_id"] == "quest_goblin"
def test_get_recent_history(self, mock_db, mock_appwrite, mock_character_service):
"""Test getting recent conversation history."""
# Create session with 5 conversation entries
entries = []
for i in range(5):
entries.append({
"turn": i + 1,
"character_id": "char_123",
"character_name": "Hero",
"action": f"Action {i+1}",
"dm_response": f"Response {i+1}",
"timestamp": "2025-11-21T10:00:00Z",
"combat_log": []
})
session_data = {
"session_id": "sess_123",
"session_type": "solo",
"user_id": "user_456",
"turn_number": 5,
"status": "active",
"game_state": {},
"conversation_history": entries,
"config": {},
"turn_order": [],
"current_turn": 0,
"party_member_ids": []
}
mock_document = Mock()
mock_document.data = {
'userId': 'user_456',
'sessionData': __import__('json').dumps(session_data)
}
mock_db.get_document.return_value = mock_document
service = SessionService()
recent = service.get_recent_history("sess_123", num_turns=3)
assert len(recent) == 3
assert recent[0].turn == 3 # Last 3 entries
assert recent[2].turn == 5
class TestGameStateTracking:
"""Tests for game state tracking methods."""
def test_update_location(self, mock_db, mock_appwrite, mock_character_service):
"""Test updating session location."""
session_data = {
"session_id": "sess_123",
"session_type": "solo",
"user_id": "user_456",
"turn_number": 0,
"status": "active",
"game_state": {
"current_location": "Town",
"location_type": "town",
"discovered_locations": ["Town"]
},
"conversation_history": [],
"config": {},
"turn_order": [],
"current_turn": 0,
"party_member_ids": []
}
mock_document = Mock()
mock_document.data = {
'userId': 'user_456',
'sessionData': __import__('json').dumps(session_data)
}
mock_db.get_document.return_value = mock_document
service = SessionService()
updated = service.update_location(
session_id="sess_123",
new_location="Dark Forest",
location_type=LocationType.WILDERNESS
)
assert updated.game_state.current_location == "Dark Forest"
assert updated.game_state.location_type == LocationType.WILDERNESS
assert "Dark Forest" in updated.game_state.discovered_locations
def test_add_active_quest(self, mock_db, mock_appwrite, mock_character_service):
"""Test adding an active quest."""
session_data = {
"session_id": "sess_123",
"session_type": "solo",
"user_id": "user_456",
"turn_number": 0,
"status": "active",
"game_state": {
"active_quests": [],
"current_location": "Town",
"location_type": "town"
},
"conversation_history": [],
"config": {},
"turn_order": [],
"current_turn": 0,
"party_member_ids": []
}
mock_document = Mock()
mock_document.data = {
'userId': 'user_456',
'sessionData': __import__('json').dumps(session_data)
}
mock_db.get_document.return_value = mock_document
service = SessionService()
updated = service.add_active_quest("sess_123", "quest_goblin")
assert "quest_goblin" in updated.game_state.active_quests
def test_add_active_quest_limit(self, mock_db, mock_appwrite, mock_character_service):
"""Test adding quest fails when max reached."""
session_data = {
"session_id": "sess_123",
"session_type": "solo",
"user_id": "user_456",
"turn_number": 0,
"status": "active",
"game_state": {
"active_quests": ["quest_1", "quest_2"], # Already at max
"current_location": "Town",
"location_type": "town"
},
"conversation_history": [],
"config": {},
"turn_order": [],
"current_turn": 0,
"party_member_ids": []
}
mock_document = Mock()
mock_document.data = {
'userId': 'user_456',
'sessionData': __import__('json').dumps(session_data)
}
mock_db.get_document.return_value = mock_document
service = SessionService()
with pytest.raises(SessionValidationError):
service.add_active_quest("sess_123", "quest_3")
def test_remove_active_quest(self, mock_db, mock_appwrite, mock_character_service):
"""Test removing an active quest."""
session_data = {
"session_id": "sess_123",
"session_type": "solo",
"user_id": "user_456",
"turn_number": 0,
"status": "active",
"game_state": {
"active_quests": ["quest_1", "quest_2"],
"current_location": "Town",
"location_type": "town"
},
"conversation_history": [],
"config": {},
"turn_order": [],
"current_turn": 0,
"party_member_ids": []
}
mock_document = Mock()
mock_document.data = {
'userId': 'user_456',
'sessionData': __import__('json').dumps(session_data)
}
mock_db.get_document.return_value = mock_document
service = SessionService()
updated = service.remove_active_quest("sess_123", "quest_1")
assert "quest_1" not in updated.game_state.active_quests
assert "quest_2" in updated.game_state.active_quests
def test_add_world_event(self, mock_db, mock_appwrite, mock_character_service):
"""Test adding a world event."""
session_data = {
"session_id": "sess_123",
"session_type": "solo",
"user_id": "user_456",
"turn_number": 0,
"status": "active",
"game_state": {
"world_events": [],
"current_location": "Town",
"location_type": "town"
},
"conversation_history": [],
"config": {},
"turn_order": [],
"current_turn": 0,
"party_member_ids": []
}
mock_document = Mock()
mock_document.data = {
'userId': 'user_456',
'sessionData': __import__('json').dumps(session_data)
}
mock_db.get_document.return_value = mock_document
service = SessionService()
updated = service.add_world_event("sess_123", {"type": "festival", "description": "A festival begins"})
assert len(updated.game_state.world_events) == 1
assert updated.game_state.world_events[0]["type"] == "festival"
assert "timestamp" in updated.game_state.world_events[0]
class TestSessionLifecycle:
"""Tests for session lifecycle management."""
def test_end_session(self, mock_db, mock_appwrite, mock_character_service):
"""Test ending a session."""
session_data = {
"session_id": "sess_123",
"session_type": "solo",
"user_id": "user_456",
"turn_number": 10,
"status": "active",
"game_state": {},
"conversation_history": [],
"config": {},
"turn_order": [],
"current_turn": 0,
"party_member_ids": []
}
mock_document = Mock()
mock_document.data = {
'userId': 'user_456',
'sessionData': __import__('json').dumps(session_data)
}
mock_db.get_document.return_value = mock_document
service = SessionService()
updated = service.end_session("sess_123", "user_456")
assert updated.status == SessionStatus.COMPLETED
mock_db.update_document.assert_called_once()
def test_count_user_sessions(self, mock_db, mock_appwrite, mock_character_service):
"""Test counting user sessions."""
mock_db.count_documents.return_value = 3
service = SessionService()
count = service.count_user_sessions("user_123", active_only=True)
assert count == 3
mock_db.count_documents.assert_called_once()

198
api/tests/test_stats.py Normal file
View File

@@ -0,0 +1,198 @@
"""
Unit tests for Stats dataclass.
Tests computed properties, serialization, and basic operations.
"""
import pytest
from app.models.stats import Stats
def test_stats_default_values():
"""Test that Stats initializes with default values."""
stats = Stats()
assert stats.strength == 10
assert stats.dexterity == 10
assert stats.constitution == 10
assert stats.intelligence == 10
assert stats.wisdom == 10
assert stats.charisma == 10
def test_stats_custom_values():
"""Test creating Stats with custom values."""
stats = Stats(
strength=15,
dexterity=12,
constitution=14,
intelligence=8,
wisdom=10,
charisma=11,
)
assert stats.strength == 15
assert stats.dexterity == 12
assert stats.constitution == 14
assert stats.intelligence == 8
assert stats.wisdom == 10
assert stats.charisma == 11
def test_hit_points_calculation():
"""Test HP calculation: 10 + (constitution × 2)."""
stats = Stats(constitution=10)
assert stats.hit_points == 30 # 10 + (10 × 2)
stats = Stats(constitution=15)
assert stats.hit_points == 40 # 10 + (15 × 2)
stats = Stats(constitution=20)
assert stats.hit_points == 50 # 10 + (20 × 2)
def test_mana_points_calculation():
"""Test MP calculation: 10 + (intelligence × 2)."""
stats = Stats(intelligence=10)
assert stats.mana_points == 30 # 10 + (10 × 2)
stats = Stats(intelligence=15)
assert stats.mana_points == 40 # 10 + (15 × 2)
stats = Stats(intelligence=8)
assert stats.mana_points == 26 # 10 + (8 × 2)
def test_defense_calculation():
"""Test defense calculation: constitution // 2."""
stats = Stats(constitution=10)
assert stats.defense == 5 # 10 // 2
stats = Stats(constitution=15)
assert stats.defense == 7 # 15 // 2
stats = Stats(constitution=21)
assert stats.defense == 10 # 21 // 2
def test_resistance_calculation():
"""Test resistance calculation: wisdom // 2."""
stats = Stats(wisdom=10)
assert stats.resistance == 5 # 10 // 2
stats = Stats(wisdom=14)
assert stats.resistance == 7 # 14 // 2
stats = Stats(wisdom=9)
assert stats.resistance == 4 # 9 // 2
def test_stats_serialization():
"""Test to_dict() serialization."""
stats = Stats(
strength=15,
dexterity=12,
constitution=14,
intelligence=10,
wisdom=11,
charisma=8,
)
data = stats.to_dict()
assert data["strength"] == 15
assert data["dexterity"] == 12
assert data["constitution"] == 14
assert data["intelligence"] == 10
assert data["wisdom"] == 11
assert data["charisma"] == 8
def test_stats_deserialization():
"""Test from_dict() deserialization."""
data = {
"strength": 18,
"dexterity": 14,
"constitution": 16,
"intelligence": 12,
"wisdom": 10,
"charisma": 9,
}
stats = Stats.from_dict(data)
assert stats.strength == 18
assert stats.dexterity == 14
assert stats.constitution == 16
assert stats.intelligence == 12
assert stats.wisdom == 10
assert stats.charisma == 9
def test_stats_deserialization_with_missing_values():
"""Test from_dict() with missing values (should use defaults)."""
data = {
"strength": 15,
# Missing other stats
}
stats = Stats.from_dict(data)
assert stats.strength == 15
assert stats.dexterity == 10 # Default
assert stats.constitution == 10 # Default
assert stats.intelligence == 10 # Default
assert stats.wisdom == 10 # Default
assert stats.charisma == 10 # Default
def test_stats_round_trip_serialization():
"""Test that serialization and deserialization preserve data."""
original = Stats(
strength=20,
dexterity=15,
constitution=18,
intelligence=10,
wisdom=12,
charisma=14,
)
# Serialize then deserialize
data = original.to_dict()
restored = Stats.from_dict(data)
assert restored.strength == original.strength
assert restored.dexterity == original.dexterity
assert restored.constitution == original.constitution
assert restored.intelligence == original.intelligence
assert restored.wisdom == original.wisdom
assert restored.charisma == original.charisma
def test_stats_copy():
"""Test that copy() creates an independent copy."""
original = Stats(strength=15, dexterity=12, constitution=14)
copy = original.copy()
assert copy.strength == original.strength
assert copy.dexterity == original.dexterity
assert copy.constitution == original.constitution
# Modify copy
copy.strength = 20
# Original should be unchanged
assert original.strength == 15
assert copy.strength == 20
def test_stats_repr():
"""Test string representation."""
stats = Stats(strength=15, constitution=12, intelligence=10)
repr_str = repr(stats)
assert "STR=15" in repr_str
assert "CON=12" in repr_str
assert "INT=10" in repr_str
assert "HP=" in repr_str
assert "MP=" in repr_str

View File

@@ -0,0 +1,460 @@
"""
Tests for UsageTrackingService.
These tests verify:
- Cost calculation for different models
- Usage logging functionality
- Daily and monthly usage aggregation
- Static helper methods
"""
import pytest
from datetime import datetime, date, timezone, timedelta
from unittest.mock import Mock, MagicMock, patch
from uuid import uuid4
from app.services.usage_tracking_service import (
UsageTrackingService,
MODEL_COSTS,
DEFAULT_COST
)
from app.models.ai_usage import (
AIUsageLog,
DailyUsageSummary,
MonthlyUsageSummary,
TaskType
)
class TestAIUsageLogModel:
"""Tests for the AIUsageLog dataclass."""
def test_to_dict(self):
"""Test conversion to dictionary."""
log = AIUsageLog(
log_id="log_123",
user_id="user_456",
timestamp=datetime(2025, 11, 21, 10, 30, 0, tzinfo=timezone.utc),
model="anthropic/claude-3.5-sonnet",
tokens_input=100,
tokens_output=350,
tokens_total=450,
estimated_cost=0.00555,
task_type=TaskType.STORY_PROGRESSION,
session_id="sess_789",
character_id="char_abc",
request_duration_ms=1500,
success=True,
error_message=None
)
result = log.to_dict()
assert result["log_id"] == "log_123"
assert result["user_id"] == "user_456"
assert result["model"] == "anthropic/claude-3.5-sonnet"
assert result["tokens_total"] == 450
assert result["task_type"] == "story_progression"
assert result["success"] is True
def test_from_dict(self):
"""Test creation from dictionary."""
data = {
"log_id": "log_123",
"user_id": "user_456",
"timestamp": "2025-11-21T10:30:00+00:00",
"model": "anthropic/claude-3.5-sonnet",
"tokens_input": 100,
"tokens_output": 350,
"tokens_total": 450,
"estimated_cost": 0.00555,
"task_type": "story_progression",
"session_id": "sess_789",
"character_id": "char_abc",
"request_duration_ms": 1500,
"success": True,
"error_message": None
}
log = AIUsageLog.from_dict(data)
assert log.log_id == "log_123"
assert log.user_id == "user_456"
assert log.task_type == TaskType.STORY_PROGRESSION
assert log.tokens_total == 450
def test_from_dict_with_invalid_task_type(self):
"""Test handling of invalid task type."""
data = {
"log_id": "log_123",
"user_id": "user_456",
"timestamp": "2025-11-21T10:30:00+00:00",
"model": "test-model",
"tokens_input": 100,
"tokens_output": 200,
"tokens_total": 300,
"estimated_cost": 0.001,
"task_type": "invalid_type"
}
log = AIUsageLog.from_dict(data)
# Should default to GENERAL
assert log.task_type == TaskType.GENERAL
class TestDailyUsageSummary:
"""Tests for the DailyUsageSummary dataclass."""
def test_to_dict(self):
"""Test conversion to dictionary."""
summary = DailyUsageSummary(
date=date(2025, 11, 21),
user_id="user_123",
total_requests=15,
total_tokens=6750,
total_input_tokens=2000,
total_output_tokens=4750,
estimated_cost=0.45,
requests_by_task={"story_progression": 10, "combat_narration": 5}
)
result = summary.to_dict()
assert result["date"] == "2025-11-21"
assert result["total_requests"] == 15
assert result["estimated_cost"] == 0.45
assert result["requests_by_task"]["story_progression"] == 10
class TestCostCalculation:
"""Tests for cost calculation functionality."""
def test_calculate_cost_llama(self):
"""Test cost calculation for Llama model."""
cost = UsageTrackingService.estimate_cost_for_model(
model="meta/meta-llama-3-8b-instruct",
tokens_input=1000,
tokens_output=1000
)
# Llama: $0.0001 per 1K input + $0.0001 per 1K output
expected = 0.0001 + 0.0001
assert abs(cost - expected) < 0.000001
def test_calculate_cost_haiku(self):
"""Test cost calculation for Claude Haiku."""
cost = UsageTrackingService.estimate_cost_for_model(
model="anthropic/claude-3.5-haiku",
tokens_input=1000,
tokens_output=1000
)
# Haiku: $0.001 per 1K input + $0.005 per 1K output
expected = 0.001 + 0.005
assert abs(cost - expected) < 0.000001
def test_calculate_cost_sonnet(self):
"""Test cost calculation for Claude Sonnet."""
cost = UsageTrackingService.estimate_cost_for_model(
model="anthropic/claude-3.5-sonnet",
tokens_input=1000,
tokens_output=1000
)
# Sonnet: $0.003 per 1K input + $0.015 per 1K output
expected = 0.003 + 0.015
assert abs(cost - expected) < 0.000001
def test_calculate_cost_opus(self):
"""Test cost calculation for Claude Opus."""
cost = UsageTrackingService.estimate_cost_for_model(
model="anthropic/claude-3-opus",
tokens_input=1000,
tokens_output=1000
)
# Opus: $0.015 per 1K input + $0.075 per 1K output
expected = 0.015 + 0.075
assert abs(cost - expected) < 0.000001
def test_calculate_cost_unknown_model(self):
"""Test cost calculation for unknown model uses default."""
cost = UsageTrackingService.estimate_cost_for_model(
model="unknown/model",
tokens_input=1000,
tokens_output=1000
)
# Default: $0.001 per 1K input + $0.005 per 1K output
expected = DEFAULT_COST["input"] + DEFAULT_COST["output"]
assert abs(cost - expected) < 0.000001
def test_calculate_cost_fractional_tokens(self):
"""Test cost calculation with fractional token counts."""
cost = UsageTrackingService.estimate_cost_for_model(
model="anthropic/claude-3.5-sonnet",
tokens_input=500,
tokens_output=250
)
# Sonnet: (500/1000 * 0.003) + (250/1000 * 0.015)
expected = 0.0015 + 0.00375
assert abs(cost - expected) < 0.000001
def test_get_model_cost_info(self):
"""Test getting cost info for a model."""
cost_info = UsageTrackingService.get_model_cost_info(
"anthropic/claude-3.5-sonnet"
)
assert cost_info["input"] == 0.003
assert cost_info["output"] == 0.015
def test_get_model_cost_info_unknown(self):
"""Test getting cost info for unknown model."""
cost_info = UsageTrackingService.get_model_cost_info("unknown/model")
assert cost_info == DEFAULT_COST
class TestUsageTrackingService:
"""Tests for UsageTrackingService class."""
@pytest.fixture
def mock_env(self):
"""Set up mock environment variables."""
with patch.dict('os.environ', {
'APPWRITE_ENDPOINT': 'https://cloud.appwrite.io/v1',
'APPWRITE_PROJECT_ID': 'test_project',
'APPWRITE_API_KEY': 'test_api_key',
'APPWRITE_DATABASE_ID': 'test_db'
}):
yield
@pytest.fixture
def mock_databases(self):
"""Create mock Databases service."""
with patch('app.services.usage_tracking_service.Databases') as mock:
yield mock
@pytest.fixture
def service(self, mock_env, mock_databases):
"""Create UsageTrackingService instance with mocked dependencies."""
service = UsageTrackingService()
return service
def test_init_missing_env(self):
"""Test initialization fails with missing env vars."""
with patch.dict('os.environ', {}, clear=True):
with pytest.raises(ValueError, match="Appwrite configuration incomplete"):
UsageTrackingService()
def test_log_usage_success(self, service):
"""Test logging usage successfully."""
# Mock the create_document response
service.databases.create_document = Mock(return_value={
"$id": "doc_123"
})
result = service.log_usage(
user_id="user_123",
model="anthropic/claude-3.5-sonnet",
tokens_input=100,
tokens_output=350,
task_type=TaskType.STORY_PROGRESSION,
session_id="sess_789"
)
# Verify result
assert result.user_id == "user_123"
assert result.model == "anthropic/claude-3.5-sonnet"
assert result.tokens_input == 100
assert result.tokens_output == 350
assert result.tokens_total == 450
assert result.task_type == TaskType.STORY_PROGRESSION
assert result.estimated_cost > 0
# Verify Appwrite was called
service.databases.create_document.assert_called_once()
call_args = service.databases.create_document.call_args
assert call_args.kwargs["database_id"] == "test_db"
assert call_args.kwargs["collection_id"] == "ai_usage_logs"
def test_log_usage_with_error(self, service):
"""Test logging usage when request failed."""
service.databases.create_document = Mock(return_value={
"$id": "doc_123"
})
result = service.log_usage(
user_id="user_123",
model="anthropic/claude-3.5-sonnet",
tokens_input=100,
tokens_output=0,
task_type=TaskType.STORY_PROGRESSION,
success=False,
error_message="API timeout"
)
assert result.success is False
assert result.error_message == "API timeout"
assert result.tokens_output == 0
def test_get_daily_usage(self, service):
"""Test getting daily usage summary."""
# Mock list_rows response
service.tables_db.list_rows = Mock(return_value={
"rows": [
{
"tokens_input": 100,
"tokens_output": 300,
"tokens_total": 400,
"estimated_cost": 0.005,
"task_type": "story_progression"
},
{
"tokens_input": 150,
"tokens_output": 350,
"tokens_total": 500,
"estimated_cost": 0.006,
"task_type": "story_progression"
},
{
"tokens_input": 50,
"tokens_output": 200,
"tokens_total": 250,
"estimated_cost": 0.003,
"task_type": "combat_narration"
}
]
})
result = service.get_daily_usage("user_123", date(2025, 11, 21))
assert result.user_id == "user_123"
assert result.date == date(2025, 11, 21)
assert result.total_requests == 3
assert result.total_tokens == 1150
assert result.total_input_tokens == 300
assert result.total_output_tokens == 850
assert abs(result.estimated_cost - 0.014) < 0.0001
assert result.requests_by_task["story_progression"] == 2
assert result.requests_by_task["combat_narration"] == 1
def test_get_daily_usage_empty(self, service):
"""Test getting daily usage when no usage exists."""
service.tables_db.list_rows = Mock(return_value={
"rows": []
})
result = service.get_daily_usage("user_123", date(2025, 11, 21))
assert result.total_requests == 0
assert result.total_tokens == 0
assert result.estimated_cost == 0.0
assert result.requests_by_task == {}
def test_get_monthly_cost(self, service):
"""Test getting monthly cost summary."""
service.tables_db.list_rows = Mock(return_value={
"rows": [
{"tokens_total": 1000, "estimated_cost": 0.01},
{"tokens_total": 2000, "estimated_cost": 0.02},
{"tokens_total": 1500, "estimated_cost": 0.015}
]
})
result = service.get_monthly_cost("user_123", 2025, 11)
assert result.year == 2025
assert result.month == 11
assert result.user_id == "user_123"
assert result.total_requests == 3
assert result.total_tokens == 4500
assert abs(result.estimated_cost - 0.045) < 0.0001
def test_get_monthly_cost_invalid_month(self, service):
"""Test monthly cost with invalid month raises ValueError."""
with pytest.raises(ValueError, match="Invalid month"):
service.get_monthly_cost("user_123", 2025, 13)
with pytest.raises(ValueError, match="Invalid month"):
service.get_monthly_cost("user_123", 2025, 0)
def test_get_total_daily_cost(self, service):
"""Test getting total daily cost across all users."""
service.tables_db.list_rows = Mock(return_value={
"rows": [
{"estimated_cost": 0.10},
{"estimated_cost": 0.25},
{"estimated_cost": 0.15}
]
})
result = service.get_total_daily_cost(date(2025, 11, 21))
assert abs(result - 0.50) < 0.0001
def test_get_user_request_count_today(self, service):
"""Test getting user request count for today."""
service.tables_db.list_rows = Mock(return_value={
"rows": [
{"tokens_total": 100, "tokens_input": 30, "tokens_output": 70, "estimated_cost": 0.001, "task_type": "story_progression"},
{"tokens_total": 200, "tokens_input": 50, "tokens_output": 150, "estimated_cost": 0.002, "task_type": "story_progression"}
]
})
result = service.get_user_request_count_today("user_123")
assert result == 2
class TestCostEstimations:
"""Tests for realistic cost estimation scenarios."""
def test_free_tier_daily_cost(self):
"""Test estimated daily cost for free tier user with Llama."""
# 20 requests per day, average 500 total tokens each
total_input = 20 * 200
total_output = 20 * 300
cost = UsageTrackingService.estimate_cost_for_model(
model="meta/meta-llama-3-8b-instruct",
tokens_input=total_input,
tokens_output=total_output
)
# Should be very cheap (essentially free)
assert cost < 0.01
def test_premium_tier_daily_cost(self):
"""Test estimated daily cost for premium tier user with Sonnet."""
# 100 requests per day, average 1000 total tokens each
total_input = 100 * 300
total_output = 100 * 700
cost = UsageTrackingService.estimate_cost_for_model(
model="anthropic/claude-3.5-sonnet",
tokens_input=total_input,
tokens_output=total_output
)
# Should be under $2/day for heavy usage
assert cost < 2.0
def test_elite_tier_monthly_cost(self):
"""Test estimated monthly cost for elite tier user."""
# 200 requests per day * 30 days = 6000 requests
# Average 1500 tokens per request
total_input = 6000 * 500
total_output = 6000 * 1000
cost = UsageTrackingService.estimate_cost_for_model(
model="anthropic/claude-4.5-sonnet",
tokens_input=total_input,
tokens_output=total_output
)
# Elite tier should be under $100/month even with heavy usage
assert cost < 100.0