239 lines
6.9 KiB
Python
239 lines
6.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Verification script for Task 7.8: Verify all AI models respond correctly.
|
|
|
|
This script tests:
|
|
1. Replicate client with Llama-3 8B
|
|
2. Replicate client with Claude Haiku
|
|
3. Replicate client with Claude Sonnet
|
|
4. Model selector tier routing
|
|
5. Token counting accuracy
|
|
6. Response quality comparison
|
|
|
|
Usage:
|
|
python scripts/verify_ai_models.py [--all] [--llama] [--haiku] [--sonnet] [--opus]
|
|
"""
|
|
|
|
import argparse
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
# Add parent directory to path for imports
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
|
# Load .env before importing app modules
|
|
from dotenv import load_dotenv
|
|
load_dotenv()
|
|
|
|
import structlog
|
|
from app.ai import (
|
|
ReplicateClient,
|
|
ModelType,
|
|
ModelSelector,
|
|
UserTier,
|
|
ContextType,
|
|
ReplicateClientError,
|
|
)
|
|
|
|
logger = structlog.get_logger(__name__)
|
|
|
|
# Test prompt for narrative generation
|
|
TEST_PROMPT = """You are a dungeon master. The player enters a dimly lit tavern.
|
|
Describe the scene in 2-3 sentences. Include at least one interesting NPC."""
|
|
|
|
TEST_SYSTEM_PROMPT = "You are a creative fantasy storyteller. Keep responses concise but vivid."
|
|
|
|
|
|
def test_model(model_type: ModelType, client: ReplicateClient | None = None) -> dict:
|
|
"""
|
|
Test a specific model and return results.
|
|
|
|
Args:
|
|
model_type: The model to test.
|
|
client: Optional existing client, otherwise creates new one.
|
|
|
|
Returns:
|
|
Dictionary with test results.
|
|
"""
|
|
model_name = model_type.name
|
|
print(f"\n{'='*60}")
|
|
print(f"Testing: {model_name}")
|
|
print(f"Model ID: {model_type.value}")
|
|
print(f"{'='*60}")
|
|
|
|
try:
|
|
if client is None:
|
|
client = ReplicateClient(model=model_type)
|
|
|
|
start_time = time.time()
|
|
response = client.generate(
|
|
prompt=TEST_PROMPT,
|
|
system_prompt=TEST_SYSTEM_PROMPT,
|
|
model=model_type
|
|
)
|
|
elapsed = time.time() - start_time
|
|
|
|
print(f"\n✅ SUCCESS")
|
|
print(f"Response time: {elapsed:.2f}s")
|
|
print(f"Tokens used: {response.tokens_used}")
|
|
print(f"Response length: {len(response.text)} chars")
|
|
print(f"\nGenerated text:")
|
|
print("-" * 40)
|
|
print(response.text[:500] + ("..." if len(response.text) > 500 else ""))
|
|
print("-" * 40)
|
|
|
|
return {
|
|
"model": model_name,
|
|
"success": True,
|
|
"response_time": elapsed,
|
|
"tokens_used": response.tokens_used,
|
|
"text_length": len(response.text),
|
|
"text_preview": response.text[:200]
|
|
}
|
|
|
|
except ReplicateClientError as e:
|
|
print(f"\n❌ FAILED: {e}")
|
|
return {
|
|
"model": model_name,
|
|
"success": False,
|
|
"error": str(e)
|
|
}
|
|
except Exception as e:
|
|
print(f"\n❌ UNEXPECTED ERROR: {e}")
|
|
return {
|
|
"model": model_name,
|
|
"success": False,
|
|
"error": str(e)
|
|
}
|
|
|
|
|
|
def test_model_selector():
|
|
"""Test the model selector tier routing."""
|
|
print(f"\n{'='*60}")
|
|
print("Testing Model Selector")
|
|
print(f"{'='*60}")
|
|
|
|
selector = ModelSelector()
|
|
|
|
test_cases = [
|
|
(UserTier.FREE, ContextType.STORY_PROGRESSION),
|
|
(UserTier.BASIC, ContextType.STORY_PROGRESSION),
|
|
(UserTier.PREMIUM, ContextType.STORY_PROGRESSION),
|
|
(UserTier.ELITE, ContextType.STORY_PROGRESSION),
|
|
(UserTier.PREMIUM, ContextType.QUEST_SELECTION),
|
|
(UserTier.PREMIUM, ContextType.COMBAT_NARRATION),
|
|
]
|
|
|
|
print("\nTier → Model Routing:")
|
|
print("-" * 40)
|
|
|
|
for tier, context in test_cases:
|
|
config = selector.select_model(tier, context)
|
|
info = selector.get_tier_info(tier)
|
|
cost = selector.estimate_cost_per_request(tier)
|
|
|
|
print(f"{tier.value:10} + {context.value:20} → {config.model_type.name:15} "
|
|
f"(tokens={config.max_tokens}, temp={config.temperature}, cost=${cost:.4f})")
|
|
|
|
print("\n✅ Model selector routing verified")
|
|
|
|
|
|
def run_verification(models_to_test: list[ModelType]):
|
|
"""
|
|
Run full verification suite.
|
|
|
|
Args:
|
|
models_to_test: List of models to test with real API calls.
|
|
"""
|
|
print("\n" + "=" * 60)
|
|
print("Phase 4 Task 7.8: AI Model Verification")
|
|
print("=" * 60)
|
|
|
|
results = []
|
|
|
|
# Test model selector first (no API calls)
|
|
test_model_selector()
|
|
|
|
if not models_to_test:
|
|
print("\nNo models selected for API testing.")
|
|
print("Use --llama, --haiku, --sonnet, --opus, or --all")
|
|
return
|
|
|
|
# Create a single client for efficiency
|
|
try:
|
|
client = ReplicateClient()
|
|
except ReplicateClientError as e:
|
|
print(f"\n❌ Failed to initialize Replicate client: {e}")
|
|
print("Check REPLICATE_API_TOKEN in .env")
|
|
return
|
|
|
|
# Test each selected model
|
|
for model_type in models_to_test:
|
|
result = test_model(model_type, client)
|
|
results.append(result)
|
|
|
|
# Summary
|
|
print("\n" + "=" * 60)
|
|
print("VERIFICATION SUMMARY")
|
|
print("=" * 60)
|
|
|
|
passed = sum(1 for r in results if r.get("success"))
|
|
failed = len(results) - passed
|
|
|
|
for result in results:
|
|
status = "✅" if result.get("success") else "❌"
|
|
model = result.get("model")
|
|
if result.get("success"):
|
|
time_s = result.get("response_time", 0)
|
|
tokens = result.get("tokens_used", 0)
|
|
print(f"{status} {model}: {time_s:.2f}s, {tokens} tokens")
|
|
else:
|
|
error = result.get("error", "Unknown error")
|
|
print(f"{status} {model}: {error[:50]}")
|
|
|
|
print(f"\nTotal: {passed} passed, {failed} failed")
|
|
|
|
if failed == 0:
|
|
print("\n✅ All verification checks passed!")
|
|
else:
|
|
print(f"\n⚠️ {failed} model(s) failed verification")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Verify AI models respond correctly through Replicate"
|
|
)
|
|
parser.add_argument("--all", action="store_true", help="Test all models")
|
|
parser.add_argument("--llama", action="store_true", help="Test Llama-3 8B")
|
|
parser.add_argument("--haiku", action="store_true", help="Test Claude Haiku")
|
|
parser.add_argument("--sonnet", action="store_true", help="Test Claude Sonnet")
|
|
parser.add_argument("--opus", action="store_true", help="Test Claude Opus")
|
|
|
|
args = parser.parse_args()
|
|
|
|
models_to_test = []
|
|
|
|
if args.all:
|
|
models_to_test = [
|
|
ModelType.LLAMA_3_8B,
|
|
ModelType.CLAUDE_HAIKU,
|
|
ModelType.CLAUDE_SONNET,
|
|
ModelType.CLAUDE_SONNET_4,
|
|
]
|
|
else:
|
|
if args.llama:
|
|
models_to_test.append(ModelType.LLAMA_3_8B)
|
|
if args.haiku:
|
|
models_to_test.append(ModelType.CLAUDE_HAIKU)
|
|
if args.sonnet:
|
|
models_to_test.append(ModelType.CLAUDE_SONNET)
|
|
if args.opus:
|
|
models_to_test.append(ModelType.CLAUDE_SONNET_4)
|
|
|
|
run_verification(models_to_test)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|