first commit
This commit is contained in:
238
api/scripts/verify_ai_models.py
Normal file
238
api/scripts/verify_ai_models.py
Normal file
@@ -0,0 +1,238 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Verification script for Task 7.8: Verify all AI models respond correctly.
|
||||
|
||||
This script tests:
|
||||
1. Replicate client with Llama-3 8B
|
||||
2. Replicate client with Claude Haiku
|
||||
3. Replicate client with Claude Sonnet
|
||||
4. Model selector tier routing
|
||||
5. Token counting accuracy
|
||||
6. Response quality comparison
|
||||
|
||||
Usage:
|
||||
python scripts/verify_ai_models.py [--all] [--llama] [--haiku] [--sonnet] [--opus]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
# Load .env before importing app modules
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
import structlog
|
||||
from app.ai import (
|
||||
ReplicateClient,
|
||||
ModelType,
|
||||
ModelSelector,
|
||||
UserTier,
|
||||
ContextType,
|
||||
ReplicateClientError,
|
||||
)
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
# Test prompt for narrative generation
|
||||
TEST_PROMPT = """You are a dungeon master. The player enters a dimly lit tavern.
|
||||
Describe the scene in 2-3 sentences. Include at least one interesting NPC."""
|
||||
|
||||
TEST_SYSTEM_PROMPT = "You are a creative fantasy storyteller. Keep responses concise but vivid."
|
||||
|
||||
|
||||
def test_model(model_type: ModelType, client: ReplicateClient | None = None) -> dict:
|
||||
"""
|
||||
Test a specific model and return results.
|
||||
|
||||
Args:
|
||||
model_type: The model to test.
|
||||
client: Optional existing client, otherwise creates new one.
|
||||
|
||||
Returns:
|
||||
Dictionary with test results.
|
||||
"""
|
||||
model_name = model_type.name
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Testing: {model_name}")
|
||||
print(f"Model ID: {model_type.value}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
try:
|
||||
if client is None:
|
||||
client = ReplicateClient(model=model_type)
|
||||
|
||||
start_time = time.time()
|
||||
response = client.generate(
|
||||
prompt=TEST_PROMPT,
|
||||
system_prompt=TEST_SYSTEM_PROMPT,
|
||||
model=model_type
|
||||
)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
print(f"\n✅ SUCCESS")
|
||||
print(f"Response time: {elapsed:.2f}s")
|
||||
print(f"Tokens used: {response.tokens_used}")
|
||||
print(f"Response length: {len(response.text)} chars")
|
||||
print(f"\nGenerated text:")
|
||||
print("-" * 40)
|
||||
print(response.text[:500] + ("..." if len(response.text) > 500 else ""))
|
||||
print("-" * 40)
|
||||
|
||||
return {
|
||||
"model": model_name,
|
||||
"success": True,
|
||||
"response_time": elapsed,
|
||||
"tokens_used": response.tokens_used,
|
||||
"text_length": len(response.text),
|
||||
"text_preview": response.text[:200]
|
||||
}
|
||||
|
||||
except ReplicateClientError as e:
|
||||
print(f"\n❌ FAILED: {e}")
|
||||
return {
|
||||
"model": model_name,
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"\n❌ UNEXPECTED ERROR: {e}")
|
||||
return {
|
||||
"model": model_name,
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
def test_model_selector():
|
||||
"""Test the model selector tier routing."""
|
||||
print(f"\n{'='*60}")
|
||||
print("Testing Model Selector")
|
||||
print(f"{'='*60}")
|
||||
|
||||
selector = ModelSelector()
|
||||
|
||||
test_cases = [
|
||||
(UserTier.FREE, ContextType.STORY_PROGRESSION),
|
||||
(UserTier.BASIC, ContextType.STORY_PROGRESSION),
|
||||
(UserTier.PREMIUM, ContextType.STORY_PROGRESSION),
|
||||
(UserTier.ELITE, ContextType.STORY_PROGRESSION),
|
||||
(UserTier.PREMIUM, ContextType.QUEST_SELECTION),
|
||||
(UserTier.PREMIUM, ContextType.COMBAT_NARRATION),
|
||||
]
|
||||
|
||||
print("\nTier → Model Routing:")
|
||||
print("-" * 40)
|
||||
|
||||
for tier, context in test_cases:
|
||||
config = selector.select_model(tier, context)
|
||||
info = selector.get_tier_info(tier)
|
||||
cost = selector.estimate_cost_per_request(tier)
|
||||
|
||||
print(f"{tier.value:10} + {context.value:20} → {config.model_type.name:15} "
|
||||
f"(tokens={config.max_tokens}, temp={config.temperature}, cost=${cost:.4f})")
|
||||
|
||||
print("\n✅ Model selector routing verified")
|
||||
|
||||
|
||||
def run_verification(models_to_test: list[ModelType]):
|
||||
"""
|
||||
Run full verification suite.
|
||||
|
||||
Args:
|
||||
models_to_test: List of models to test with real API calls.
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("Phase 4 Task 7.8: AI Model Verification")
|
||||
print("=" * 60)
|
||||
|
||||
results = []
|
||||
|
||||
# Test model selector first (no API calls)
|
||||
test_model_selector()
|
||||
|
||||
if not models_to_test:
|
||||
print("\nNo models selected for API testing.")
|
||||
print("Use --llama, --haiku, --sonnet, --opus, or --all")
|
||||
return
|
||||
|
||||
# Create a single client for efficiency
|
||||
try:
|
||||
client = ReplicateClient()
|
||||
except ReplicateClientError as e:
|
||||
print(f"\n❌ Failed to initialize Replicate client: {e}")
|
||||
print("Check REPLICATE_API_TOKEN in .env")
|
||||
return
|
||||
|
||||
# Test each selected model
|
||||
for model_type in models_to_test:
|
||||
result = test_model(model_type, client)
|
||||
results.append(result)
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("VERIFICATION SUMMARY")
|
||||
print("=" * 60)
|
||||
|
||||
passed = sum(1 for r in results if r.get("success"))
|
||||
failed = len(results) - passed
|
||||
|
||||
for result in results:
|
||||
status = "✅" if result.get("success") else "❌"
|
||||
model = result.get("model")
|
||||
if result.get("success"):
|
||||
time_s = result.get("response_time", 0)
|
||||
tokens = result.get("tokens_used", 0)
|
||||
print(f"{status} {model}: {time_s:.2f}s, {tokens} tokens")
|
||||
else:
|
||||
error = result.get("error", "Unknown error")
|
||||
print(f"{status} {model}: {error[:50]}")
|
||||
|
||||
print(f"\nTotal: {passed} passed, {failed} failed")
|
||||
|
||||
if failed == 0:
|
||||
print("\n✅ All verification checks passed!")
|
||||
else:
|
||||
print(f"\n⚠️ {failed} model(s) failed verification")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Verify AI models respond correctly through Replicate"
|
||||
)
|
||||
parser.add_argument("--all", action="store_true", help="Test all models")
|
||||
parser.add_argument("--llama", action="store_true", help="Test Llama-3 8B")
|
||||
parser.add_argument("--haiku", action="store_true", help="Test Claude Haiku")
|
||||
parser.add_argument("--sonnet", action="store_true", help="Test Claude Sonnet")
|
||||
parser.add_argument("--opus", action="store_true", help="Test Claude Opus")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
models_to_test = []
|
||||
|
||||
if args.all:
|
||||
models_to_test = [
|
||||
ModelType.LLAMA_3_8B,
|
||||
ModelType.CLAUDE_HAIKU,
|
||||
ModelType.CLAUDE_SONNET,
|
||||
ModelType.CLAUDE_SONNET_4,
|
||||
]
|
||||
else:
|
||||
if args.llama:
|
||||
models_to_test.append(ModelType.LLAMA_3_8B)
|
||||
if args.haiku:
|
||||
models_to_test.append(ModelType.CLAUDE_HAIKU)
|
||||
if args.sonnet:
|
||||
models_to_test.append(ModelType.CLAUDE_SONNET)
|
||||
if args.opus:
|
||||
models_to_test.append(ModelType.CLAUDE_SONNET_4)
|
||||
|
||||
run_verification(models_to_test)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user