572 lines
20 KiB
Python
572 lines
20 KiB
Python
"""
|
|
Unit tests for AI Task Jobs.
|
|
|
|
These tests verify the job enqueueing, status tracking,
|
|
and result storage functionality.
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import Mock, patch, MagicMock
|
|
import json
|
|
|
|
from app.tasks.ai_tasks import (
|
|
enqueue_ai_task,
|
|
process_ai_task,
|
|
get_job_status,
|
|
get_job_result,
|
|
JobStatus,
|
|
TaskType,
|
|
TaskPriority,
|
|
_store_job_status,
|
|
_update_job_status,
|
|
_store_job_result,
|
|
)
|
|
|
|
|
|
class TestEnqueueAITask:
|
|
"""Test AI task enqueueing."""
|
|
|
|
@patch('app.tasks.ai_tasks.get_queue')
|
|
@patch('app.tasks.ai_tasks._store_job_status')
|
|
def test_enqueue_narrative_task(self, mock_store_status, mock_get_queue):
|
|
"""Test enqueueing a narrative task."""
|
|
# Setup mock queue
|
|
mock_queue = MagicMock()
|
|
mock_job = MagicMock()
|
|
mock_job.id = "test_job_123"
|
|
mock_queue.enqueue.return_value = mock_job
|
|
mock_get_queue.return_value = mock_queue
|
|
|
|
# Enqueue task
|
|
result = enqueue_ai_task(
|
|
task_type="narrative",
|
|
user_id="user_123",
|
|
context={"action": "explore"},
|
|
priority="normal",
|
|
)
|
|
|
|
# Verify
|
|
assert "job_id" in result
|
|
assert result["status"] == "queued"
|
|
mock_queue.enqueue.assert_called_once()
|
|
mock_store_status.assert_called_once()
|
|
|
|
@patch('app.tasks.ai_tasks.get_queue')
|
|
@patch('app.tasks.ai_tasks._store_job_status')
|
|
def test_enqueue_high_priority_task(self, mock_store_status, mock_get_queue):
|
|
"""Test high priority task goes to front of queue."""
|
|
mock_queue = MagicMock()
|
|
mock_job = MagicMock()
|
|
mock_queue.enqueue.return_value = mock_job
|
|
mock_get_queue.return_value = mock_queue
|
|
|
|
enqueue_ai_task(
|
|
task_type="narrative",
|
|
user_id="user_123",
|
|
context={},
|
|
priority="high",
|
|
)
|
|
|
|
# Verify at_front=True for high priority
|
|
call_kwargs = mock_queue.enqueue.call_args[1]
|
|
assert call_kwargs["at_front"] is True
|
|
|
|
@patch('app.tasks.ai_tasks.get_queue')
|
|
@patch('app.tasks.ai_tasks._store_job_status')
|
|
def test_enqueue_with_session_and_character(self, mock_store_status, mock_get_queue):
|
|
"""Test enqueueing with session and character IDs."""
|
|
mock_queue = MagicMock()
|
|
mock_job = MagicMock()
|
|
mock_queue.enqueue.return_value = mock_job
|
|
mock_get_queue.return_value = mock_queue
|
|
|
|
result = enqueue_ai_task(
|
|
task_type="combat",
|
|
user_id="user_123",
|
|
context={"enemy": "goblin"},
|
|
session_id="sess_456",
|
|
character_id="char_789",
|
|
)
|
|
|
|
assert "job_id" in result
|
|
|
|
# Verify kwargs passed to enqueue include session and character
|
|
call_kwargs = mock_queue.enqueue.call_args[1]
|
|
job_kwargs = call_kwargs["kwargs"]
|
|
assert job_kwargs["session_id"] == "sess_456"
|
|
assert job_kwargs["character_id"] == "char_789"
|
|
|
|
def test_enqueue_invalid_task_type(self):
|
|
"""Test enqueueing with invalid task type raises error."""
|
|
with pytest.raises(ValueError) as exc_info:
|
|
enqueue_ai_task(
|
|
task_type="invalid_type",
|
|
user_id="user_123",
|
|
context={},
|
|
)
|
|
|
|
assert "Invalid task_type" in str(exc_info.value)
|
|
|
|
def test_enqueue_invalid_priority(self):
|
|
"""Test enqueueing with invalid priority raises error."""
|
|
with pytest.raises(ValueError) as exc_info:
|
|
enqueue_ai_task(
|
|
task_type="narrative",
|
|
user_id="user_123",
|
|
context={},
|
|
priority="invalid_priority",
|
|
)
|
|
|
|
assert "Invalid priority" in str(exc_info.value)
|
|
|
|
|
|
class TestProcessAITask:
|
|
"""Test AI task processing with mocked NarrativeGenerator."""
|
|
|
|
@pytest.fixture
|
|
def sample_character(self):
|
|
"""Sample character data for tests."""
|
|
return {
|
|
"name": "Aldric",
|
|
"level": 3,
|
|
"player_class": "Fighter",
|
|
"current_hp": 25,
|
|
"max_hp": 30,
|
|
"stats": {"strength": 16, "dexterity": 12},
|
|
"skills": [],
|
|
"effects": []
|
|
}
|
|
|
|
@pytest.fixture
|
|
def sample_game_state(self):
|
|
"""Sample game state for tests."""
|
|
return {
|
|
"current_location": "Tavern",
|
|
"location_type": "TAVERN",
|
|
"discovered_locations": [],
|
|
"active_quests": []
|
|
}
|
|
|
|
@pytest.fixture
|
|
def mock_narrative_response(self):
|
|
"""Mock NarrativeResponse for tests."""
|
|
from app.ai.narrative_generator import NarrativeResponse
|
|
return NarrativeResponse(
|
|
narrative="You enter the tavern...",
|
|
tokens_used=150,
|
|
model="meta/meta-llama-3-8b-instruct",
|
|
context_type="story_progression",
|
|
generation_time=2.5
|
|
)
|
|
|
|
@patch('app.tasks.ai_tasks._get_user_tier')
|
|
@patch('app.tasks.ai_tasks._update_game_session')
|
|
@patch('app.tasks.ai_tasks.NarrativeGenerator')
|
|
@patch('app.tasks.ai_tasks._update_job_status')
|
|
@patch('app.tasks.ai_tasks._store_job_result')
|
|
def test_process_narrative_task(
|
|
self, mock_store_result, mock_update_status, mock_generator_class,
|
|
mock_update_session, mock_get_tier, sample_character, sample_game_state, mock_narrative_response
|
|
):
|
|
"""Test processing a narrative task with NarrativeGenerator."""
|
|
from app.ai.model_selector import UserTier
|
|
|
|
# Setup mocks
|
|
mock_get_tier.return_value = UserTier.FREE
|
|
mock_generator = MagicMock()
|
|
mock_generator.generate_story_response.return_value = mock_narrative_response
|
|
mock_generator_class.return_value = mock_generator
|
|
|
|
# Process task
|
|
result = process_ai_task(
|
|
task_type="narrative",
|
|
user_id="user_123",
|
|
context={
|
|
"action": "I explore the tavern",
|
|
"character": sample_character,
|
|
"game_state": sample_game_state
|
|
},
|
|
job_id="job_123",
|
|
session_id="sess_456",
|
|
character_id="char_789"
|
|
)
|
|
|
|
# Verify result structure
|
|
assert "narrative" in result
|
|
assert result["narrative"] == "You enter the tavern..."
|
|
assert result["tokens_used"] == 150
|
|
assert result["model"] == "meta/meta-llama-3-8b-instruct"
|
|
|
|
# Verify generator was called
|
|
mock_generator.generate_story_response.assert_called_once()
|
|
|
|
# Verify session update was called
|
|
mock_update_session.assert_called_once()
|
|
|
|
@patch('app.tasks.ai_tasks._get_user_tier')
|
|
@patch('app.tasks.ai_tasks.NarrativeGenerator')
|
|
@patch('app.tasks.ai_tasks._update_job_status')
|
|
@patch('app.tasks.ai_tasks._store_job_result')
|
|
def test_process_combat_task(
|
|
self, mock_store_result, mock_update_status, mock_generator_class,
|
|
mock_get_tier, sample_character, mock_narrative_response
|
|
):
|
|
"""Test processing a combat task."""
|
|
from app.ai.model_selector import UserTier
|
|
|
|
mock_get_tier.return_value = UserTier.BASIC
|
|
mock_generator = MagicMock()
|
|
mock_generator.generate_combat_narration.return_value = mock_narrative_response
|
|
mock_generator_class.return_value = mock_generator
|
|
|
|
result = process_ai_task(
|
|
task_type="combat",
|
|
user_id="user_123",
|
|
context={
|
|
"character": sample_character,
|
|
"combat_state": {"round_number": 1, "enemies": [], "current_turn": "player"},
|
|
"action": "swings sword",
|
|
"action_result": {"hit": True, "damage": 10}
|
|
},
|
|
job_id="job_123",
|
|
)
|
|
|
|
assert "combat_narrative" in result
|
|
mock_generator.generate_combat_narration.assert_called_once()
|
|
|
|
@patch('app.tasks.ai_tasks._get_user_tier')
|
|
@patch('app.tasks.ai_tasks.NarrativeGenerator')
|
|
@patch('app.tasks.ai_tasks._update_job_status')
|
|
@patch('app.tasks.ai_tasks._store_job_result')
|
|
def test_process_quest_selection_task(
|
|
self, mock_store_result, mock_update_status, mock_generator_class,
|
|
mock_get_tier, sample_character, sample_game_state
|
|
):
|
|
"""Test processing a quest selection task."""
|
|
from app.ai.model_selector import UserTier
|
|
|
|
mock_get_tier.return_value = UserTier.FREE
|
|
mock_generator = MagicMock()
|
|
mock_generator.generate_quest_selection.return_value = "goblin_cave"
|
|
mock_generator_class.return_value = mock_generator
|
|
|
|
result = process_ai_task(
|
|
task_type="quest_selection",
|
|
user_id="user_123",
|
|
context={
|
|
"character": sample_character,
|
|
"eligible_quests": [{"quest_id": "goblin_cave"}],
|
|
"game_context": sample_game_state
|
|
},
|
|
job_id="job_123",
|
|
)
|
|
|
|
assert result["selected_quest_id"] == "goblin_cave"
|
|
mock_generator.generate_quest_selection.assert_called_once()
|
|
|
|
@patch('app.tasks.ai_tasks._get_user_tier')
|
|
@patch('app.tasks.ai_tasks.NarrativeGenerator')
|
|
@patch('app.tasks.ai_tasks._update_job_status')
|
|
@patch('app.tasks.ai_tasks._store_job_result')
|
|
def test_process_npc_dialogue_task(
|
|
self, mock_store_result, mock_update_status, mock_generator_class,
|
|
mock_get_tier, sample_character, sample_game_state, mock_narrative_response
|
|
):
|
|
"""Test processing an NPC dialogue task."""
|
|
from app.ai.model_selector import UserTier
|
|
|
|
mock_get_tier.return_value = UserTier.PREMIUM
|
|
mock_generator = MagicMock()
|
|
mock_generator.generate_npc_dialogue.return_value = mock_narrative_response
|
|
mock_generator_class.return_value = mock_generator
|
|
|
|
result = process_ai_task(
|
|
task_type="npc_dialogue",
|
|
user_id="user_123",
|
|
context={
|
|
"character": sample_character,
|
|
"npc": {"name": "Innkeeper", "role": "Barkeeper"},
|
|
"conversation_topic": "What's the news?",
|
|
"game_state": sample_game_state
|
|
},
|
|
job_id="job_123",
|
|
)
|
|
|
|
assert "dialogue" in result
|
|
mock_generator.generate_npc_dialogue.assert_called_once()
|
|
|
|
@patch('app.tasks.ai_tasks._update_job_status')
|
|
def test_process_task_failure_invalid_type(self, mock_update_status):
|
|
"""Test task processing failure with invalid task type."""
|
|
with pytest.raises(ValueError):
|
|
process_ai_task(
|
|
task_type="invalid",
|
|
user_id="user_123",
|
|
context={},
|
|
job_id="job_123",
|
|
)
|
|
|
|
# Verify status was updated to PROCESSING then FAILED
|
|
calls = mock_update_status.call_args_list
|
|
assert any(call[0][1] == JobStatus.FAILED for call in calls)
|
|
|
|
@patch('app.tasks.ai_tasks._get_user_tier')
|
|
@patch('app.tasks.ai_tasks._update_job_status')
|
|
def test_process_task_missing_context_fields(self, mock_update_status, mock_get_tier):
|
|
"""Test task processing failure with missing required context fields."""
|
|
from app.ai.model_selector import UserTier
|
|
mock_get_tier.return_value = UserTier.FREE
|
|
|
|
with pytest.raises(ValueError) as exc_info:
|
|
process_ai_task(
|
|
task_type="narrative",
|
|
user_id="user_123",
|
|
context={"action": "test"}, # Missing character and game_state
|
|
job_id="job_123",
|
|
)
|
|
|
|
assert "Missing required context field" in str(exc_info.value)
|
|
|
|
|
|
class TestJobStatus:
|
|
"""Test job status retrieval."""
|
|
|
|
@patch('app.tasks.ai_tasks.RedisService')
|
|
def test_get_job_status_from_cache(self, mock_redis_class):
|
|
"""Test getting job status from Redis cache."""
|
|
mock_redis = MagicMock()
|
|
mock_redis.get_json.return_value = {
|
|
"job_id": "job_123",
|
|
"status": "completed",
|
|
"created_at": "2025-11-21T10:00:00Z",
|
|
}
|
|
mock_redis_class.return_value = mock_redis
|
|
|
|
result = get_job_status("job_123")
|
|
|
|
assert result["job_id"] == "job_123"
|
|
assert result["status"] == "completed"
|
|
|
|
@patch('app.tasks.ai_tasks.RedisService')
|
|
@patch('app.tasks.ai_tasks.get_redis_connection')
|
|
@patch('app.tasks.ai_tasks.Job')
|
|
def test_get_job_status_from_rq(self, mock_job_class, mock_get_conn, mock_redis_class):
|
|
"""Test getting job status from RQ when not in cache."""
|
|
# No cached status
|
|
mock_redis = MagicMock()
|
|
mock_redis.get_json.return_value = None
|
|
mock_redis_class.return_value = mock_redis
|
|
|
|
# RQ job
|
|
mock_job = MagicMock()
|
|
mock_job.is_finished = True
|
|
mock_job.is_failed = False
|
|
mock_job.is_started = False
|
|
mock_job.created_at = None
|
|
mock_job.started_at = None
|
|
mock_job.ended_at = None
|
|
mock_job_class.fetch.return_value = mock_job
|
|
|
|
result = get_job_status("job_123")
|
|
|
|
assert result["status"] == "completed"
|
|
|
|
@patch('app.tasks.ai_tasks.RedisService')
|
|
@patch('app.tasks.ai_tasks.get_redis_connection')
|
|
@patch('app.tasks.ai_tasks.Job')
|
|
def test_get_job_status_not_found(self, mock_job_class, mock_get_conn, mock_redis_class):
|
|
"""Test getting status of non-existent job."""
|
|
mock_redis = MagicMock()
|
|
mock_redis.get_json.return_value = None
|
|
mock_redis_class.return_value = mock_redis
|
|
|
|
mock_job_class.fetch.side_effect = Exception("Job not found")
|
|
|
|
result = get_job_status("nonexistent_job")
|
|
|
|
assert result["status"] == "unknown"
|
|
assert "error" in result
|
|
|
|
|
|
class TestJobResult:
|
|
"""Test job result retrieval."""
|
|
|
|
@patch('app.tasks.ai_tasks.RedisService')
|
|
def test_get_job_result_from_cache(self, mock_redis_class):
|
|
"""Test getting job result from Redis cache."""
|
|
expected_result = {
|
|
"narrative": "You enter the tavern...",
|
|
"tokens_used": 450,
|
|
}
|
|
|
|
mock_redis = MagicMock()
|
|
mock_redis.get_json.return_value = expected_result
|
|
mock_redis_class.return_value = mock_redis
|
|
|
|
result = get_job_result("job_123")
|
|
|
|
assert result == expected_result
|
|
|
|
@patch('app.tasks.ai_tasks.RedisService')
|
|
@patch('app.tasks.ai_tasks.get_redis_connection')
|
|
@patch('app.tasks.ai_tasks.Job')
|
|
def test_get_job_result_from_rq(self, mock_job_class, mock_get_conn, mock_redis_class):
|
|
"""Test getting job result from RQ when not in cache."""
|
|
expected_result = {"narrative": "You find a sword..."}
|
|
|
|
mock_redis = MagicMock()
|
|
mock_redis.get_json.return_value = None
|
|
mock_redis_class.return_value = mock_redis
|
|
|
|
mock_job = MagicMock()
|
|
mock_job.is_finished = True
|
|
mock_job.result = expected_result
|
|
mock_job_class.fetch.return_value = mock_job
|
|
|
|
result = get_job_result("job_123")
|
|
|
|
assert result == expected_result
|
|
|
|
@patch('app.tasks.ai_tasks.RedisService')
|
|
@patch('app.tasks.ai_tasks.get_redis_connection')
|
|
@patch('app.tasks.ai_tasks.Job')
|
|
def test_get_job_result_not_found(self, mock_job_class, mock_get_conn, mock_redis_class):
|
|
"""Test getting result of non-existent job."""
|
|
mock_redis = MagicMock()
|
|
mock_redis.get_json.return_value = None
|
|
mock_redis_class.return_value = mock_redis
|
|
|
|
mock_job_class.fetch.side_effect = Exception("Job not found")
|
|
|
|
result = get_job_result("nonexistent_job")
|
|
|
|
assert result is None
|
|
|
|
|
|
class TestStatusStorage:
|
|
"""Test job status storage functions."""
|
|
|
|
@patch('app.tasks.ai_tasks.RedisService')
|
|
def test_store_job_status(self, mock_redis_class):
|
|
"""Test storing initial job status."""
|
|
mock_redis = MagicMock()
|
|
mock_redis_class.return_value = mock_redis
|
|
|
|
_store_job_status(
|
|
job_id="job_123",
|
|
status=JobStatus.QUEUED,
|
|
task_type="narrative",
|
|
user_id="user_456",
|
|
)
|
|
|
|
mock_redis.set_json.assert_called_once()
|
|
call_args = mock_redis.set_json.call_args
|
|
stored_data = call_args[0][1]
|
|
|
|
assert stored_data["job_id"] == "job_123"
|
|
assert stored_data["status"] == "queued"
|
|
assert stored_data["task_type"] == "narrative"
|
|
assert stored_data["user_id"] == "user_456"
|
|
|
|
@patch('app.tasks.ai_tasks.RedisService')
|
|
def test_update_job_status_to_processing(self, mock_redis_class):
|
|
"""Test updating job status to processing."""
|
|
mock_redis = MagicMock()
|
|
mock_redis.get_json.return_value = {
|
|
"job_id": "job_123",
|
|
"status": "queued",
|
|
"created_at": "2025-11-21T10:00:00Z",
|
|
}
|
|
mock_redis_class.return_value = mock_redis
|
|
|
|
_update_job_status("job_123", JobStatus.PROCESSING)
|
|
|
|
mock_redis.set_json.assert_called_once()
|
|
call_args = mock_redis.set_json.call_args
|
|
updated_data = call_args[0][1]
|
|
|
|
assert updated_data["status"] == "processing"
|
|
assert updated_data["started_at"] is not None
|
|
|
|
@patch('app.tasks.ai_tasks.RedisService')
|
|
def test_update_job_status_to_completed(self, mock_redis_class):
|
|
"""Test updating job status to completed with result."""
|
|
mock_redis = MagicMock()
|
|
mock_redis.get_json.return_value = {
|
|
"job_id": "job_123",
|
|
"status": "processing",
|
|
}
|
|
mock_redis_class.return_value = mock_redis
|
|
|
|
result_data = {"narrative": "Test result"}
|
|
_update_job_status("job_123", JobStatus.COMPLETED, result=result_data)
|
|
|
|
call_args = mock_redis.set_json.call_args
|
|
updated_data = call_args[0][1]
|
|
|
|
assert updated_data["status"] == "completed"
|
|
assert updated_data["completed_at"] is not None
|
|
assert updated_data["result"] == result_data
|
|
|
|
@patch('app.tasks.ai_tasks.RedisService')
|
|
def test_update_job_status_to_failed(self, mock_redis_class):
|
|
"""Test updating job status to failed with error."""
|
|
mock_redis = MagicMock()
|
|
mock_redis.get_json.return_value = {
|
|
"job_id": "job_123",
|
|
"status": "processing",
|
|
}
|
|
mock_redis_class.return_value = mock_redis
|
|
|
|
_update_job_status("job_123", JobStatus.FAILED, error="Something went wrong")
|
|
|
|
call_args = mock_redis.set_json.call_args
|
|
updated_data = call_args[0][1]
|
|
|
|
assert updated_data["status"] == "failed"
|
|
assert updated_data["error"] == "Something went wrong"
|
|
|
|
@patch('app.tasks.ai_tasks.RedisService')
|
|
def test_store_job_result(self, mock_redis_class):
|
|
"""Test storing job result."""
|
|
mock_redis = MagicMock()
|
|
mock_redis_class.return_value = mock_redis
|
|
|
|
result_data = {
|
|
"narrative": "You discover a hidden passage...",
|
|
"tokens_used": 500,
|
|
}
|
|
|
|
_store_job_result("job_123", result_data)
|
|
|
|
mock_redis.set_json.assert_called_once()
|
|
call_args = mock_redis.set_json.call_args
|
|
|
|
assert call_args[0][0] == "ai_job_result:job_123"
|
|
assert call_args[0][1] == result_data
|
|
assert call_args[1]["ttl"] == 3600 # 1 hour
|
|
|
|
|
|
class TestTaskTypes:
|
|
"""Test task type and priority enums."""
|
|
|
|
def test_task_type_values(self):
|
|
"""Test all task types are defined."""
|
|
assert TaskType.NARRATIVE.value == "narrative"
|
|
assert TaskType.COMBAT.value == "combat"
|
|
assert TaskType.QUEST_SELECTION.value == "quest_selection"
|
|
assert TaskType.NPC_DIALOGUE.value == "npc_dialogue"
|
|
|
|
def test_task_priority_values(self):
|
|
"""Test all priorities are defined."""
|
|
assert TaskPriority.LOW.value == "low"
|
|
assert TaskPriority.NORMAL.value == "normal"
|
|
assert TaskPriority.HIGH.value == "high"
|
|
|
|
def test_job_status_values(self):
|
|
"""Test all job statuses are defined."""
|
|
assert JobStatus.QUEUED.value == "queued"
|
|
assert JobStatus.PROCESSING.value == "processing"
|
|
assert JobStatus.COMPLETED.value == "completed"
|
|
assert JobStatus.FAILED.value == "failed"
|