first commit
This commit is contained in:
124
api/scripts/README.md
Normal file
124
api/scripts/README.md
Normal file
@@ -0,0 +1,124 @@
|
||||
# Scripts Directory
|
||||
|
||||
This directory contains utility scripts for database initialization, migrations, and other maintenance tasks.
|
||||
|
||||
## Database Initialization
|
||||
|
||||
### `init_database.py`
|
||||
|
||||
Initializes all database tables in Appwrite with the correct schema, columns, and indexes.
|
||||
|
||||
**Usage:**
|
||||
|
||||
```bash
|
||||
# Ensure virtual environment is activated
|
||||
source venv/bin/activate
|
||||
|
||||
# Run the initialization script
|
||||
python scripts/init_database.py
|
||||
```
|
||||
|
||||
**Prerequisites:**
|
||||
|
||||
1. Appwrite instance running and accessible
|
||||
2. `.env` file configured with Appwrite credentials:
|
||||
- `APPWRITE_ENDPOINT`
|
||||
- `APPWRITE_PROJECT_ID`
|
||||
- `APPWRITE_API_KEY`
|
||||
- `APPWRITE_DATABASE_ID`
|
||||
|
||||
**What it does:**
|
||||
|
||||
1. Validates environment configuration
|
||||
2. Creates the following tables:
|
||||
- **characters**: Player character data with userId indexing
|
||||
|
||||
3. Creates necessary columns and indexes for efficient querying
|
||||
4. Skips tables/columns/indexes that already exist (idempotent)
|
||||
|
||||
**Output:**
|
||||
|
||||
```
|
||||
============================================================
|
||||
Code of Conquest - Database Initialization
|
||||
============================================================
|
||||
|
||||
✓ Environment variables loaded
|
||||
Endpoint: https://your-appwrite-instance.com/v1
|
||||
Project: your-project-id
|
||||
Database: main
|
||||
|
||||
Initializing database tables...
|
||||
|
||||
============================================================
|
||||
Initialization Results
|
||||
============================================================
|
||||
|
||||
✓ characters: SUCCESS
|
||||
|
||||
Total: 1 succeeded, 0 failed
|
||||
|
||||
✓ All tables initialized successfully!
|
||||
|
||||
You can now start the application.
|
||||
```
|
||||
|
||||
## Adding New Tables
|
||||
|
||||
To add a new table to the initialization process:
|
||||
|
||||
1. Open `app/services/database_init.py`
|
||||
2. Create a new method following the pattern of `init_characters_table()`
|
||||
3. Add the table initialization to `init_all_tables()` method
|
||||
4. Run the initialization script
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
def init_sessions_table(self) -> bool:
|
||||
"""Initialize the sessions table."""
|
||||
table_id = 'sessions'
|
||||
|
||||
# Create table
|
||||
table = self.tables_db.create_table(
|
||||
database_id=self.database_id,
|
||||
table_id=table_id,
|
||||
name='Sessions'
|
||||
)
|
||||
|
||||
# Create columns
|
||||
self._create_column(
|
||||
table_id=table_id,
|
||||
column_id='userId',
|
||||
column_type='string',
|
||||
size=255,
|
||||
required=True
|
||||
)
|
||||
|
||||
# Create indexes
|
||||
self._create_index(
|
||||
table_id=table_id,
|
||||
index_id='idx_userId',
|
||||
index_type='key',
|
||||
attributes=['userId']
|
||||
)
|
||||
|
||||
return True
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Missing Environment Variables
|
||||
|
||||
If you see errors about missing environment variables, ensure your `.env` file contains all required Appwrite configuration.
|
||||
|
||||
### Connection Errors
|
||||
|
||||
If the script cannot connect to Appwrite:
|
||||
- Verify the `APPWRITE_ENDPOINT` is correct and accessible
|
||||
- Check that the API key has sufficient permissions
|
||||
- Ensure the database exists in your Appwrite project
|
||||
|
||||
### Column/Index Already Exists
|
||||
|
||||
The script is idempotent and will log warnings for existing columns/indexes without failing. This is normal if you run the script multiple times.
|
||||
2
api/scripts/clear_char_daily_limit.sh
Executable file
2
api/scripts/clear_char_daily_limit.sh
Executable file
@@ -0,0 +1,2 @@
|
||||
#!/bin/env bash
|
||||
python -c "from app.services.rate_limiter_service import RateLimiterService; RateLimiterService().reset_usage('69180281baf6d52c772d')"
|
||||
2
api/scripts/clear_worker_queues.sh
Executable file
2
api/scripts/clear_worker_queues.sh
Executable file
@@ -0,0 +1,2 @@
|
||||
#!/bin/env bash
|
||||
docker exec -it coc_redis redis-cli FLUSHALL
|
||||
106
api/scripts/init_database.py
Executable file
106
api/scripts/init_database.py
Executable file
@@ -0,0 +1,106 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Database Initialization Script.
|
||||
|
||||
This script initializes all database tables in Appwrite.
|
||||
Run this script once to set up the database schema before running the application.
|
||||
|
||||
Usage:
|
||||
python scripts/init_database.py
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from app.services.database_init import init_database
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
def main():
|
||||
"""Initialize database tables."""
|
||||
print("=" * 60)
|
||||
print("Code of Conquest - Database Initialization")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
# Verify environment variables are set
|
||||
required_vars = [
|
||||
'APPWRITE_ENDPOINT',
|
||||
'APPWRITE_PROJECT_ID',
|
||||
'APPWRITE_API_KEY',
|
||||
'APPWRITE_DATABASE_ID'
|
||||
]
|
||||
|
||||
missing_vars = [var for var in required_vars if not os.getenv(var)]
|
||||
if missing_vars:
|
||||
print("❌ ERROR: Missing required environment variables:")
|
||||
for var in missing_vars:
|
||||
print(f" - {var}")
|
||||
print()
|
||||
print("Please ensure your .env file is configured correctly.")
|
||||
sys.exit(1)
|
||||
|
||||
print("✓ Environment variables loaded")
|
||||
print(f" Endpoint: {os.getenv('APPWRITE_ENDPOINT')}")
|
||||
print(f" Project: {os.getenv('APPWRITE_PROJECT_ID')}")
|
||||
print(f" Database: {os.getenv('APPWRITE_DATABASE_ID')}")
|
||||
print()
|
||||
|
||||
# Initialize database
|
||||
print("Initializing database tables...")
|
||||
print()
|
||||
|
||||
try:
|
||||
results = init_database()
|
||||
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("Initialization Results")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for table_name, success in results.items():
|
||||
if success:
|
||||
print(f"✓ {table_name}: SUCCESS")
|
||||
success_count += 1
|
||||
else:
|
||||
print(f"✗ {table_name}: FAILED")
|
||||
failed_count += 1
|
||||
|
||||
print()
|
||||
print(f"Total: {success_count} succeeded, {failed_count} failed")
|
||||
print()
|
||||
|
||||
if failed_count > 0:
|
||||
print("⚠️ Some tables failed to initialize. Check logs for details.")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print("✓ All tables initialized successfully!")
|
||||
print()
|
||||
print("You can now start the application.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Database initialization failed", error=str(e))
|
||||
print()
|
||||
print(f"❌ ERROR: {str(e)}")
|
||||
print()
|
||||
print("Check logs for details.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
146
api/scripts/queue_info.py
Executable file
146
api/scripts/queue_info.py
Executable file
@@ -0,0 +1,146 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
RQ Queue Monitoring Utility
|
||||
|
||||
Displays information about RQ queues and their jobs.
|
||||
|
||||
Usage:
|
||||
python scripts/queue_info.py # Show all queues
|
||||
python scripts/queue_info.py --failed # Show failed jobs
|
||||
python scripts/queue_info.py --workers # Show active workers
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from redis import Redis
|
||||
from rq import Queue, Worker
|
||||
from rq.job import Job
|
||||
from rq.registry import FailedJobRegistry, StartedJobRegistry
|
||||
|
||||
from app.tasks import ALL_QUEUES, get_redis_connection, get_all_queues_info
|
||||
|
||||
|
||||
def show_queue_info():
|
||||
"""Display information about all queues."""
|
||||
print("\n" + "=" * 60)
|
||||
print("RQ Queue Status")
|
||||
print("=" * 60)
|
||||
|
||||
for info in get_all_queues_info():
|
||||
print(f"\nQueue: {info['name']}")
|
||||
print(f" Description: {info['description']}")
|
||||
print(f" Jobs in queue: {info['count']}")
|
||||
print(f" Default timeout: {info['default_timeout']}s")
|
||||
print(f" Result TTL: {info['default_result_ttl']}s")
|
||||
|
||||
|
||||
def show_failed_jobs():
|
||||
"""Display failed jobs from all queues."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Failed Jobs")
|
||||
print("=" * 60)
|
||||
|
||||
conn = get_redis_connection()
|
||||
|
||||
for queue_name in ALL_QUEUES:
|
||||
queue = Queue(queue_name, connection=conn)
|
||||
registry = FailedJobRegistry(queue=queue)
|
||||
job_ids = registry.get_job_ids()
|
||||
|
||||
if job_ids:
|
||||
print(f"\nQueue: {queue_name} ({len(job_ids)} failed)")
|
||||
for job_id in job_ids[:10]: # Show first 10
|
||||
job = Job.fetch(job_id, connection=conn)
|
||||
print(f" - {job_id}")
|
||||
print(f" Function: {job.func_name}")
|
||||
print(f" Failed at: {job.ended_at}")
|
||||
if job.exc_info:
|
||||
# Show first line of exception
|
||||
exc_line = job.exc_info.split('\n')[-2] if job.exc_info else 'Unknown'
|
||||
print(f" Error: {exc_line[:80]}")
|
||||
else:
|
||||
print(f"\nQueue: {queue_name} (no failed jobs)")
|
||||
|
||||
|
||||
def show_workers():
|
||||
"""Display active workers."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Active Workers")
|
||||
print("=" * 60)
|
||||
|
||||
conn = get_redis_connection()
|
||||
workers = Worker.all(connection=conn)
|
||||
|
||||
if not workers:
|
||||
print("\nNo active workers found.")
|
||||
return
|
||||
|
||||
for worker in workers:
|
||||
print(f"\nWorker: {worker.name}")
|
||||
print(f" State: {worker.get_state()}")
|
||||
print(f" Queues: {', '.join(q.name for q in worker.queues)}")
|
||||
print(f" PID: {worker.pid}")
|
||||
|
||||
current_job = worker.get_current_job()
|
||||
if current_job:
|
||||
print(f" Current job: {current_job.id}")
|
||||
print(f" Function: {current_job.func_name}")
|
||||
|
||||
|
||||
def show_started_jobs():
|
||||
"""Display currently running jobs."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Running Jobs")
|
||||
print("=" * 60)
|
||||
|
||||
conn = get_redis_connection()
|
||||
|
||||
for queue_name in ALL_QUEUES:
|
||||
queue = Queue(queue_name, connection=conn)
|
||||
registry = StartedJobRegistry(queue=queue)
|
||||
job_ids = registry.get_job_ids()
|
||||
|
||||
if job_ids:
|
||||
print(f"\nQueue: {queue_name} ({len(job_ids)} running)")
|
||||
for job_id in job_ids:
|
||||
job = Job.fetch(job_id, connection=conn)
|
||||
print(f" - {job_id}")
|
||||
print(f" Function: {job.func_name}")
|
||||
print(f" Started at: {job.started_at}")
|
||||
else:
|
||||
print(f"\nQueue: {queue_name} (no running jobs)")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='RQ Queue Monitoring Utility')
|
||||
parser.add_argument('--failed', action='store_true', help='Show failed jobs')
|
||||
parser.add_argument('--workers', action='store_true', help='Show active workers')
|
||||
parser.add_argument('--running', action='store_true', help='Show running jobs')
|
||||
parser.add_argument('--all', action='store_true', help='Show all information')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Always show queue info
|
||||
show_queue_info()
|
||||
|
||||
if args.all or args.workers:
|
||||
show_workers()
|
||||
|
||||
if args.all or args.running:
|
||||
show_started_jobs()
|
||||
|
||||
if args.all or args.failed:
|
||||
show_failed_jobs()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Done")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
76
api/scripts/setup.sh
Executable file
76
api/scripts/setup.sh
Executable file
@@ -0,0 +1,76 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Setup script for Code of Conquest
|
||||
# Run this after cloning the repository
|
||||
|
||||
set -e
|
||||
|
||||
echo "========================================="
|
||||
echo "Code of Conquest - Setup Script"
|
||||
echo "========================================="
|
||||
echo ""
|
||||
|
||||
# Check Python version
|
||||
echo "Checking Python version..."
|
||||
python3 --version
|
||||
|
||||
# Check if virtual environment exists
|
||||
if [ ! -d "venv" ]; then
|
||||
echo "Creating virtual environment..."
|
||||
python3 -m venv venv
|
||||
else
|
||||
echo "Virtual environment already exists."
|
||||
fi
|
||||
|
||||
# Activate virtual environment
|
||||
echo "Activating virtual environment..."
|
||||
source venv/bin/activate
|
||||
|
||||
# Install dependencies
|
||||
echo "Installing dependencies..."
|
||||
pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Create .env if it doesn't exist
|
||||
if [ ! -f ".env" ]; then
|
||||
echo "Creating .env file from template..."
|
||||
cp .env.example .env
|
||||
echo "⚠️ Please edit .env and add your API keys!"
|
||||
else
|
||||
echo ".env file already exists."
|
||||
fi
|
||||
|
||||
# Create logs directory
|
||||
mkdir -p logs
|
||||
|
||||
# Check Docker
|
||||
echo ""
|
||||
echo "Checking Docker installation..."
|
||||
if command -v docker &> /dev/null; then
|
||||
echo "✓ Docker is installed"
|
||||
docker --version
|
||||
else
|
||||
echo "✗ Docker is not installed. Please install Docker to run Redis locally."
|
||||
fi
|
||||
|
||||
# Check Docker Compose
|
||||
if command -v docker-compose &> /dev/null; then
|
||||
echo "✓ Docker Compose is installed"
|
||||
docker-compose --version
|
||||
else
|
||||
echo "✗ Docker Compose is not installed."
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "========================================="
|
||||
echo "Setup complete!"
|
||||
echo "========================================="
|
||||
echo ""
|
||||
echo "Next steps:"
|
||||
echo "1. Edit .env and add your API keys"
|
||||
echo "2. Follow docs/APPWRITE_SETUP.md to configure Appwrite"
|
||||
echo "3. Start Redis: docker-compose up -d"
|
||||
echo "4. Run the app: python wsgi.py"
|
||||
echo ""
|
||||
echo "For more information, see README.md"
|
||||
echo ""
|
||||
98
api/scripts/start_workers.sh
Executable file
98
api/scripts/start_workers.sh
Executable file
@@ -0,0 +1,98 @@
|
||||
#!/bin/bash
|
||||
# RQ Worker Startup Script
|
||||
#
|
||||
# This script starts RQ workers for processing background jobs.
|
||||
# Workers listen on configured queues in priority order.
|
||||
#
|
||||
# Usage:
|
||||
# ./scripts/start_workers.sh # Start all-queue worker
|
||||
# ./scripts/start_workers.sh ai # Start AI-only worker
|
||||
# ./scripts/start_workers.sh combat # Start combat-only worker
|
||||
# ./scripts/start_workers.sh marketplace # Start marketplace-only worker
|
||||
#
|
||||
# Environment Variables:
|
||||
# REDIS_URL - Redis connection URL (default: redis://localhost:6379/0)
|
||||
# LOG_LEVEL - Logging level (default: INFO)
|
||||
# WORKER_COUNT - Number of workers to start (default: 1)
|
||||
|
||||
set -e
|
||||
|
||||
# Change to API directory
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
API_DIR="$(dirname "$SCRIPT_DIR")"
|
||||
cd "$API_DIR"
|
||||
|
||||
# Load environment variables if .env exists
|
||||
if [ -f .env ]; then
|
||||
export $(grep -v '^#' .env | xargs)
|
||||
fi
|
||||
|
||||
# Default configuration
|
||||
REDIS_URL="${REDIS_URL:-redis://localhost:6379/0}"
|
||||
LOG_LEVEL="${LOG_LEVEL:-INFO}"
|
||||
WORKER_COUNT="${WORKER_COUNT:-1}"
|
||||
|
||||
# Determine which queues to listen on
|
||||
WORKER_TYPE="${1:-all}"
|
||||
|
||||
case "$WORKER_TYPE" in
|
||||
ai)
|
||||
QUEUES="ai_tasks"
|
||||
WORKER_NAME="ai-worker"
|
||||
;;
|
||||
combat)
|
||||
QUEUES="combat_tasks"
|
||||
WORKER_NAME="combat-worker"
|
||||
;;
|
||||
marketplace)
|
||||
QUEUES="marketplace_tasks"
|
||||
WORKER_NAME="marketplace-worker"
|
||||
;;
|
||||
all|*)
|
||||
QUEUES="ai_tasks,combat_tasks,marketplace_tasks"
|
||||
WORKER_NAME="all-queues-worker"
|
||||
;;
|
||||
esac
|
||||
|
||||
echo "=========================================="
|
||||
echo "Starting RQ Worker"
|
||||
echo "=========================================="
|
||||
echo "Worker Type: $WORKER_TYPE"
|
||||
echo "Worker Name: $WORKER_NAME"
|
||||
echo "Queues: $QUEUES"
|
||||
echo "Redis URL: ${REDIS_URL//:*@/:***@}"
|
||||
echo "Log Level: $LOG_LEVEL"
|
||||
echo "Worker Count: $WORKER_COUNT"
|
||||
echo "=========================================="
|
||||
|
||||
# Activate virtual environment if it exists
|
||||
if [ -d "venv" ]; then
|
||||
echo "Activating virtual environment..."
|
||||
source venv/bin/activate
|
||||
fi
|
||||
|
||||
# Start workers
|
||||
if [ "$WORKER_COUNT" -eq 1 ]; then
|
||||
# Single worker
|
||||
echo "Starting single worker..."
|
||||
exec rq worker \
|
||||
--url "$REDIS_URL" \
|
||||
--name "$WORKER_NAME" \
|
||||
--logging_level "$LOG_LEVEL" \
|
||||
--with-scheduler \
|
||||
$QUEUES
|
||||
else
|
||||
# Multiple workers (use supervisord or run in background)
|
||||
echo "Starting $WORKER_COUNT workers..."
|
||||
for i in $(seq 1 $WORKER_COUNT); do
|
||||
rq worker \
|
||||
--url "$REDIS_URL" \
|
||||
--name "${WORKER_NAME}-${i}" \
|
||||
--logging_level "$LOG_LEVEL" \
|
||||
$QUEUES &
|
||||
echo "Started worker ${WORKER_NAME}-${i} (PID: $!)"
|
||||
done
|
||||
|
||||
echo "All workers started. Press Ctrl+C to stop."
|
||||
wait
|
||||
fi
|
||||
238
api/scripts/verify_ai_models.py
Normal file
238
api/scripts/verify_ai_models.py
Normal file
@@ -0,0 +1,238 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Verification script for Task 7.8: Verify all AI models respond correctly.
|
||||
|
||||
This script tests:
|
||||
1. Replicate client with Llama-3 8B
|
||||
2. Replicate client with Claude Haiku
|
||||
3. Replicate client with Claude Sonnet
|
||||
4. Model selector tier routing
|
||||
5. Token counting accuracy
|
||||
6. Response quality comparison
|
||||
|
||||
Usage:
|
||||
python scripts/verify_ai_models.py [--all] [--llama] [--haiku] [--sonnet] [--opus]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
# Load .env before importing app modules
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
import structlog
|
||||
from app.ai import (
|
||||
ReplicateClient,
|
||||
ModelType,
|
||||
ModelSelector,
|
||||
UserTier,
|
||||
ContextType,
|
||||
ReplicateClientError,
|
||||
)
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
# Test prompt for narrative generation
|
||||
TEST_PROMPT = """You are a dungeon master. The player enters a dimly lit tavern.
|
||||
Describe the scene in 2-3 sentences. Include at least one interesting NPC."""
|
||||
|
||||
TEST_SYSTEM_PROMPT = "You are a creative fantasy storyteller. Keep responses concise but vivid."
|
||||
|
||||
|
||||
def test_model(model_type: ModelType, client: ReplicateClient | None = None) -> dict:
|
||||
"""
|
||||
Test a specific model and return results.
|
||||
|
||||
Args:
|
||||
model_type: The model to test.
|
||||
client: Optional existing client, otherwise creates new one.
|
||||
|
||||
Returns:
|
||||
Dictionary with test results.
|
||||
"""
|
||||
model_name = model_type.name
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Testing: {model_name}")
|
||||
print(f"Model ID: {model_type.value}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
try:
|
||||
if client is None:
|
||||
client = ReplicateClient(model=model_type)
|
||||
|
||||
start_time = time.time()
|
||||
response = client.generate(
|
||||
prompt=TEST_PROMPT,
|
||||
system_prompt=TEST_SYSTEM_PROMPT,
|
||||
model=model_type
|
||||
)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
print(f"\n✅ SUCCESS")
|
||||
print(f"Response time: {elapsed:.2f}s")
|
||||
print(f"Tokens used: {response.tokens_used}")
|
||||
print(f"Response length: {len(response.text)} chars")
|
||||
print(f"\nGenerated text:")
|
||||
print("-" * 40)
|
||||
print(response.text[:500] + ("..." if len(response.text) > 500 else ""))
|
||||
print("-" * 40)
|
||||
|
||||
return {
|
||||
"model": model_name,
|
||||
"success": True,
|
||||
"response_time": elapsed,
|
||||
"tokens_used": response.tokens_used,
|
||||
"text_length": len(response.text),
|
||||
"text_preview": response.text[:200]
|
||||
}
|
||||
|
||||
except ReplicateClientError as e:
|
||||
print(f"\n❌ FAILED: {e}")
|
||||
return {
|
||||
"model": model_name,
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"\n❌ UNEXPECTED ERROR: {e}")
|
||||
return {
|
||||
"model": model_name,
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
def test_model_selector():
|
||||
"""Test the model selector tier routing."""
|
||||
print(f"\n{'='*60}")
|
||||
print("Testing Model Selector")
|
||||
print(f"{'='*60}")
|
||||
|
||||
selector = ModelSelector()
|
||||
|
||||
test_cases = [
|
||||
(UserTier.FREE, ContextType.STORY_PROGRESSION),
|
||||
(UserTier.BASIC, ContextType.STORY_PROGRESSION),
|
||||
(UserTier.PREMIUM, ContextType.STORY_PROGRESSION),
|
||||
(UserTier.ELITE, ContextType.STORY_PROGRESSION),
|
||||
(UserTier.PREMIUM, ContextType.QUEST_SELECTION),
|
||||
(UserTier.PREMIUM, ContextType.COMBAT_NARRATION),
|
||||
]
|
||||
|
||||
print("\nTier → Model Routing:")
|
||||
print("-" * 40)
|
||||
|
||||
for tier, context in test_cases:
|
||||
config = selector.select_model(tier, context)
|
||||
info = selector.get_tier_info(tier)
|
||||
cost = selector.estimate_cost_per_request(tier)
|
||||
|
||||
print(f"{tier.value:10} + {context.value:20} → {config.model_type.name:15} "
|
||||
f"(tokens={config.max_tokens}, temp={config.temperature}, cost=${cost:.4f})")
|
||||
|
||||
print("\n✅ Model selector routing verified")
|
||||
|
||||
|
||||
def run_verification(models_to_test: list[ModelType]):
|
||||
"""
|
||||
Run full verification suite.
|
||||
|
||||
Args:
|
||||
models_to_test: List of models to test with real API calls.
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("Phase 4 Task 7.8: AI Model Verification")
|
||||
print("=" * 60)
|
||||
|
||||
results = []
|
||||
|
||||
# Test model selector first (no API calls)
|
||||
test_model_selector()
|
||||
|
||||
if not models_to_test:
|
||||
print("\nNo models selected for API testing.")
|
||||
print("Use --llama, --haiku, --sonnet, --opus, or --all")
|
||||
return
|
||||
|
||||
# Create a single client for efficiency
|
||||
try:
|
||||
client = ReplicateClient()
|
||||
except ReplicateClientError as e:
|
||||
print(f"\n❌ Failed to initialize Replicate client: {e}")
|
||||
print("Check REPLICATE_API_TOKEN in .env")
|
||||
return
|
||||
|
||||
# Test each selected model
|
||||
for model_type in models_to_test:
|
||||
result = test_model(model_type, client)
|
||||
results.append(result)
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("VERIFICATION SUMMARY")
|
||||
print("=" * 60)
|
||||
|
||||
passed = sum(1 for r in results if r.get("success"))
|
||||
failed = len(results) - passed
|
||||
|
||||
for result in results:
|
||||
status = "✅" if result.get("success") else "❌"
|
||||
model = result.get("model")
|
||||
if result.get("success"):
|
||||
time_s = result.get("response_time", 0)
|
||||
tokens = result.get("tokens_used", 0)
|
||||
print(f"{status} {model}: {time_s:.2f}s, {tokens} tokens")
|
||||
else:
|
||||
error = result.get("error", "Unknown error")
|
||||
print(f"{status} {model}: {error[:50]}")
|
||||
|
||||
print(f"\nTotal: {passed} passed, {failed} failed")
|
||||
|
||||
if failed == 0:
|
||||
print("\n✅ All verification checks passed!")
|
||||
else:
|
||||
print(f"\n⚠️ {failed} model(s) failed verification")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Verify AI models respond correctly through Replicate"
|
||||
)
|
||||
parser.add_argument("--all", action="store_true", help="Test all models")
|
||||
parser.add_argument("--llama", action="store_true", help="Test Llama-3 8B")
|
||||
parser.add_argument("--haiku", action="store_true", help="Test Claude Haiku")
|
||||
parser.add_argument("--sonnet", action="store_true", help="Test Claude Sonnet")
|
||||
parser.add_argument("--opus", action="store_true", help="Test Claude Opus")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
models_to_test = []
|
||||
|
||||
if args.all:
|
||||
models_to_test = [
|
||||
ModelType.LLAMA_3_8B,
|
||||
ModelType.CLAUDE_HAIKU,
|
||||
ModelType.CLAUDE_SONNET,
|
||||
ModelType.CLAUDE_SONNET_4,
|
||||
]
|
||||
else:
|
||||
if args.llama:
|
||||
models_to_test.append(ModelType.LLAMA_3_8B)
|
||||
if args.haiku:
|
||||
models_to_test.append(ModelType.CLAUDE_HAIKU)
|
||||
if args.sonnet:
|
||||
models_to_test.append(ModelType.CLAUDE_SONNET)
|
||||
if args.opus:
|
||||
models_to_test.append(ModelType.CLAUDE_SONNET_4)
|
||||
|
||||
run_verification(models_to_test)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
757
api/scripts/verify_e2e_ai_generation.py
Executable file
757
api/scripts/verify_e2e_ai_generation.py
Executable file
@@ -0,0 +1,757 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Task 7.12: CHECKPOINT - Verify end-to-end AI generation flow
|
||||
|
||||
This script verifies the complete AI generation pipeline:
|
||||
1. Queue a story action job via RQ
|
||||
2. Verify job processes and calls AI client
|
||||
3. Check AI response is coherent and appropriate
|
||||
4. Verify GameSession updated in Appwrite
|
||||
5. Confirm Realtime notification sent (via document update)
|
||||
6. Test job failure and retry logic
|
||||
7. Verify response stored in Redis cache
|
||||
8. Test with all 3 user tiers (Free, Premium, Elite)
|
||||
|
||||
Usage:
|
||||
# Run without real AI calls (mock mode)
|
||||
python scripts/verify_e2e_ai_generation.py
|
||||
|
||||
# Run with real AI calls (requires REPLICATE_API_TOKEN)
|
||||
python scripts/verify_e2e_ai_generation.py --real
|
||||
|
||||
# Test specific user tier
|
||||
python scripts/verify_e2e_ai_generation.py --tier free
|
||||
python scripts/verify_e2e_ai_generation.py --tier premium
|
||||
python scripts/verify_e2e_ai_generation.py --tier elite
|
||||
|
||||
# Run full integration test (requires Redis, worker, Appwrite)
|
||||
python scripts/verify_e2e_ai_generation.py --integration
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
# Load environment variables from .env file
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
from app.ai.model_selector import UserTier, ContextType, ModelSelector
|
||||
from app.ai.narrative_generator import NarrativeGenerator, NarrativeResponse
|
||||
from app.ai.replicate_client import ReplicateClient, ReplicateResponse, ModelType
|
||||
from app.tasks.ai_tasks import (
|
||||
enqueue_ai_task,
|
||||
get_job_status,
|
||||
get_job_result,
|
||||
process_ai_task,
|
||||
TaskType,
|
||||
JobStatus,
|
||||
)
|
||||
from app.services.redis_service import RedisService
|
||||
|
||||
|
||||
class Colors:
|
||||
"""Terminal colors for output."""
|
||||
GREEN = '\033[92m'
|
||||
YELLOW = '\033[93m'
|
||||
RED = '\033[91m'
|
||||
BLUE = '\033[94m'
|
||||
BOLD = '\033[1m'
|
||||
END = '\033[0m'
|
||||
|
||||
|
||||
def log_pass(message: str) -> None:
|
||||
"""Log a passing test."""
|
||||
print(f"{Colors.GREEN}✓{Colors.END} {message}")
|
||||
|
||||
|
||||
def log_fail(message: str) -> None:
|
||||
"""Log a failing test."""
|
||||
print(f"{Colors.RED}✗{Colors.END} {message}")
|
||||
|
||||
|
||||
def log_info(message: str) -> None:
|
||||
"""Log info message."""
|
||||
print(f"{Colors.BLUE}ℹ{Colors.END} {message}")
|
||||
|
||||
|
||||
def log_section(title: str) -> None:
|
||||
"""Log section header."""
|
||||
print(f"\n{Colors.BOLD}{Colors.YELLOW}{'='*60}{Colors.END}")
|
||||
print(f"{Colors.BOLD}{Colors.YELLOW}{title}{Colors.END}")
|
||||
print(f"{Colors.BOLD}{Colors.YELLOW}{'='*60}{Colors.END}\n")
|
||||
|
||||
|
||||
# Sample test data
|
||||
SAMPLE_CHARACTER = {
|
||||
"character_id": "char_test_123",
|
||||
"name": "Aldric the Bold",
|
||||
"level": 3,
|
||||
"player_class": "Fighter",
|
||||
"race": "Human",
|
||||
"stats": {
|
||||
"strength": 16,
|
||||
"dexterity": 12,
|
||||
"constitution": 14,
|
||||
"intelligence": 10,
|
||||
"wisdom": 11,
|
||||
"charisma": 13
|
||||
},
|
||||
"current_hp": 28,
|
||||
"max_hp": 28,
|
||||
"gold": 50,
|
||||
"inventory": [
|
||||
{"name": "Longsword", "type": "weapon", "quantity": 1},
|
||||
{"name": "Shield", "type": "armor", "quantity": 1},
|
||||
{"name": "Healing Potion", "type": "consumable", "quantity": 2}
|
||||
],
|
||||
"skills": [
|
||||
{"name": "Athletics", "level": 5},
|
||||
{"name": "Intimidation", "level": 3},
|
||||
{"name": "Perception", "level": 4}
|
||||
],
|
||||
"effects": []
|
||||
}
|
||||
|
||||
SAMPLE_GAME_STATE = {
|
||||
"current_location": "The Rusty Anchor Tavern",
|
||||
"location_type": "TAVERN",
|
||||
"discovered_locations": ["Crossroads Village", "The Rusty Anchor Tavern"],
|
||||
"active_quests": [],
|
||||
"world_events": [],
|
||||
"time_of_day": "evening",
|
||||
"weather": "clear"
|
||||
}
|
||||
|
||||
SAMPLE_CONVERSATION_HISTORY = [
|
||||
{
|
||||
"turn": 1,
|
||||
"action": "I enter the tavern",
|
||||
"dm_response": "You push open the heavy wooden door and step inside. The warmth hits you immediately...",
|
||||
"timestamp": "2025-11-21T10:00:00Z"
|
||||
},
|
||||
{
|
||||
"turn": 2,
|
||||
"action": "I approach the bar",
|
||||
"dm_response": "The barkeep, a stout dwarf with a magnificent braided beard, looks up...",
|
||||
"timestamp": "2025-11-21T10:05:00Z"
|
||||
}
|
||||
]
|
||||
|
||||
SAMPLE_COMBAT_STATE = {
|
||||
"round_number": 2,
|
||||
"enemies": [
|
||||
{"name": "Goblin", "current_hp": 5, "max_hp": 7, "armor_class": 13}
|
||||
],
|
||||
"is_player_turn": True,
|
||||
"combat_log": []
|
||||
}
|
||||
|
||||
SAMPLE_NPC = {
|
||||
"name": "Old Barkeep",
|
||||
"role": "Tavern Owner",
|
||||
"personality": "Gruff but kind-hearted",
|
||||
"description": "A stout dwarf with a magnificent braided beard and keen eyes"
|
||||
}
|
||||
|
||||
SAMPLE_ELIGIBLE_QUESTS = [
|
||||
{
|
||||
"quest_id": "quest_goblin_cave",
|
||||
"name": "Clear the Goblin Cave",
|
||||
"description": "A nearby cave has been overrun by goblins raiding farms",
|
||||
"quest_giver": "Village Elder",
|
||||
"difficulty": "EASY",
|
||||
"narrative_hooks": [
|
||||
"The village elder looks worried about recent goblin attacks",
|
||||
"You hear farmers complaining about lost livestock"
|
||||
]
|
||||
},
|
||||
{
|
||||
"quest_id": "quest_lost_merchant",
|
||||
"name": "Find the Lost Merchant",
|
||||
"description": "A merchant went missing on the forest road",
|
||||
"quest_giver": "Merchant Guild",
|
||||
"difficulty": "EASY",
|
||||
"narrative_hooks": [
|
||||
"Posters about a missing merchant are everywhere",
|
||||
"The merchant guild is offering a reward"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def verify_model_selector_routing() -> bool:
|
||||
"""Verify model selector routes correctly for all tiers."""
|
||||
log_section("1. Model Selector Routing")
|
||||
|
||||
selector = ModelSelector()
|
||||
all_passed = True
|
||||
|
||||
tier_tests = [
|
||||
(UserTier.FREE, ModelType.LLAMA_3_8B, "Llama-3 8B"),
|
||||
(UserTier.BASIC, ModelType.CLAUDE_HAIKU, "Claude Haiku"),
|
||||
(UserTier.PREMIUM, ModelType.CLAUDE_SONNET, "Claude Sonnet"),
|
||||
(UserTier.ELITE, ModelType.CLAUDE_SONNET_4, "Claude Sonnet 4.5"),
|
||||
]
|
||||
|
||||
for tier, expected_model, model_name in tier_tests:
|
||||
config = selector.select_model(tier, ContextType.STORY_PROGRESSION)
|
||||
if config.model_type == expected_model:
|
||||
log_pass(f"{tier.value} tier → {model_name}")
|
||||
else:
|
||||
log_fail(f"{tier.value} tier: Expected {model_name}, got {config.model_type}")
|
||||
all_passed = False
|
||||
|
||||
return all_passed
|
||||
|
||||
|
||||
def verify_narrative_generator_mocked() -> bool:
|
||||
"""Verify NarrativeGenerator works with mocked AI client."""
|
||||
log_section("2. Narrative Generator (Mocked)")
|
||||
|
||||
all_passed = True
|
||||
|
||||
# Mock the Replicate client
|
||||
mock_response = ReplicateResponse(
|
||||
text="You scan the tavern carefully, your trained eyes taking in every detail...",
|
||||
tokens_used=150,
|
||||
model="meta/meta-llama-3-8b-instruct",
|
||||
generation_time=1.5
|
||||
)
|
||||
|
||||
mock_client = MagicMock(spec=ReplicateClient)
|
||||
mock_client.generate.return_value = mock_response
|
||||
|
||||
generator = NarrativeGenerator(replicate_client=mock_client)
|
||||
|
||||
# Test story response
|
||||
try:
|
||||
response = generator.generate_story_response(
|
||||
character=SAMPLE_CHARACTER,
|
||||
action="I search the room for hidden doors",
|
||||
game_state=SAMPLE_GAME_STATE,
|
||||
user_tier=UserTier.FREE,
|
||||
conversation_history=SAMPLE_CONVERSATION_HISTORY
|
||||
)
|
||||
|
||||
if response.narrative and len(response.narrative) > 0:
|
||||
log_pass(f"Story response generated ({response.tokens_used} tokens)")
|
||||
else:
|
||||
log_fail("Story response is empty")
|
||||
all_passed = False
|
||||
|
||||
except Exception as e:
|
||||
log_fail(f"Story generation failed: {e}")
|
||||
all_passed = False
|
||||
|
||||
# Test combat narration
|
||||
try:
|
||||
action_result = {"hit": True, "damage": 8, "effects": []}
|
||||
response = generator.generate_combat_narration(
|
||||
character=SAMPLE_CHARACTER,
|
||||
combat_state=SAMPLE_COMBAT_STATE,
|
||||
action="swings sword at goblin",
|
||||
action_result=action_result,
|
||||
user_tier=UserTier.BASIC,
|
||||
is_critical=False,
|
||||
is_finishing_blow=True
|
||||
)
|
||||
|
||||
if response.narrative:
|
||||
log_pass(f"Combat narration generated ({response.tokens_used} tokens)")
|
||||
else:
|
||||
log_fail("Combat narration is empty")
|
||||
all_passed = False
|
||||
|
||||
except Exception as e:
|
||||
log_fail(f"Combat narration failed: {e}")
|
||||
all_passed = False
|
||||
|
||||
# Test NPC dialogue
|
||||
try:
|
||||
response = generator.generate_npc_dialogue(
|
||||
character=SAMPLE_CHARACTER,
|
||||
npc=SAMPLE_NPC,
|
||||
conversation_topic="What rumors have you heard lately?",
|
||||
game_state=SAMPLE_GAME_STATE,
|
||||
user_tier=UserTier.PREMIUM
|
||||
)
|
||||
|
||||
if response.narrative:
|
||||
log_pass(f"NPC dialogue generated ({response.tokens_used} tokens)")
|
||||
else:
|
||||
log_fail("NPC dialogue is empty")
|
||||
all_passed = False
|
||||
|
||||
except Exception as e:
|
||||
log_fail(f"NPC dialogue failed: {e}")
|
||||
all_passed = False
|
||||
|
||||
# Test quest selection
|
||||
mock_client.generate.return_value = ReplicateResponse(
|
||||
text="quest_goblin_cave",
|
||||
tokens_used=50,
|
||||
model="meta/meta-llama-3-8b-instruct",
|
||||
generation_time=0.5
|
||||
)
|
||||
|
||||
try:
|
||||
quest_id = generator.generate_quest_selection(
|
||||
character=SAMPLE_CHARACTER,
|
||||
eligible_quests=SAMPLE_ELIGIBLE_QUESTS,
|
||||
game_context=SAMPLE_GAME_STATE,
|
||||
user_tier=UserTier.FREE
|
||||
)
|
||||
|
||||
if quest_id == "quest_goblin_cave":
|
||||
log_pass(f"Quest selection returned: {quest_id}")
|
||||
else:
|
||||
log_fail(f"Unexpected quest_id: {quest_id}")
|
||||
all_passed = False
|
||||
|
||||
except Exception as e:
|
||||
log_fail(f"Quest selection failed: {e}")
|
||||
all_passed = False
|
||||
|
||||
return all_passed
|
||||
|
||||
|
||||
def verify_ai_task_processing_mocked() -> bool:
|
||||
"""Verify AI task processing with mocked components."""
|
||||
log_section("3. AI Task Processing (Mocked)")
|
||||
|
||||
all_passed = True
|
||||
|
||||
# Mock dependencies
|
||||
mock_response = ReplicateResponse(
|
||||
text="The tavern grows quiet as you make your proclamation...",
|
||||
tokens_used=200,
|
||||
model="meta/meta-llama-3-8b-instruct",
|
||||
generation_time=2.0
|
||||
)
|
||||
|
||||
with patch('app.tasks.ai_tasks.NarrativeGenerator') as MockGenerator, \
|
||||
patch('app.tasks.ai_tasks._get_user_tier') as mock_get_tier, \
|
||||
patch('app.tasks.ai_tasks._update_game_session') as mock_update_session:
|
||||
|
||||
# Setup mocks
|
||||
mock_get_tier.return_value = UserTier.FREE
|
||||
|
||||
mock_gen_instance = MagicMock()
|
||||
mock_gen_instance.generate_story_response.return_value = NarrativeResponse(
|
||||
narrative=mock_response.text,
|
||||
tokens_used=mock_response.tokens_used,
|
||||
model=mock_response.model,
|
||||
context_type="story_progression",
|
||||
generation_time=mock_response.generation_time
|
||||
)
|
||||
MockGenerator.return_value = mock_gen_instance
|
||||
|
||||
# Test narrative task processing
|
||||
context = {
|
||||
"action": "I stand on a table and announce myself",
|
||||
"character": SAMPLE_CHARACTER,
|
||||
"game_state": SAMPLE_GAME_STATE,
|
||||
"conversation_history": SAMPLE_CONVERSATION_HISTORY
|
||||
}
|
||||
|
||||
job_id = f"test_{uuid4().hex[:8]}"
|
||||
|
||||
try:
|
||||
result = process_ai_task(
|
||||
task_type="narrative",
|
||||
user_id="test_user_123",
|
||||
context=context,
|
||||
job_id=job_id,
|
||||
session_id="sess_test_123",
|
||||
character_id="char_test_123"
|
||||
)
|
||||
|
||||
if result.get("narrative"):
|
||||
log_pass(f"Narrative task processed successfully")
|
||||
log_info(f" Tokens: {result.get('tokens_used')}, Model: {result.get('model')}")
|
||||
else:
|
||||
log_fail("Narrative task returned no narrative")
|
||||
all_passed = False
|
||||
|
||||
# Verify session update was called
|
||||
if mock_update_session.called:
|
||||
log_pass("GameSession update called")
|
||||
else:
|
||||
log_fail("GameSession update NOT called")
|
||||
all_passed = False
|
||||
|
||||
except Exception as e:
|
||||
log_fail(f"Narrative task processing failed: {e}")
|
||||
all_passed = False
|
||||
|
||||
return all_passed
|
||||
|
||||
|
||||
def verify_job_lifecycle_mocked() -> bool:
|
||||
"""Verify job queueing, status tracking, and result storage (mocked)."""
|
||||
log_section("4. Job Lifecycle (Mocked)")
|
||||
|
||||
all_passed = True
|
||||
|
||||
# Test with mocked Redis and queue
|
||||
with patch('app.tasks.ai_tasks.get_queue') as mock_get_queue, \
|
||||
patch('app.tasks.ai_tasks._store_job_status') as mock_store_status:
|
||||
|
||||
mock_queue = MagicMock()
|
||||
mock_job = MagicMock()
|
||||
mock_job.id = "test_job_123"
|
||||
mock_queue.enqueue.return_value = mock_job
|
||||
mock_get_queue.return_value = mock_queue
|
||||
|
||||
# Test job enqueueing
|
||||
try:
|
||||
result = enqueue_ai_task(
|
||||
task_type="narrative",
|
||||
user_id="test_user",
|
||||
context={"action": "test", "character": {}, "game_state": {}},
|
||||
priority="high"
|
||||
)
|
||||
|
||||
if result.get("job_id") and result.get("status") == "queued":
|
||||
log_pass(f"Job enqueued: {result.get('job_id')}")
|
||||
else:
|
||||
log_fail(f"Unexpected enqueue result: {result}")
|
||||
all_passed = False
|
||||
|
||||
# Verify queue was called with correct priority
|
||||
if mock_queue.enqueue.called:
|
||||
call_kwargs = mock_queue.enqueue.call_args
|
||||
if call_kwargs.kwargs.get('at_front') == True:
|
||||
log_pass("High priority job placed at front of queue")
|
||||
else:
|
||||
log_fail("High priority not placed at front")
|
||||
all_passed = False
|
||||
|
||||
# Verify status was stored
|
||||
if mock_store_status.called:
|
||||
log_pass("Job status stored in Redis")
|
||||
else:
|
||||
log_fail("Job status NOT stored")
|
||||
all_passed = False
|
||||
|
||||
except Exception as e:
|
||||
log_fail(f"Job enqueueing failed: {e}")
|
||||
all_passed = False
|
||||
|
||||
return all_passed
|
||||
|
||||
|
||||
def verify_error_handling() -> bool:
|
||||
"""Verify error handling and validation."""
|
||||
log_section("5. Error Handling")
|
||||
|
||||
all_passed = True
|
||||
|
||||
# Test invalid task type
|
||||
try:
|
||||
enqueue_ai_task(
|
||||
task_type="invalid_type",
|
||||
user_id="test",
|
||||
context={}
|
||||
)
|
||||
log_fail("Should have raised ValueError for invalid task_type")
|
||||
all_passed = False
|
||||
except ValueError as e:
|
||||
if "Invalid task_type" in str(e):
|
||||
log_pass("Invalid task_type raises ValueError")
|
||||
else:
|
||||
log_fail(f"Unexpected error: {e}")
|
||||
all_passed = False
|
||||
|
||||
# Test invalid priority
|
||||
try:
|
||||
enqueue_ai_task(
|
||||
task_type="narrative",
|
||||
user_id="test",
|
||||
context={},
|
||||
priority="super_urgent"
|
||||
)
|
||||
log_fail("Should have raised ValueError for invalid priority")
|
||||
all_passed = False
|
||||
except ValueError as e:
|
||||
if "Invalid priority" in str(e):
|
||||
log_pass("Invalid priority raises ValueError")
|
||||
else:
|
||||
log_fail(f"Unexpected error: {e}")
|
||||
all_passed = False
|
||||
|
||||
# Test missing context fields
|
||||
with patch('app.tasks.ai_tasks._get_user_tier') as mock_tier, \
|
||||
patch('app.tasks.ai_tasks._update_job_status'):
|
||||
mock_tier.return_value = UserTier.FREE
|
||||
|
||||
try:
|
||||
process_ai_task(
|
||||
task_type="narrative",
|
||||
user_id="test",
|
||||
context={"action": "test"}, # Missing character and game_state
|
||||
job_id="test_job"
|
||||
)
|
||||
log_fail("Should have raised error for missing context fields")
|
||||
all_passed = False
|
||||
except ValueError as e:
|
||||
if "Missing required context field" in str(e):
|
||||
log_pass("Missing context fields raises ValueError")
|
||||
else:
|
||||
log_fail(f"Unexpected error: {e}")
|
||||
all_passed = False
|
||||
|
||||
return all_passed
|
||||
|
||||
|
||||
def verify_real_ai_generation(tier: str = "free") -> bool:
|
||||
"""Test with real AI calls (requires REPLICATE_API_TOKEN)."""
|
||||
log_section(f"6. Real AI Generation ({tier.upper()} tier)")
|
||||
|
||||
# Check for API token
|
||||
if not os.environ.get("REPLICATE_API_TOKEN"):
|
||||
log_info("REPLICATE_API_TOKEN not set - skipping real AI test")
|
||||
return True
|
||||
|
||||
tier_map = {
|
||||
"free": UserTier.FREE,
|
||||
"basic": UserTier.BASIC,
|
||||
"premium": UserTier.PREMIUM,
|
||||
"elite": UserTier.ELITE
|
||||
}
|
||||
|
||||
user_tier = tier_map.get(tier.lower(), UserTier.FREE)
|
||||
|
||||
generator = NarrativeGenerator()
|
||||
|
||||
try:
|
||||
log_info("Calling Replicate API...")
|
||||
response = generator.generate_story_response(
|
||||
character=SAMPLE_CHARACTER,
|
||||
action="I look around the tavern and ask the barkeep about any interesting rumors",
|
||||
game_state=SAMPLE_GAME_STATE,
|
||||
user_tier=user_tier,
|
||||
conversation_history=SAMPLE_CONVERSATION_HISTORY
|
||||
)
|
||||
|
||||
log_pass(f"AI response generated successfully")
|
||||
log_info(f" Model: {response.model}")
|
||||
log_info(f" Tokens: {response.tokens_used}")
|
||||
log_info(f" Time: {response.generation_time:.2f}s")
|
||||
log_info(f" Response preview: {response.narrative[:200]}...")
|
||||
|
||||
# Check response quality
|
||||
if len(response.narrative) > 50:
|
||||
log_pass("Response has substantial content")
|
||||
else:
|
||||
log_fail("Response seems too short")
|
||||
return False
|
||||
|
||||
if any(word in response.narrative.lower() for word in ["tavern", "barkeep", "rumor", "hear"]):
|
||||
log_pass("Response is contextually relevant")
|
||||
else:
|
||||
log_info("Response may not be fully contextual (check manually)")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
log_fail(f"Real AI generation failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def verify_integration(tier: str = "free") -> bool:
|
||||
"""Full integration test with Redis, RQ, and real job processing."""
|
||||
log_section("7. Full Integration Test")
|
||||
|
||||
# Check Redis connection
|
||||
try:
|
||||
redis = RedisService()
|
||||
redis.set("integration_test", "ok", ttl=60)
|
||||
if redis.get("integration_test") == "ok":
|
||||
log_pass("Redis connection working")
|
||||
else:
|
||||
log_fail("Redis read/write failed")
|
||||
return False
|
||||
except Exception as e:
|
||||
log_fail(f"Redis connection failed: {e}")
|
||||
log_info("Make sure Redis is running: docker-compose up -d redis")
|
||||
return False
|
||||
|
||||
# Check if we have Replicate token
|
||||
has_api_token = bool(os.environ.get("REPLICATE_API_TOKEN"))
|
||||
if not has_api_token:
|
||||
log_info("REPLICATE_API_TOKEN not set - will test with mocked AI")
|
||||
|
||||
tier_map = {
|
||||
"free": UserTier.FREE,
|
||||
"basic": UserTier.BASIC,
|
||||
"premium": UserTier.PREMIUM,
|
||||
"elite": UserTier.ELITE
|
||||
}
|
||||
user_tier = tier_map.get(tier.lower(), UserTier.FREE)
|
||||
|
||||
# Create context for test
|
||||
context = {
|
||||
"action": "I search the tavern for any suspicious characters",
|
||||
"character": SAMPLE_CHARACTER,
|
||||
"game_state": SAMPLE_GAME_STATE,
|
||||
"conversation_history": SAMPLE_CONVERSATION_HISTORY
|
||||
}
|
||||
|
||||
if has_api_token:
|
||||
# Real integration test - queue job and let worker process it
|
||||
log_info("To run full integration, start a worker in another terminal:")
|
||||
log_info(" cd api && source venv/bin/activate")
|
||||
log_info(" rq worker ai_tasks --url redis://localhost:6379")
|
||||
|
||||
try:
|
||||
result = enqueue_ai_task(
|
||||
task_type="narrative",
|
||||
user_id="integration_test_user",
|
||||
context=context,
|
||||
priority="high"
|
||||
)
|
||||
|
||||
job_id = result.get("job_id")
|
||||
log_pass(f"Job enqueued: {job_id}")
|
||||
|
||||
# Poll for completion
|
||||
log_info("Waiting for worker to process job...")
|
||||
max_wait = 60 # seconds
|
||||
waited = 0
|
||||
|
||||
while waited < max_wait:
|
||||
status = get_job_status(job_id)
|
||||
current_status = status.get("status", "unknown")
|
||||
|
||||
if current_status == "completed":
|
||||
log_pass(f"Job completed after {waited}s")
|
||||
|
||||
# Get result
|
||||
job_result = get_job_result(job_id)
|
||||
if job_result:
|
||||
log_pass("Job result retrieved from Redis")
|
||||
log_info(f" Tokens: {job_result.get('tokens_used')}")
|
||||
log_info(f" Model: {job_result.get('model')}")
|
||||
else:
|
||||
log_fail("Could not retrieve job result")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
elif current_status == "failed":
|
||||
log_fail(f"Job failed: {status.get('error')}")
|
||||
return False
|
||||
|
||||
time.sleep(2)
|
||||
waited += 2
|
||||
|
||||
log_fail(f"Job did not complete within {max_wait}s")
|
||||
log_info("Make sure RQ worker is running")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
log_fail(f"Integration test failed: {e}")
|
||||
return False
|
||||
else:
|
||||
# Mocked integration test - process directly
|
||||
log_info("Running mocked integration (no worker needed)")
|
||||
|
||||
with patch('app.tasks.ai_tasks.NarrativeGenerator') as MockGenerator, \
|
||||
patch('app.tasks.ai_tasks._get_user_tier') as mock_get_tier, \
|
||||
patch('app.tasks.ai_tasks._update_game_session') as mock_update:
|
||||
|
||||
mock_get_tier.return_value = user_tier
|
||||
|
||||
mock_gen = MagicMock()
|
||||
mock_gen.generate_story_response.return_value = NarrativeResponse(
|
||||
narrative="The tavern is filled with a motley crew of adventurers...",
|
||||
tokens_used=180,
|
||||
model="meta/meta-llama-3-8b-instruct",
|
||||
context_type="story_progression",
|
||||
generation_time=1.8
|
||||
)
|
||||
MockGenerator.return_value = mock_gen
|
||||
|
||||
job_id = f"integration_test_{uuid4().hex[:8]}"
|
||||
|
||||
try:
|
||||
result = process_ai_task(
|
||||
task_type="narrative",
|
||||
user_id="integration_test_user",
|
||||
context=context,
|
||||
job_id=job_id,
|
||||
session_id="sess_integration_test"
|
||||
)
|
||||
|
||||
log_pass("Mocked job processed successfully")
|
||||
log_info(f" Result: {result.get('narrative', '')[:100]}...")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
log_fail(f"Mocked integration failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all verification tests."""
|
||||
parser = argparse.ArgumentParser(description="Verify end-to-end AI generation flow")
|
||||
parser.add_argument("--real", action="store_true", help="Run with real AI API calls")
|
||||
parser.add_argument("--tier", type=str, default="free",
|
||||
choices=["free", "basic", "premium", "elite"],
|
||||
help="User tier to test")
|
||||
parser.add_argument("--integration", action="store_true",
|
||||
help="Run full integration test with Redis/RQ")
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"\n{Colors.BOLD}Task 7.12: End-to-End AI Generation Verification{Colors.END}")
|
||||
print(f"Started at: {datetime.now(timezone.utc).isoformat()}\n")
|
||||
|
||||
results = []
|
||||
|
||||
# Core tests (always run)
|
||||
results.append(("Model Selector Routing", verify_model_selector_routing()))
|
||||
results.append(("Narrative Generator (Mocked)", verify_narrative_generator_mocked()))
|
||||
results.append(("AI Task Processing (Mocked)", verify_ai_task_processing_mocked()))
|
||||
results.append(("Job Lifecycle (Mocked)", verify_job_lifecycle_mocked()))
|
||||
results.append(("Error Handling", verify_error_handling()))
|
||||
|
||||
# Optional tests
|
||||
if args.real:
|
||||
results.append(("Real AI Generation", verify_real_ai_generation(args.tier)))
|
||||
|
||||
if args.integration:
|
||||
results.append(("Full Integration", verify_integration(args.tier)))
|
||||
|
||||
# Summary
|
||||
log_section("VERIFICATION SUMMARY")
|
||||
|
||||
passed = sum(1 for _, result in results if result)
|
||||
total = len(results)
|
||||
|
||||
for name, result in results:
|
||||
status = f"{Colors.GREEN}PASS{Colors.END}" if result else f"{Colors.RED}FAIL{Colors.END}"
|
||||
print(f" {name}: {status}")
|
||||
|
||||
print(f"\n{Colors.BOLD}Total: {passed}/{total} tests passed{Colors.END}")
|
||||
|
||||
if passed == total:
|
||||
print(f"\n{Colors.GREEN}✓ Task 7.12 CHECKPOINT VERIFIED{Colors.END}")
|
||||
return 0
|
||||
else:
|
||||
print(f"\n{Colors.RED}✗ Some tests failed - review issues above{Colors.END}")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
428
api/scripts/verify_session_persistence.py
Executable file
428
api/scripts/verify_session_persistence.py
Executable file
@@ -0,0 +1,428 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Verification script for Task 8.24: Session Persistence Checkpoint.
|
||||
|
||||
Tests the SessionService against real Appwrite database to verify:
|
||||
1. Solo session creation and storage
|
||||
2. Session retrieval and ownership validation
|
||||
3. Conversation history persistence
|
||||
4. Game state tracking (location, quests, events)
|
||||
5. Session lifecycle (create, update, end)
|
||||
|
||||
Usage:
|
||||
python scripts/verify_session_persistence.py
|
||||
|
||||
Requirements:
|
||||
- Appwrite configured in .env
|
||||
- game_sessions collection created in Appwrite
|
||||
- At least one character created for testing
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
from app.services.session_service import (
|
||||
SessionService,
|
||||
SessionNotFound,
|
||||
SessionLimitExceeded,
|
||||
SessionValidationError,
|
||||
get_session_service,
|
||||
)
|
||||
from app.services.character_service import get_character_service, CharacterNotFound
|
||||
from app.models.enums import SessionStatus, SessionType, LocationType
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
def print_header(text: str):
|
||||
"""Print a section header."""
|
||||
print(f"\n{'='*60}")
|
||||
print(f" {text}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
def print_result(test_name: str, passed: bool, details: str = ""):
|
||||
"""Print test result."""
|
||||
status = "✅ PASS" if passed else "❌ FAIL"
|
||||
print(f"{status}: {test_name}")
|
||||
if details:
|
||||
print(f" {details}")
|
||||
|
||||
|
||||
def verify_session_creation(service: SessionService, user_id: str, character_id: str) -> str:
|
||||
"""
|
||||
Test 1: Create a solo session and verify it's stored correctly.
|
||||
|
||||
Returns session_id if successful.
|
||||
"""
|
||||
print_header("Test 1: Solo Session Creation")
|
||||
|
||||
try:
|
||||
# Create session
|
||||
session = service.create_solo_session(
|
||||
user_id=user_id,
|
||||
character_id=character_id,
|
||||
starting_location="Test Town",
|
||||
starting_location_type=LocationType.TOWN
|
||||
)
|
||||
|
||||
# Verify fields
|
||||
checks = [
|
||||
(session.session_type == SessionType.SOLO, "session_type is SOLO"),
|
||||
(session.solo_character_id == character_id, "solo_character_id matches"),
|
||||
(session.user_id == user_id, "user_id matches"),
|
||||
(session.turn_number == 0, "turn_number is 0"),
|
||||
(session.status == SessionStatus.ACTIVE, "status is ACTIVE"),
|
||||
(session.game_state.current_location == "Test Town", "current_location set"),
|
||||
(session.game_state.location_type == LocationType.TOWN, "location_type set"),
|
||||
("Test Town" in session.game_state.discovered_locations, "location in discovered"),
|
||||
]
|
||||
|
||||
all_passed = True
|
||||
for passed, desc in checks:
|
||||
print_result(desc, passed)
|
||||
if not passed:
|
||||
all_passed = False
|
||||
|
||||
if all_passed:
|
||||
print(f"\n Session ID: {session.session_id}")
|
||||
return session.session_id
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print_result("Session creation", False, str(e))
|
||||
return None
|
||||
|
||||
|
||||
def verify_session_retrieval(service: SessionService, session_id: str, user_id: str) -> bool:
|
||||
"""
|
||||
Test 2: Load session from database and verify data integrity.
|
||||
"""
|
||||
print_header("Test 2: Session Retrieval")
|
||||
|
||||
try:
|
||||
# Load session
|
||||
session = service.get_session(session_id, user_id)
|
||||
|
||||
checks = [
|
||||
(session.session_id == session_id, "session_id matches"),
|
||||
(session.user_id == user_id, "user_id matches"),
|
||||
(session.session_type == SessionType.SOLO, "session_type preserved"),
|
||||
(session.status == SessionStatus.ACTIVE, "status preserved"),
|
||||
]
|
||||
|
||||
all_passed = True
|
||||
for passed, desc in checks:
|
||||
print_result(desc, passed)
|
||||
if not passed:
|
||||
all_passed = False
|
||||
|
||||
# Test ownership validation
|
||||
try:
|
||||
service.get_session(session_id, "wrong_user_id")
|
||||
print_result("Ownership validation", False, "Should have raised SessionNotFound")
|
||||
all_passed = False
|
||||
except SessionNotFound:
|
||||
print_result("Ownership validation (wrong user rejected)", True)
|
||||
|
||||
return all_passed
|
||||
|
||||
except Exception as e:
|
||||
print_result("Session retrieval", False, str(e))
|
||||
return False
|
||||
|
||||
|
||||
def verify_conversation_history(service: SessionService, session_id: str) -> bool:
|
||||
"""
|
||||
Test 3: Add conversation entries and verify persistence.
|
||||
"""
|
||||
print_header("Test 3: Conversation History")
|
||||
|
||||
try:
|
||||
# Add first entry
|
||||
service.add_conversation_entry(
|
||||
session_id=session_id,
|
||||
character_id="char_test",
|
||||
character_name="Test Hero",
|
||||
action="I explore the town",
|
||||
dm_response="You find a bustling marketplace..."
|
||||
)
|
||||
|
||||
# Add second entry with quest
|
||||
service.add_conversation_entry(
|
||||
session_id=session_id,
|
||||
character_id="char_test",
|
||||
character_name="Test Hero",
|
||||
action="Talk to the merchant",
|
||||
dm_response="The merchant offers you a quest...",
|
||||
quest_offered={"quest_id": "test_quest", "name": "Test Quest"}
|
||||
)
|
||||
|
||||
# Retrieve and verify
|
||||
session = service.get_session(session_id)
|
||||
|
||||
checks = [
|
||||
(session.turn_number == 2, f"turn_number is 2 (got {session.turn_number})"),
|
||||
(len(session.conversation_history) == 2, f"2 entries in history (got {len(session.conversation_history)})"),
|
||||
(session.conversation_history[0].action == "I explore the town", "first action preserved"),
|
||||
(session.conversation_history[1].quest_offered is not None, "quest_offered preserved"),
|
||||
(session.conversation_history[0].timestamp != "", "timestamp auto-generated"),
|
||||
]
|
||||
|
||||
all_passed = True
|
||||
for passed, desc in checks:
|
||||
print_result(desc, passed)
|
||||
if not passed:
|
||||
all_passed = False
|
||||
|
||||
# Test get_recent_history
|
||||
recent = service.get_recent_history(session_id, num_turns=1)
|
||||
check = len(recent) == 1 and recent[0].turn == 2
|
||||
print_result("get_recent_history returns last entry", check)
|
||||
if not check:
|
||||
all_passed = False
|
||||
|
||||
return all_passed
|
||||
|
||||
except Exception as e:
|
||||
print_result("Conversation history", False, str(e))
|
||||
return False
|
||||
|
||||
|
||||
def verify_game_state_tracking(service: SessionService, session_id: str) -> bool:
|
||||
"""
|
||||
Test 4: Test location, quest, and event tracking.
|
||||
"""
|
||||
print_header("Test 4: Game State Tracking")
|
||||
|
||||
try:
|
||||
all_passed = True
|
||||
|
||||
# Update location
|
||||
service.update_location(
|
||||
session_id=session_id,
|
||||
new_location="Dark Forest",
|
||||
location_type=LocationType.WILDERNESS
|
||||
)
|
||||
|
||||
session = service.get_session(session_id)
|
||||
check = session.game_state.current_location == "Dark Forest"
|
||||
print_result("Location updated", check, f"Got: {session.game_state.current_location}")
|
||||
if not check:
|
||||
all_passed = False
|
||||
|
||||
check = session.game_state.location_type == LocationType.WILDERNESS
|
||||
print_result("Location type updated", check)
|
||||
if not check:
|
||||
all_passed = False
|
||||
|
||||
check = "Dark Forest" in session.game_state.discovered_locations
|
||||
print_result("New location added to discovered", check)
|
||||
if not check:
|
||||
all_passed = False
|
||||
|
||||
# Add quest
|
||||
service.add_active_quest(session_id, "quest_1")
|
||||
session = service.get_session(session_id)
|
||||
check = "quest_1" in session.game_state.active_quests
|
||||
print_result("Quest added to active_quests", check)
|
||||
if not check:
|
||||
all_passed = False
|
||||
|
||||
# Add second quest
|
||||
service.add_active_quest(session_id, "quest_2")
|
||||
session = service.get_session(session_id)
|
||||
check = len(session.game_state.active_quests) == 2
|
||||
print_result("Second quest added", check)
|
||||
if not check:
|
||||
all_passed = False
|
||||
|
||||
# Try to add third quest (should fail)
|
||||
try:
|
||||
service.add_active_quest(session_id, "quest_3")
|
||||
print_result("Max quest limit enforced", False, "Should have raised error")
|
||||
all_passed = False
|
||||
except SessionValidationError:
|
||||
print_result("Max quest limit enforced (2/2)", True)
|
||||
|
||||
# Remove quest
|
||||
service.remove_active_quest(session_id, "quest_1")
|
||||
session = service.get_session(session_id)
|
||||
check = "quest_1" not in session.game_state.active_quests
|
||||
print_result("Quest removed", check)
|
||||
if not check:
|
||||
all_passed = False
|
||||
|
||||
# Add world event
|
||||
service.add_world_event(session_id, {"type": "storm", "description": "A storm approaches"})
|
||||
session = service.get_session(session_id)
|
||||
check = len(session.game_state.world_events) == 1
|
||||
print_result("World event added", check)
|
||||
if not check:
|
||||
all_passed = False
|
||||
|
||||
check = "timestamp" in session.game_state.world_events[0]
|
||||
print_result("Event timestamp auto-added", check)
|
||||
if not check:
|
||||
all_passed = False
|
||||
|
||||
return all_passed
|
||||
|
||||
except Exception as e:
|
||||
print_result("Game state tracking", False, str(e))
|
||||
return False
|
||||
|
||||
|
||||
def verify_session_lifecycle(service: SessionService, session_id: str, user_id: str) -> bool:
|
||||
"""
|
||||
Test 5: Test session ending and status changes.
|
||||
"""
|
||||
print_header("Test 5: Session Lifecycle")
|
||||
|
||||
try:
|
||||
all_passed = True
|
||||
|
||||
# End session
|
||||
session = service.end_session(session_id, user_id)
|
||||
check = session.status == SessionStatus.COMPLETED
|
||||
print_result("Session status set to COMPLETED", check, f"Got: {session.status}")
|
||||
if not check:
|
||||
all_passed = False
|
||||
|
||||
# Verify persisted
|
||||
session = service.get_session(session_id)
|
||||
check = session.status == SessionStatus.COMPLETED
|
||||
print_result("Completed status persisted", check)
|
||||
if not check:
|
||||
all_passed = False
|
||||
|
||||
# Verify it's not counted as active
|
||||
count = service.count_user_sessions(user_id, active_only=True)
|
||||
# Note: This might include other sessions, so just check it works
|
||||
print_result(f"count_user_sessions works (active: {count})", True)
|
||||
|
||||
return all_passed
|
||||
|
||||
except Exception as e:
|
||||
print_result("Session lifecycle", False, str(e))
|
||||
return False
|
||||
|
||||
|
||||
def verify_error_handling(service: SessionService) -> bool:
|
||||
"""
|
||||
Test 6: Test error handling for invalid operations.
|
||||
"""
|
||||
print_header("Test 6: Error Handling")
|
||||
|
||||
all_passed = True
|
||||
|
||||
# Invalid session ID
|
||||
try:
|
||||
service.get_session("invalid_session_id_12345")
|
||||
print_result("Invalid session ID raises error", False)
|
||||
all_passed = False
|
||||
except SessionNotFound:
|
||||
print_result("Invalid session ID raises SessionNotFound", True)
|
||||
except Exception as e:
|
||||
print_result("Invalid session ID raises error", False, str(e))
|
||||
all_passed = False
|
||||
|
||||
return all_passed
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all verification tests."""
|
||||
print("\n" + "="*60)
|
||||
print(" Task 8.24: Session Persistence Verification")
|
||||
print("="*60)
|
||||
|
||||
# Check environment
|
||||
if not os.getenv('APPWRITE_ENDPOINT'):
|
||||
print("\n❌ ERROR: Appwrite not configured. Set APPWRITE_* in .env")
|
||||
return False
|
||||
|
||||
# Get test user and character
|
||||
print("\nSetup: Finding test character...")
|
||||
|
||||
char_service = get_character_service()
|
||||
|
||||
# Try to find an existing character for testing
|
||||
# You may need to adjust this based on your test data
|
||||
test_user_id = os.getenv('TEST_USER_ID', '')
|
||||
test_character_id = os.getenv('TEST_CHARACTER_ID', '')
|
||||
|
||||
if not test_user_id or not test_character_id:
|
||||
print("\n⚠️ No TEST_USER_ID or TEST_CHARACTER_ID in .env")
|
||||
print(" Will attempt to use mock IDs for basic testing")
|
||||
print(" For full integration test, set these environment variables")
|
||||
|
||||
# Use mock approach - only tests that don't need real DB
|
||||
print("\n" + "="*60)
|
||||
print(" Running Unit Test Verification Only")
|
||||
print("="*60)
|
||||
|
||||
# Run the pytest tests instead
|
||||
import subprocess
|
||||
result = subprocess.run(
|
||||
["python", "-m", "pytest", "tests/test_session_service.py", "-v", "--tb=short"],
|
||||
cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
return result.returncode == 0
|
||||
|
||||
# Initialize service
|
||||
service = get_session_service()
|
||||
|
||||
# Run tests
|
||||
results = []
|
||||
|
||||
# Test 1: Create session
|
||||
session_id = verify_session_creation(service, test_user_id, test_character_id)
|
||||
results.append(("Session Creation", session_id is not None))
|
||||
|
||||
if session_id:
|
||||
# Test 2: Retrieve session
|
||||
results.append(("Session Retrieval", verify_session_retrieval(service, session_id, test_user_id)))
|
||||
|
||||
# Test 3: Conversation history
|
||||
results.append(("Conversation History", verify_conversation_history(service, session_id)))
|
||||
|
||||
# Test 4: Game state tracking
|
||||
results.append(("Game State Tracking", verify_game_state_tracking(service, session_id)))
|
||||
|
||||
# Test 5: Session lifecycle
|
||||
results.append(("Session Lifecycle", verify_session_lifecycle(service, session_id, test_user_id)))
|
||||
|
||||
# Test 6: Error handling
|
||||
results.append(("Error Handling", verify_error_handling(service)))
|
||||
|
||||
# Summary
|
||||
print_header("Verification Summary")
|
||||
|
||||
passed = sum(1 for _, p in results if p)
|
||||
total = len(results)
|
||||
|
||||
for name, result in results:
|
||||
status = "✅" if result else "❌"
|
||||
print(f" {status} {name}")
|
||||
|
||||
print(f"\n Total: {passed}/{total} tests passed")
|
||||
|
||||
if passed == total:
|
||||
print("\n✅ All session persistence tests PASSED!")
|
||||
return True
|
||||
else:
|
||||
print("\n❌ Some tests FAILED. Check output above for details.")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
||||
Reference in New Issue
Block a user