first commit
This commit is contained in:
571
api/tests/test_ai_tasks.py
Normal file
571
api/tests/test_ai_tasks.py
Normal file
@@ -0,0 +1,571 @@
|
||||
"""
|
||||
Unit tests for AI Task Jobs.
|
||||
|
||||
These tests verify the job enqueueing, status tracking,
|
||||
and result storage functionality.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
import json
|
||||
|
||||
from app.tasks.ai_tasks import (
|
||||
enqueue_ai_task,
|
||||
process_ai_task,
|
||||
get_job_status,
|
||||
get_job_result,
|
||||
JobStatus,
|
||||
TaskType,
|
||||
TaskPriority,
|
||||
_store_job_status,
|
||||
_update_job_status,
|
||||
_store_job_result,
|
||||
)
|
||||
|
||||
|
||||
class TestEnqueueAITask:
|
||||
"""Test AI task enqueueing."""
|
||||
|
||||
@patch('app.tasks.ai_tasks.get_queue')
|
||||
@patch('app.tasks.ai_tasks._store_job_status')
|
||||
def test_enqueue_narrative_task(self, mock_store_status, mock_get_queue):
|
||||
"""Test enqueueing a narrative task."""
|
||||
# Setup mock queue
|
||||
mock_queue = MagicMock()
|
||||
mock_job = MagicMock()
|
||||
mock_job.id = "test_job_123"
|
||||
mock_queue.enqueue.return_value = mock_job
|
||||
mock_get_queue.return_value = mock_queue
|
||||
|
||||
# Enqueue task
|
||||
result = enqueue_ai_task(
|
||||
task_type="narrative",
|
||||
user_id="user_123",
|
||||
context={"action": "explore"},
|
||||
priority="normal",
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert "job_id" in result
|
||||
assert result["status"] == "queued"
|
||||
mock_queue.enqueue.assert_called_once()
|
||||
mock_store_status.assert_called_once()
|
||||
|
||||
@patch('app.tasks.ai_tasks.get_queue')
|
||||
@patch('app.tasks.ai_tasks._store_job_status')
|
||||
def test_enqueue_high_priority_task(self, mock_store_status, mock_get_queue):
|
||||
"""Test high priority task goes to front of queue."""
|
||||
mock_queue = MagicMock()
|
||||
mock_job = MagicMock()
|
||||
mock_queue.enqueue.return_value = mock_job
|
||||
mock_get_queue.return_value = mock_queue
|
||||
|
||||
enqueue_ai_task(
|
||||
task_type="narrative",
|
||||
user_id="user_123",
|
||||
context={},
|
||||
priority="high",
|
||||
)
|
||||
|
||||
# Verify at_front=True for high priority
|
||||
call_kwargs = mock_queue.enqueue.call_args[1]
|
||||
assert call_kwargs["at_front"] is True
|
||||
|
||||
@patch('app.tasks.ai_tasks.get_queue')
|
||||
@patch('app.tasks.ai_tasks._store_job_status')
|
||||
def test_enqueue_with_session_and_character(self, mock_store_status, mock_get_queue):
|
||||
"""Test enqueueing with session and character IDs."""
|
||||
mock_queue = MagicMock()
|
||||
mock_job = MagicMock()
|
||||
mock_queue.enqueue.return_value = mock_job
|
||||
mock_get_queue.return_value = mock_queue
|
||||
|
||||
result = enqueue_ai_task(
|
||||
task_type="combat",
|
||||
user_id="user_123",
|
||||
context={"enemy": "goblin"},
|
||||
session_id="sess_456",
|
||||
character_id="char_789",
|
||||
)
|
||||
|
||||
assert "job_id" in result
|
||||
|
||||
# Verify kwargs passed to enqueue include session and character
|
||||
call_kwargs = mock_queue.enqueue.call_args[1]
|
||||
job_kwargs = call_kwargs["kwargs"]
|
||||
assert job_kwargs["session_id"] == "sess_456"
|
||||
assert job_kwargs["character_id"] == "char_789"
|
||||
|
||||
def test_enqueue_invalid_task_type(self):
|
||||
"""Test enqueueing with invalid task type raises error."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
enqueue_ai_task(
|
||||
task_type="invalid_type",
|
||||
user_id="user_123",
|
||||
context={},
|
||||
)
|
||||
|
||||
assert "Invalid task_type" in str(exc_info.value)
|
||||
|
||||
def test_enqueue_invalid_priority(self):
|
||||
"""Test enqueueing with invalid priority raises error."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
enqueue_ai_task(
|
||||
task_type="narrative",
|
||||
user_id="user_123",
|
||||
context={},
|
||||
priority="invalid_priority",
|
||||
)
|
||||
|
||||
assert "Invalid priority" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestProcessAITask:
|
||||
"""Test AI task processing with mocked NarrativeGenerator."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_character(self):
|
||||
"""Sample character data for tests."""
|
||||
return {
|
||||
"name": "Aldric",
|
||||
"level": 3,
|
||||
"player_class": "Fighter",
|
||||
"current_hp": 25,
|
||||
"max_hp": 30,
|
||||
"stats": {"strength": 16, "dexterity": 12},
|
||||
"skills": [],
|
||||
"effects": []
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def sample_game_state(self):
|
||||
"""Sample game state for tests."""
|
||||
return {
|
||||
"current_location": "Tavern",
|
||||
"location_type": "TAVERN",
|
||||
"discovered_locations": [],
|
||||
"active_quests": []
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_narrative_response(self):
|
||||
"""Mock NarrativeResponse for tests."""
|
||||
from app.ai.narrative_generator import NarrativeResponse
|
||||
return NarrativeResponse(
|
||||
narrative="You enter the tavern...",
|
||||
tokens_used=150,
|
||||
model="meta/meta-llama-3-8b-instruct",
|
||||
context_type="story_progression",
|
||||
generation_time=2.5
|
||||
)
|
||||
|
||||
@patch('app.tasks.ai_tasks._get_user_tier')
|
||||
@patch('app.tasks.ai_tasks._update_game_session')
|
||||
@patch('app.tasks.ai_tasks.NarrativeGenerator')
|
||||
@patch('app.tasks.ai_tasks._update_job_status')
|
||||
@patch('app.tasks.ai_tasks._store_job_result')
|
||||
def test_process_narrative_task(
|
||||
self, mock_store_result, mock_update_status, mock_generator_class,
|
||||
mock_update_session, mock_get_tier, sample_character, sample_game_state, mock_narrative_response
|
||||
):
|
||||
"""Test processing a narrative task with NarrativeGenerator."""
|
||||
from app.ai.model_selector import UserTier
|
||||
|
||||
# Setup mocks
|
||||
mock_get_tier.return_value = UserTier.FREE
|
||||
mock_generator = MagicMock()
|
||||
mock_generator.generate_story_response.return_value = mock_narrative_response
|
||||
mock_generator_class.return_value = mock_generator
|
||||
|
||||
# Process task
|
||||
result = process_ai_task(
|
||||
task_type="narrative",
|
||||
user_id="user_123",
|
||||
context={
|
||||
"action": "I explore the tavern",
|
||||
"character": sample_character,
|
||||
"game_state": sample_game_state
|
||||
},
|
||||
job_id="job_123",
|
||||
session_id="sess_456",
|
||||
character_id="char_789"
|
||||
)
|
||||
|
||||
# Verify result structure
|
||||
assert "narrative" in result
|
||||
assert result["narrative"] == "You enter the tavern..."
|
||||
assert result["tokens_used"] == 150
|
||||
assert result["model"] == "meta/meta-llama-3-8b-instruct"
|
||||
|
||||
# Verify generator was called
|
||||
mock_generator.generate_story_response.assert_called_once()
|
||||
|
||||
# Verify session update was called
|
||||
mock_update_session.assert_called_once()
|
||||
|
||||
@patch('app.tasks.ai_tasks._get_user_tier')
|
||||
@patch('app.tasks.ai_tasks.NarrativeGenerator')
|
||||
@patch('app.tasks.ai_tasks._update_job_status')
|
||||
@patch('app.tasks.ai_tasks._store_job_result')
|
||||
def test_process_combat_task(
|
||||
self, mock_store_result, mock_update_status, mock_generator_class,
|
||||
mock_get_tier, sample_character, mock_narrative_response
|
||||
):
|
||||
"""Test processing a combat task."""
|
||||
from app.ai.model_selector import UserTier
|
||||
|
||||
mock_get_tier.return_value = UserTier.BASIC
|
||||
mock_generator = MagicMock()
|
||||
mock_generator.generate_combat_narration.return_value = mock_narrative_response
|
||||
mock_generator_class.return_value = mock_generator
|
||||
|
||||
result = process_ai_task(
|
||||
task_type="combat",
|
||||
user_id="user_123",
|
||||
context={
|
||||
"character": sample_character,
|
||||
"combat_state": {"round_number": 1, "enemies": [], "current_turn": "player"},
|
||||
"action": "swings sword",
|
||||
"action_result": {"hit": True, "damage": 10}
|
||||
},
|
||||
job_id="job_123",
|
||||
)
|
||||
|
||||
assert "combat_narrative" in result
|
||||
mock_generator.generate_combat_narration.assert_called_once()
|
||||
|
||||
@patch('app.tasks.ai_tasks._get_user_tier')
|
||||
@patch('app.tasks.ai_tasks.NarrativeGenerator')
|
||||
@patch('app.tasks.ai_tasks._update_job_status')
|
||||
@patch('app.tasks.ai_tasks._store_job_result')
|
||||
def test_process_quest_selection_task(
|
||||
self, mock_store_result, mock_update_status, mock_generator_class,
|
||||
mock_get_tier, sample_character, sample_game_state
|
||||
):
|
||||
"""Test processing a quest selection task."""
|
||||
from app.ai.model_selector import UserTier
|
||||
|
||||
mock_get_tier.return_value = UserTier.FREE
|
||||
mock_generator = MagicMock()
|
||||
mock_generator.generate_quest_selection.return_value = "goblin_cave"
|
||||
mock_generator_class.return_value = mock_generator
|
||||
|
||||
result = process_ai_task(
|
||||
task_type="quest_selection",
|
||||
user_id="user_123",
|
||||
context={
|
||||
"character": sample_character,
|
||||
"eligible_quests": [{"quest_id": "goblin_cave"}],
|
||||
"game_context": sample_game_state
|
||||
},
|
||||
job_id="job_123",
|
||||
)
|
||||
|
||||
assert result["selected_quest_id"] == "goblin_cave"
|
||||
mock_generator.generate_quest_selection.assert_called_once()
|
||||
|
||||
@patch('app.tasks.ai_tasks._get_user_tier')
|
||||
@patch('app.tasks.ai_tasks.NarrativeGenerator')
|
||||
@patch('app.tasks.ai_tasks._update_job_status')
|
||||
@patch('app.tasks.ai_tasks._store_job_result')
|
||||
def test_process_npc_dialogue_task(
|
||||
self, mock_store_result, mock_update_status, mock_generator_class,
|
||||
mock_get_tier, sample_character, sample_game_state, mock_narrative_response
|
||||
):
|
||||
"""Test processing an NPC dialogue task."""
|
||||
from app.ai.model_selector import UserTier
|
||||
|
||||
mock_get_tier.return_value = UserTier.PREMIUM
|
||||
mock_generator = MagicMock()
|
||||
mock_generator.generate_npc_dialogue.return_value = mock_narrative_response
|
||||
mock_generator_class.return_value = mock_generator
|
||||
|
||||
result = process_ai_task(
|
||||
task_type="npc_dialogue",
|
||||
user_id="user_123",
|
||||
context={
|
||||
"character": sample_character,
|
||||
"npc": {"name": "Innkeeper", "role": "Barkeeper"},
|
||||
"conversation_topic": "What's the news?",
|
||||
"game_state": sample_game_state
|
||||
},
|
||||
job_id="job_123",
|
||||
)
|
||||
|
||||
assert "dialogue" in result
|
||||
mock_generator.generate_npc_dialogue.assert_called_once()
|
||||
|
||||
@patch('app.tasks.ai_tasks._update_job_status')
|
||||
def test_process_task_failure_invalid_type(self, mock_update_status):
|
||||
"""Test task processing failure with invalid task type."""
|
||||
with pytest.raises(ValueError):
|
||||
process_ai_task(
|
||||
task_type="invalid",
|
||||
user_id="user_123",
|
||||
context={},
|
||||
job_id="job_123",
|
||||
)
|
||||
|
||||
# Verify status was updated to PROCESSING then FAILED
|
||||
calls = mock_update_status.call_args_list
|
||||
assert any(call[0][1] == JobStatus.FAILED for call in calls)
|
||||
|
||||
@patch('app.tasks.ai_tasks._get_user_tier')
|
||||
@patch('app.tasks.ai_tasks._update_job_status')
|
||||
def test_process_task_missing_context_fields(self, mock_update_status, mock_get_tier):
|
||||
"""Test task processing failure with missing required context fields."""
|
||||
from app.ai.model_selector import UserTier
|
||||
mock_get_tier.return_value = UserTier.FREE
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
process_ai_task(
|
||||
task_type="narrative",
|
||||
user_id="user_123",
|
||||
context={"action": "test"}, # Missing character and game_state
|
||||
job_id="job_123",
|
||||
)
|
||||
|
||||
assert "Missing required context field" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestJobStatus:
|
||||
"""Test job status retrieval."""
|
||||
|
||||
@patch('app.tasks.ai_tasks.RedisService')
|
||||
def test_get_job_status_from_cache(self, mock_redis_class):
|
||||
"""Test getting job status from Redis cache."""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get_json.return_value = {
|
||||
"job_id": "job_123",
|
||||
"status": "completed",
|
||||
"created_at": "2025-11-21T10:00:00Z",
|
||||
}
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
result = get_job_status("job_123")
|
||||
|
||||
assert result["job_id"] == "job_123"
|
||||
assert result["status"] == "completed"
|
||||
|
||||
@patch('app.tasks.ai_tasks.RedisService')
|
||||
@patch('app.tasks.ai_tasks.get_redis_connection')
|
||||
@patch('app.tasks.ai_tasks.Job')
|
||||
def test_get_job_status_from_rq(self, mock_job_class, mock_get_conn, mock_redis_class):
|
||||
"""Test getting job status from RQ when not in cache."""
|
||||
# No cached status
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get_json.return_value = None
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
# RQ job
|
||||
mock_job = MagicMock()
|
||||
mock_job.is_finished = True
|
||||
mock_job.is_failed = False
|
||||
mock_job.is_started = False
|
||||
mock_job.created_at = None
|
||||
mock_job.started_at = None
|
||||
mock_job.ended_at = None
|
||||
mock_job_class.fetch.return_value = mock_job
|
||||
|
||||
result = get_job_status("job_123")
|
||||
|
||||
assert result["status"] == "completed"
|
||||
|
||||
@patch('app.tasks.ai_tasks.RedisService')
|
||||
@patch('app.tasks.ai_tasks.get_redis_connection')
|
||||
@patch('app.tasks.ai_tasks.Job')
|
||||
def test_get_job_status_not_found(self, mock_job_class, mock_get_conn, mock_redis_class):
|
||||
"""Test getting status of non-existent job."""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get_json.return_value = None
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
mock_job_class.fetch.side_effect = Exception("Job not found")
|
||||
|
||||
result = get_job_status("nonexistent_job")
|
||||
|
||||
assert result["status"] == "unknown"
|
||||
assert "error" in result
|
||||
|
||||
|
||||
class TestJobResult:
|
||||
"""Test job result retrieval."""
|
||||
|
||||
@patch('app.tasks.ai_tasks.RedisService')
|
||||
def test_get_job_result_from_cache(self, mock_redis_class):
|
||||
"""Test getting job result from Redis cache."""
|
||||
expected_result = {
|
||||
"narrative": "You enter the tavern...",
|
||||
"tokens_used": 450,
|
||||
}
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get_json.return_value = expected_result
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
result = get_job_result("job_123")
|
||||
|
||||
assert result == expected_result
|
||||
|
||||
@patch('app.tasks.ai_tasks.RedisService')
|
||||
@patch('app.tasks.ai_tasks.get_redis_connection')
|
||||
@patch('app.tasks.ai_tasks.Job')
|
||||
def test_get_job_result_from_rq(self, mock_job_class, mock_get_conn, mock_redis_class):
|
||||
"""Test getting job result from RQ when not in cache."""
|
||||
expected_result = {"narrative": "You find a sword..."}
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get_json.return_value = None
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
mock_job = MagicMock()
|
||||
mock_job.is_finished = True
|
||||
mock_job.result = expected_result
|
||||
mock_job_class.fetch.return_value = mock_job
|
||||
|
||||
result = get_job_result("job_123")
|
||||
|
||||
assert result == expected_result
|
||||
|
||||
@patch('app.tasks.ai_tasks.RedisService')
|
||||
@patch('app.tasks.ai_tasks.get_redis_connection')
|
||||
@patch('app.tasks.ai_tasks.Job')
|
||||
def test_get_job_result_not_found(self, mock_job_class, mock_get_conn, mock_redis_class):
|
||||
"""Test getting result of non-existent job."""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get_json.return_value = None
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
mock_job_class.fetch.side_effect = Exception("Job not found")
|
||||
|
||||
result = get_job_result("nonexistent_job")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestStatusStorage:
|
||||
"""Test job status storage functions."""
|
||||
|
||||
@patch('app.tasks.ai_tasks.RedisService')
|
||||
def test_store_job_status(self, mock_redis_class):
|
||||
"""Test storing initial job status."""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
_store_job_status(
|
||||
job_id="job_123",
|
||||
status=JobStatus.QUEUED,
|
||||
task_type="narrative",
|
||||
user_id="user_456",
|
||||
)
|
||||
|
||||
mock_redis.set_json.assert_called_once()
|
||||
call_args = mock_redis.set_json.call_args
|
||||
stored_data = call_args[0][1]
|
||||
|
||||
assert stored_data["job_id"] == "job_123"
|
||||
assert stored_data["status"] == "queued"
|
||||
assert stored_data["task_type"] == "narrative"
|
||||
assert stored_data["user_id"] == "user_456"
|
||||
|
||||
@patch('app.tasks.ai_tasks.RedisService')
|
||||
def test_update_job_status_to_processing(self, mock_redis_class):
|
||||
"""Test updating job status to processing."""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get_json.return_value = {
|
||||
"job_id": "job_123",
|
||||
"status": "queued",
|
||||
"created_at": "2025-11-21T10:00:00Z",
|
||||
}
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
_update_job_status("job_123", JobStatus.PROCESSING)
|
||||
|
||||
mock_redis.set_json.assert_called_once()
|
||||
call_args = mock_redis.set_json.call_args
|
||||
updated_data = call_args[0][1]
|
||||
|
||||
assert updated_data["status"] == "processing"
|
||||
assert updated_data["started_at"] is not None
|
||||
|
||||
@patch('app.tasks.ai_tasks.RedisService')
|
||||
def test_update_job_status_to_completed(self, mock_redis_class):
|
||||
"""Test updating job status to completed with result."""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get_json.return_value = {
|
||||
"job_id": "job_123",
|
||||
"status": "processing",
|
||||
}
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
result_data = {"narrative": "Test result"}
|
||||
_update_job_status("job_123", JobStatus.COMPLETED, result=result_data)
|
||||
|
||||
call_args = mock_redis.set_json.call_args
|
||||
updated_data = call_args[0][1]
|
||||
|
||||
assert updated_data["status"] == "completed"
|
||||
assert updated_data["completed_at"] is not None
|
||||
assert updated_data["result"] == result_data
|
||||
|
||||
@patch('app.tasks.ai_tasks.RedisService')
|
||||
def test_update_job_status_to_failed(self, mock_redis_class):
|
||||
"""Test updating job status to failed with error."""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.get_json.return_value = {
|
||||
"job_id": "job_123",
|
||||
"status": "processing",
|
||||
}
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
_update_job_status("job_123", JobStatus.FAILED, error="Something went wrong")
|
||||
|
||||
call_args = mock_redis.set_json.call_args
|
||||
updated_data = call_args[0][1]
|
||||
|
||||
assert updated_data["status"] == "failed"
|
||||
assert updated_data["error"] == "Something went wrong"
|
||||
|
||||
@patch('app.tasks.ai_tasks.RedisService')
|
||||
def test_store_job_result(self, mock_redis_class):
|
||||
"""Test storing job result."""
|
||||
mock_redis = MagicMock()
|
||||
mock_redis_class.return_value = mock_redis
|
||||
|
||||
result_data = {
|
||||
"narrative": "You discover a hidden passage...",
|
||||
"tokens_used": 500,
|
||||
}
|
||||
|
||||
_store_job_result("job_123", result_data)
|
||||
|
||||
mock_redis.set_json.assert_called_once()
|
||||
call_args = mock_redis.set_json.call_args
|
||||
|
||||
assert call_args[0][0] == "ai_job_result:job_123"
|
||||
assert call_args[0][1] == result_data
|
||||
assert call_args[1]["ttl"] == 3600 # 1 hour
|
||||
|
||||
|
||||
class TestTaskTypes:
|
||||
"""Test task type and priority enums."""
|
||||
|
||||
def test_task_type_values(self):
|
||||
"""Test all task types are defined."""
|
||||
assert TaskType.NARRATIVE.value == "narrative"
|
||||
assert TaskType.COMBAT.value == "combat"
|
||||
assert TaskType.QUEST_SELECTION.value == "quest_selection"
|
||||
assert TaskType.NPC_DIALOGUE.value == "npc_dialogue"
|
||||
|
||||
def test_task_priority_values(self):
|
||||
"""Test all priorities are defined."""
|
||||
assert TaskPriority.LOW.value == "low"
|
||||
assert TaskPriority.NORMAL.value == "normal"
|
||||
assert TaskPriority.HIGH.value == "high"
|
||||
|
||||
def test_job_status_values(self):
|
||||
"""Test all job statuses are defined."""
|
||||
assert JobStatus.QUEUED.value == "queued"
|
||||
assert JobStatus.PROCESSING.value == "processing"
|
||||
assert JobStatus.COMPLETED.value == "completed"
|
||||
assert JobStatus.FAILED.value == "failed"
|
||||
Reference in New Issue
Block a user