584 lines
20 KiB
Python
584 lines
20 KiB
Python
"""
|
|
Tests for the NarrativeGenerator wrapper.
|
|
|
|
These tests use mocked AI clients to verify the generator's
|
|
behavior without making actual API calls.
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from app.ai.narrative_generator import (
|
|
NarrativeGenerator,
|
|
NarrativeResponse,
|
|
NarrativeGeneratorError,
|
|
)
|
|
from app.ai.replicate_client import ReplicateResponse, ReplicateClientError
|
|
from app.ai.model_selector import UserTier, ContextType, ModelSelector, ModelConfig, ModelType
|
|
from app.ai.prompt_templates import PromptTemplates
|
|
|
|
|
|
# Test fixtures
|
|
@pytest.fixture
|
|
def mock_replicate_client():
|
|
"""Create a mock Replicate client."""
|
|
client = MagicMock()
|
|
client.generate.return_value = ReplicateResponse(
|
|
text="The tavern falls silent as you step through the doorway...",
|
|
tokens_used=150,
|
|
model="meta/meta-llama-3-8b-instruct",
|
|
generation_time=2.5
|
|
)
|
|
return client
|
|
|
|
|
|
@pytest.fixture
|
|
def generator(mock_replicate_client):
|
|
"""Create a NarrativeGenerator with mocked dependencies."""
|
|
return NarrativeGenerator(replicate_client=mock_replicate_client)
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_character():
|
|
"""Sample character data for tests."""
|
|
return {
|
|
"name": "Aldric",
|
|
"level": 3,
|
|
"player_class": "Fighter",
|
|
"current_hp": 25,
|
|
"max_hp": 30,
|
|
"stats": {
|
|
"strength": 16,
|
|
"dexterity": 12,
|
|
"constitution": 14,
|
|
"intelligence": 10,
|
|
"wisdom": 11,
|
|
"charisma": 8
|
|
},
|
|
"skills": [
|
|
{"name": "Sword Mastery", "level": 2},
|
|
{"name": "Shield Block", "level": 1}
|
|
],
|
|
"effects": [],
|
|
"completed_quests": []
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_game_state():
|
|
"""Sample game state for tests."""
|
|
return {
|
|
"current_location": "The Rusty Anchor Tavern",
|
|
"location_type": "TAVERN",
|
|
"discovered_locations": ["Crossroads Village", "Dark Forest"],
|
|
"active_quests": [],
|
|
"time_of_day": "Evening"
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_combat_state():
|
|
"""Sample combat state for tests."""
|
|
return {
|
|
"round_number": 2,
|
|
"current_turn": "player",
|
|
"enemies": [
|
|
{
|
|
"name": "Goblin Scout",
|
|
"current_hp": 8,
|
|
"max_hp": 12,
|
|
"effects": []
|
|
}
|
|
]
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_npc():
|
|
"""Sample NPC data for tests."""
|
|
return {
|
|
"name": "Barkeep Magnus",
|
|
"role": "Tavern Owner",
|
|
"personality": "Gruff but kind-hearted",
|
|
"speaking_style": "Short sentences, occasional jokes",
|
|
"goals": "Keep the tavern running, help adventurers"
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_quests():
|
|
"""Sample eligible quests for tests."""
|
|
return [
|
|
{
|
|
"quest_id": "goblin_cave",
|
|
"name": "Clear the Goblin Cave",
|
|
"difficulty": "EASY",
|
|
"quest_giver": "Village Elder",
|
|
"description": "A nearby cave has been overrun by goblins.",
|
|
"narrative_hooks": [
|
|
"Farmers complain about stolen livestock",
|
|
"Goblin tracks spotted on the road"
|
|
]
|
|
},
|
|
{
|
|
"quest_id": "missing_merchant",
|
|
"name": "Find the Missing Merchant",
|
|
"difficulty": "MEDIUM",
|
|
"quest_giver": "Guild Master",
|
|
"description": "A merchant caravan has gone missing.",
|
|
"narrative_hooks": [
|
|
"The merchant was carrying valuable goods",
|
|
"His family is worried sick"
|
|
]
|
|
}
|
|
]
|
|
|
|
|
|
class TestNarrativeGeneratorInit:
|
|
"""Tests for NarrativeGenerator initialization."""
|
|
|
|
def test_init_with_defaults(self):
|
|
"""Test initialization with default dependencies."""
|
|
with patch('app.ai.narrative_generator.ReplicateClient'):
|
|
generator = NarrativeGenerator()
|
|
assert generator.model_selector is not None
|
|
assert generator.prompt_templates is not None
|
|
|
|
def test_init_with_custom_selector(self):
|
|
"""Test initialization with custom model selector."""
|
|
custom_selector = ModelSelector()
|
|
generator = NarrativeGenerator(model_selector=custom_selector)
|
|
assert generator.model_selector is custom_selector
|
|
|
|
def test_init_with_custom_client(self, mock_replicate_client):
|
|
"""Test initialization with custom Replicate client."""
|
|
generator = NarrativeGenerator(replicate_client=mock_replicate_client)
|
|
assert generator.replicate_client is mock_replicate_client
|
|
|
|
|
|
class TestGenerateStoryResponse:
|
|
"""Tests for generate_story_response method."""
|
|
|
|
def test_basic_story_generation(
|
|
self, generator, mock_replicate_client, sample_character, sample_game_state
|
|
):
|
|
"""Test basic story response generation."""
|
|
response = generator.generate_story_response(
|
|
character=sample_character,
|
|
action="I search the room for hidden doors",
|
|
game_state=sample_game_state,
|
|
user_tier=UserTier.FREE
|
|
)
|
|
|
|
assert isinstance(response, NarrativeResponse)
|
|
assert response.narrative is not None
|
|
assert response.tokens_used > 0
|
|
assert response.context_type == "story_progression"
|
|
assert response.generation_time > 0
|
|
|
|
# Verify client was called
|
|
mock_replicate_client.generate.assert_called_once()
|
|
|
|
def test_story_with_conversation_history(
|
|
self, generator, mock_replicate_client, sample_character, sample_game_state
|
|
):
|
|
"""Test story generation with conversation history."""
|
|
history = [
|
|
{
|
|
"turn": 1,
|
|
"action": "I enter the tavern",
|
|
"dm_response": "The tavern is warm and inviting..."
|
|
},
|
|
{
|
|
"turn": 2,
|
|
"action": "I approach the bar",
|
|
"dm_response": "The barkeep nods in greeting..."
|
|
}
|
|
]
|
|
|
|
response = generator.generate_story_response(
|
|
character=sample_character,
|
|
action="I ask about local rumors",
|
|
game_state=sample_game_state,
|
|
user_tier=UserTier.PREMIUM,
|
|
conversation_history=history
|
|
)
|
|
|
|
assert isinstance(response, NarrativeResponse)
|
|
mock_replicate_client.generate.assert_called_once()
|
|
|
|
def test_story_with_world_context(
|
|
self, generator, mock_replicate_client, sample_character, sample_game_state
|
|
):
|
|
"""Test story generation with additional world context."""
|
|
response = generator.generate_story_response(
|
|
character=sample_character,
|
|
action="I look around",
|
|
game_state=sample_game_state,
|
|
user_tier=UserTier.ELITE,
|
|
world_context="A festival is being celebrated in the village"
|
|
)
|
|
|
|
assert isinstance(response, NarrativeResponse)
|
|
|
|
def test_story_uses_correct_model_for_tier(
|
|
self, generator, mock_replicate_client, sample_character, sample_game_state
|
|
):
|
|
"""Test that correct model is selected based on tier."""
|
|
for tier in UserTier:
|
|
mock_replicate_client.generate.reset_mock()
|
|
|
|
generator.generate_story_response(
|
|
character=sample_character,
|
|
action="Test action",
|
|
game_state=sample_game_state,
|
|
user_tier=tier
|
|
)
|
|
|
|
# Verify generate was called with appropriate model
|
|
call_kwargs = mock_replicate_client.generate.call_args
|
|
assert call_kwargs is not None
|
|
|
|
def test_story_handles_client_error(
|
|
self, generator, mock_replicate_client, sample_character, sample_game_state
|
|
):
|
|
"""Test error handling when client fails."""
|
|
mock_replicate_client.generate.side_effect = ReplicateClientError("API error")
|
|
|
|
with pytest.raises(NarrativeGeneratorError) as exc_info:
|
|
generator.generate_story_response(
|
|
character=sample_character,
|
|
action="Test action",
|
|
game_state=sample_game_state,
|
|
user_tier=UserTier.FREE
|
|
)
|
|
|
|
assert "AI generation failed" in str(exc_info.value)
|
|
|
|
|
|
class TestGenerateCombatNarration:
|
|
"""Tests for generate_combat_narration method."""
|
|
|
|
def test_basic_combat_narration(
|
|
self, generator, mock_replicate_client, sample_character, sample_combat_state
|
|
):
|
|
"""Test basic combat narration generation."""
|
|
action_result = {
|
|
"hit": True,
|
|
"damage": 8,
|
|
"target": "Goblin Scout"
|
|
}
|
|
|
|
response = generator.generate_combat_narration(
|
|
character=sample_character,
|
|
combat_state=sample_combat_state,
|
|
action="swings their sword",
|
|
action_result=action_result,
|
|
user_tier=UserTier.BASIC
|
|
)
|
|
|
|
assert isinstance(response, NarrativeResponse)
|
|
assert response.context_type == "combat_narration"
|
|
|
|
def test_critical_hit_narration(
|
|
self, generator, mock_replicate_client, sample_character, sample_combat_state
|
|
):
|
|
"""Test combat narration for critical hit."""
|
|
action_result = {
|
|
"hit": True,
|
|
"damage": 16,
|
|
"target": "Goblin Scout"
|
|
}
|
|
|
|
response = generator.generate_combat_narration(
|
|
character=sample_character,
|
|
combat_state=sample_combat_state,
|
|
action="strikes with precision",
|
|
action_result=action_result,
|
|
user_tier=UserTier.PREMIUM,
|
|
is_critical=True
|
|
)
|
|
|
|
assert isinstance(response, NarrativeResponse)
|
|
|
|
def test_finishing_blow_narration(
|
|
self, generator, mock_replicate_client, sample_character, sample_combat_state
|
|
):
|
|
"""Test combat narration for finishing blow."""
|
|
action_result = {
|
|
"hit": True,
|
|
"damage": 10,
|
|
"target": "Goblin Scout"
|
|
}
|
|
|
|
response = generator.generate_combat_narration(
|
|
character=sample_character,
|
|
combat_state=sample_combat_state,
|
|
action="delivers the final blow",
|
|
action_result=action_result,
|
|
user_tier=UserTier.ELITE,
|
|
is_finishing_blow=True
|
|
)
|
|
|
|
assert isinstance(response, NarrativeResponse)
|
|
|
|
def test_miss_narration(
|
|
self, generator, mock_replicate_client, sample_character, sample_combat_state
|
|
):
|
|
"""Test combat narration for a miss."""
|
|
action_result = {
|
|
"hit": False,
|
|
"damage": 0,
|
|
"target": "Goblin Scout"
|
|
}
|
|
|
|
response = generator.generate_combat_narration(
|
|
character=sample_character,
|
|
combat_state=sample_combat_state,
|
|
action="swings wildly",
|
|
action_result=action_result,
|
|
user_tier=UserTier.FREE
|
|
)
|
|
|
|
assert isinstance(response, NarrativeResponse)
|
|
|
|
|
|
class TestGenerateQuestSelection:
|
|
"""Tests for generate_quest_selection method."""
|
|
|
|
def test_basic_quest_selection(
|
|
self, generator, mock_replicate_client, sample_character, sample_game_state, sample_quests
|
|
):
|
|
"""Test basic quest selection."""
|
|
# Mock response to return a valid quest_id
|
|
mock_replicate_client.generate.return_value = ReplicateResponse(
|
|
text="goblin_cave",
|
|
tokens_used=50,
|
|
model="meta/meta-llama-3-8b-instruct",
|
|
generation_time=1.0
|
|
)
|
|
|
|
quest_id = generator.generate_quest_selection(
|
|
character=sample_character,
|
|
eligible_quests=sample_quests,
|
|
game_context=sample_game_state,
|
|
user_tier=UserTier.FREE
|
|
)
|
|
|
|
assert quest_id == "goblin_cave"
|
|
|
|
def test_quest_selection_with_recent_actions(
|
|
self, generator, mock_replicate_client, sample_character, sample_game_state, sample_quests
|
|
):
|
|
"""Test quest selection with recent actions."""
|
|
mock_replicate_client.generate.return_value = ReplicateResponse(
|
|
text="missing_merchant",
|
|
tokens_used=50,
|
|
model="meta/meta-llama-3-8b-instruct",
|
|
generation_time=1.0
|
|
)
|
|
|
|
recent_actions = [
|
|
"Asked about missing traders",
|
|
"Investigated the market square"
|
|
]
|
|
|
|
quest_id = generator.generate_quest_selection(
|
|
character=sample_character,
|
|
eligible_quests=sample_quests,
|
|
game_context=sample_game_state,
|
|
user_tier=UserTier.PREMIUM,
|
|
recent_actions=recent_actions
|
|
)
|
|
|
|
assert quest_id == "missing_merchant"
|
|
|
|
def test_quest_selection_invalid_response_fallback(
|
|
self, generator, mock_replicate_client, sample_character, sample_game_state, sample_quests
|
|
):
|
|
"""Test fallback when AI returns invalid quest_id."""
|
|
# Mock response with invalid quest_id
|
|
mock_replicate_client.generate.return_value = ReplicateResponse(
|
|
text="invalid_quest_id",
|
|
tokens_used=50,
|
|
model="meta/meta-llama-3-8b-instruct",
|
|
generation_time=1.0
|
|
)
|
|
|
|
quest_id = generator.generate_quest_selection(
|
|
character=sample_character,
|
|
eligible_quests=sample_quests,
|
|
game_context=sample_game_state,
|
|
user_tier=UserTier.FREE
|
|
)
|
|
|
|
# Should fall back to first eligible quest
|
|
assert quest_id == "goblin_cave"
|
|
|
|
def test_quest_selection_no_quests_error(
|
|
self, generator, sample_character, sample_game_state
|
|
):
|
|
"""Test error when no eligible quests provided."""
|
|
with pytest.raises(NarrativeGeneratorError) as exc_info:
|
|
generator.generate_quest_selection(
|
|
character=sample_character,
|
|
eligible_quests=[],
|
|
game_context=sample_game_state,
|
|
user_tier=UserTier.FREE
|
|
)
|
|
|
|
assert "No eligible quests" in str(exc_info.value)
|
|
|
|
|
|
class TestGenerateNPCDialogue:
|
|
"""Tests for generate_npc_dialogue method."""
|
|
|
|
def test_basic_npc_dialogue(
|
|
self, generator, mock_replicate_client, sample_character, sample_npc, sample_game_state
|
|
):
|
|
"""Test basic NPC dialogue generation."""
|
|
mock_replicate_client.generate.return_value = ReplicateResponse(
|
|
text='*wipes down the bar* "Aye, rumors aplenty around here..."',
|
|
tokens_used=100,
|
|
model="anthropic/claude-3.5-haiku",
|
|
generation_time=1.5
|
|
)
|
|
|
|
response = generator.generate_npc_dialogue(
|
|
character=sample_character,
|
|
npc=sample_npc,
|
|
conversation_topic="What rumors have you heard?",
|
|
game_state=sample_game_state,
|
|
user_tier=UserTier.BASIC
|
|
)
|
|
|
|
assert isinstance(response, NarrativeResponse)
|
|
assert response.context_type == "npc_dialogue"
|
|
|
|
def test_npc_dialogue_with_relationship(
|
|
self, generator, mock_replicate_client, sample_character, sample_npc, sample_game_state
|
|
):
|
|
"""Test NPC dialogue with established relationship."""
|
|
response = generator.generate_npc_dialogue(
|
|
character=sample_character,
|
|
npc=sample_npc,
|
|
conversation_topic="Hello old friend",
|
|
game_state=sample_game_state,
|
|
user_tier=UserTier.PREMIUM,
|
|
npc_relationship="Friendly - helped defend the tavern last month"
|
|
)
|
|
|
|
assert isinstance(response, NarrativeResponse)
|
|
|
|
def test_npc_dialogue_with_previous_conversation(
|
|
self, generator, mock_replicate_client, sample_character, sample_npc, sample_game_state
|
|
):
|
|
"""Test NPC dialogue with previous conversation history."""
|
|
previous = [
|
|
{
|
|
"player_line": "What's on tap tonight?",
|
|
"npc_response": "Got some fine ale from the southern vineyards."
|
|
}
|
|
]
|
|
|
|
response = generator.generate_npc_dialogue(
|
|
character=sample_character,
|
|
npc=sample_npc,
|
|
conversation_topic="I'll take one",
|
|
game_state=sample_game_state,
|
|
user_tier=UserTier.FREE,
|
|
previous_dialogue=previous
|
|
)
|
|
|
|
assert isinstance(response, NarrativeResponse)
|
|
|
|
def test_npc_dialogue_with_special_knowledge(
|
|
self, generator, mock_replicate_client, sample_character, sample_npc, sample_game_state
|
|
):
|
|
"""Test NPC dialogue with special knowledge."""
|
|
response = generator.generate_npc_dialogue(
|
|
character=sample_character,
|
|
npc=sample_npc,
|
|
conversation_topic="Have you seen anything strange?",
|
|
game_state=sample_game_state,
|
|
user_tier=UserTier.ELITE,
|
|
npc_knowledge=["Secret passage in the cellar", "Hidden treasure map"]
|
|
)
|
|
|
|
assert isinstance(response, NarrativeResponse)
|
|
|
|
|
|
class TestModelSelection:
|
|
"""Tests for model selection behavior."""
|
|
|
|
def test_free_tier_uses_llama(
|
|
self, generator, mock_replicate_client, sample_character, sample_game_state
|
|
):
|
|
"""Test that free tier uses Llama model."""
|
|
generator.generate_story_response(
|
|
character=sample_character,
|
|
action="Test",
|
|
game_state=sample_game_state,
|
|
user_tier=UserTier.FREE
|
|
)
|
|
|
|
call_kwargs = mock_replicate_client.generate.call_args
|
|
assert call_kwargs is not None
|
|
# Model type should be Llama for free tier
|
|
model_arg = call_kwargs.kwargs.get('model')
|
|
if model_arg:
|
|
assert "llama" in str(model_arg).lower() or model_arg == ModelType.LLAMA_3_8B
|
|
|
|
def test_different_contexts_use_different_temperatures(self, generator):
|
|
"""Test that different contexts have different temperature settings."""
|
|
# Get model configs for different contexts
|
|
config_story = generator.model_selector.select_model(
|
|
UserTier.FREE, ContextType.STORY_PROGRESSION
|
|
)
|
|
config_quest = generator.model_selector.select_model(
|
|
UserTier.FREE, ContextType.QUEST_SELECTION
|
|
)
|
|
|
|
# Story should have higher temperature (more creative)
|
|
assert config_story.temperature > config_quest.temperature
|
|
|
|
|
|
class TestErrorHandling:
|
|
"""Tests for error handling behavior."""
|
|
|
|
def test_template_error_handling(self, mock_replicate_client):
|
|
"""Test handling of template errors."""
|
|
from app.ai.prompt_templates import PromptTemplateError
|
|
|
|
# Create generator with bad template path
|
|
with patch.object(PromptTemplates, 'render') as mock_render:
|
|
mock_render.side_effect = PromptTemplateError("Template not found")
|
|
|
|
generator = NarrativeGenerator(replicate_client=mock_replicate_client)
|
|
|
|
with pytest.raises(NarrativeGeneratorError) as exc_info:
|
|
generator.generate_story_response(
|
|
character={"name": "Test"},
|
|
action="Test",
|
|
game_state={"current_location": "Test"},
|
|
user_tier=UserTier.FREE
|
|
)
|
|
|
|
assert "Prompt template error" in str(exc_info.value)
|
|
|
|
def test_api_error_handling(
|
|
self, generator, mock_replicate_client, sample_character, sample_game_state
|
|
):
|
|
"""Test handling of API errors."""
|
|
mock_replicate_client.generate.side_effect = ReplicateClientError("Connection failed")
|
|
|
|
with pytest.raises(NarrativeGeneratorError) as exc_info:
|
|
generator.generate_story_response(
|
|
character=sample_character,
|
|
action="Test",
|
|
game_state=sample_game_state,
|
|
user_tier=UserTier.FREE
|
|
)
|
|
|
|
assert "AI generation failed" in str(exc_info.value)
|