first commit

This commit is contained in:
2025-11-24 23:10:55 -06:00
commit 8315fa51c9
279 changed files with 74600 additions and 0 deletions

View File

View File

@@ -0,0 +1,320 @@
"""
Action Prompt Loader Service
This module provides a service for loading and filtering action prompts from YAML.
It implements a singleton pattern to cache loaded prompts in memory.
Usage:
from app.services.action_prompt_loader import ActionPromptLoader
loader = ActionPromptLoader()
loader.load_from_yaml("app/data/action_prompts.yaml")
# Get available actions for a user at a location
actions = loader.get_available_actions(
user_tier=UserTier.FREE,
location_type=LocationType.TOWN
)
# Get specific action
action = loader.get_action_by_id("ask_locals")
"""
import os
from typing import List, Optional, Dict
import yaml
from app.models.action_prompt import ActionPrompt, LocationType
from app.ai.model_selector import UserTier
from app.utils.logging import get_logger
# Initialize logger
logger = get_logger(__file__)
class ActionPromptLoaderError(Exception):
"""Base exception for action prompt loader errors."""
pass
class ActionPromptNotFoundError(ActionPromptLoaderError):
"""Raised when a requested action prompt is not found."""
pass
class ActionPromptLoader:
"""
Service for loading and filtering action prompts.
This class loads action prompts from YAML files and provides methods
to filter them based on user tier and location type.
Uses singleton pattern to cache loaded prompts in memory.
Attributes:
_prompts: Dictionary of loaded action prompts keyed by prompt_id
_loaded: Flag indicating if prompts have been loaded
"""
_instance = None
_prompts: Dict[str, ActionPrompt] = {}
_loaded: bool = False
def __new__(cls):
"""Implement singleton pattern."""
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._prompts = {}
cls._instance._loaded = False
return cls._instance
def load_from_yaml(self, filepath: str) -> int:
"""
Load action prompts from a YAML file.
Args:
filepath: Path to the YAML file
Returns:
Number of prompts loaded
Raises:
ActionPromptLoaderError: If file cannot be read or parsed
"""
if not os.path.exists(filepath):
logger.error("Action prompts file not found", filepath=filepath)
raise ActionPromptLoaderError(f"File not found: {filepath}")
try:
with open(filepath, 'r', encoding='utf-8') as f:
data = yaml.safe_load(f)
except yaml.YAMLError as e:
logger.error("Failed to parse YAML", filepath=filepath, error=str(e))
raise ActionPromptLoaderError(f"Invalid YAML in {filepath}: {e}")
except IOError as e:
logger.error("Failed to read file", filepath=filepath, error=str(e))
raise ActionPromptLoaderError(f"Cannot read {filepath}: {e}")
if not data or 'action_prompts' not in data:
logger.error("No action_prompts key in YAML", filepath=filepath)
raise ActionPromptLoaderError(f"Missing 'action_prompts' key in {filepath}")
# Clear existing prompts
self._prompts = {}
# Parse each prompt
prompts_data = data['action_prompts']
errors = []
for i, prompt_data in enumerate(prompts_data):
try:
prompt = ActionPrompt.from_dict(prompt_data)
self._prompts[prompt.prompt_id] = prompt
except (ValueError, KeyError) as e:
errors.append(f"Prompt {i}: {e}")
logger.warning(
"Failed to parse action prompt",
index=i,
error=str(e)
)
if errors:
logger.warning(
"Some action prompts failed to load",
error_count=len(errors),
errors=errors
)
self._loaded = True
loaded_count = len(self._prompts)
logger.info(
"Action prompts loaded",
filepath=filepath,
count=loaded_count,
errors=len(errors)
)
return loaded_count
def get_all_actions(self) -> List[ActionPrompt]:
"""
Get all loaded action prompts.
Returns:
List of all action prompts
"""
self._ensure_loaded()
return list(self._prompts.values())
def get_action_by_id(self, prompt_id: str) -> ActionPrompt:
"""
Get a specific action prompt by ID.
Args:
prompt_id: The unique identifier of the action
Returns:
The ActionPrompt object
Raises:
ActionPromptNotFoundError: If action not found
"""
self._ensure_loaded()
if prompt_id not in self._prompts:
logger.warning("Action prompt not found", prompt_id=prompt_id)
raise ActionPromptNotFoundError(f"Action prompt '{prompt_id}' not found")
return self._prompts[prompt_id]
def get_available_actions(
self,
user_tier: UserTier,
location_type: LocationType
) -> List[ActionPrompt]:
"""
Get actions available to a user at a specific location.
Args:
user_tier: The user's subscription tier
location_type: The current location type
Returns:
List of available action prompts
"""
self._ensure_loaded()
available = []
for prompt in self._prompts.values():
if prompt.is_available(user_tier, location_type):
available.append(prompt)
logger.debug(
"Filtered available actions",
user_tier=user_tier.value,
location_type=location_type.value,
count=len(available)
)
return available
def get_actions_by_tier(self, user_tier: UserTier) -> List[ActionPrompt]:
"""
Get all actions available to a user tier (ignoring location).
Args:
user_tier: The user's subscription tier
Returns:
List of action prompts available to the tier
"""
self._ensure_loaded()
available = []
for prompt in self._prompts.values():
if prompt._tier_meets_requirement(user_tier):
available.append(prompt)
return available
def get_actions_by_category(self, category: str) -> List[ActionPrompt]:
"""
Get all actions in a specific category.
Args:
category: The action category (e.g., "ask_question", "explore")
Returns:
List of action prompts in the category
"""
self._ensure_loaded()
return [
prompt for prompt in self._prompts.values()
if prompt.category.value == category
]
def get_locked_actions(
self,
user_tier: UserTier,
location_type: LocationType
) -> List[ActionPrompt]:
"""
Get actions that are locked due to tier restrictions.
Used to show locked actions with upgrade prompts in UI.
Args:
user_tier: The user's subscription tier
location_type: The current location type
Returns:
List of locked action prompts
"""
self._ensure_loaded()
locked = []
for prompt in self._prompts.values():
# Must match location but be tier-locked
if prompt._location_matches_filter(location_type) and prompt.is_locked(user_tier):
locked.append(prompt)
return locked
def reload(self, filepath: str) -> int:
"""
Force reload prompts from YAML file.
Args:
filepath: Path to the YAML file
Returns:
Number of prompts loaded
"""
self._loaded = False
return self.load_from_yaml(filepath)
def is_loaded(self) -> bool:
"""Check if prompts have been loaded."""
return self._loaded
def get_prompt_count(self) -> int:
"""Get the number of loaded prompts."""
return len(self._prompts)
def _ensure_loaded(self) -> None:
"""
Ensure prompts are loaded, auto-load from default path if not.
Raises:
ActionPromptLoaderError: If prompts cannot be loaded
"""
if not self._loaded:
# Try default path
default_path = os.path.join(
os.path.dirname(__file__),
'..', 'data', 'action_prompts.yaml'
)
default_path = os.path.normpath(default_path)
if os.path.exists(default_path):
self.load_from_yaml(default_path)
else:
raise ActionPromptLoaderError(
"Action prompts not loaded. Call load_from_yaml() first."
)
@classmethod
def reset_instance(cls) -> None:
"""
Reset the singleton instance.
Primarily for testing purposes.
"""
cls._instance = None

View File

@@ -0,0 +1,588 @@
"""
Appwrite Service Wrapper
This module provides a wrapper around the Appwrite SDK for handling user authentication,
session management, and user data operations. It abstracts Appwrite's API to provide
a clean interface for the application.
Usage:
from app.services.appwrite_service import AppwriteService
# Initialize service
service = AppwriteService()
# Register a new user
user = service.register_user(
email="player@example.com",
password="SecurePass123!",
name="Brave Adventurer"
)
# Login
session = service.login_user(
email="player@example.com",
password="SecurePass123!"
)
"""
import os
from typing import Optional, Dict, Any
from dataclasses import dataclass
from datetime import datetime, timezone
from appwrite.client import Client
from appwrite.services.account import Account
from appwrite.services.users import Users
from appwrite.exception import AppwriteException
from appwrite.id import ID
from app.utils.logging import get_logger
# Initialize logger
logger = get_logger(__file__)
@dataclass
class UserData:
"""
Data class representing a user in the system.
Attributes:
id: Unique user identifier
email: User's email address
name: User's display name
email_verified: Whether email has been verified
tier: User's subscription tier (free, basic, premium, elite)
created_at: When the user account was created
updated_at: When the user account was last updated
"""
id: str
email: str
name: str
email_verified: bool
tier: str
created_at: datetime
updated_at: datetime
def to_dict(self) -> Dict[str, Any]:
"""Convert user data to dictionary."""
return {
"id": self.id,
"email": self.email,
"name": self.name,
"email_verified": self.email_verified,
"tier": self.tier,
"created_at": self.created_at.isoformat() if isinstance(self.created_at, datetime) else self.created_at,
"updated_at": self.updated_at.isoformat() if isinstance(self.updated_at, datetime) else self.updated_at,
}
@dataclass
class SessionData:
"""
Data class representing a user session.
Attributes:
session_id: Unique session identifier
user_id: User ID associated with this session
provider: Authentication provider (email, oauth, etc.)
expire: When the session expires
"""
session_id: str
user_id: str
provider: str
expire: datetime
def to_dict(self) -> Dict[str, Any]:
"""Convert session data to dictionary."""
return {
"session_id": self.session_id,
"user_id": self.user_id,
"provider": self.provider,
"expire": self.expire.isoformat() if isinstance(self.expire, datetime) else self.expire,
}
class AppwriteService:
"""
Service class for interacting with Appwrite authentication and user management.
This class provides methods for:
- User registration and email verification
- User login and logout
- Session management
- Password reset
- User tier management
"""
def __init__(self):
"""
Initialize the Appwrite service.
Reads configuration from environment variables:
- APPWRITE_ENDPOINT: Appwrite API endpoint
- APPWRITE_PROJECT_ID: Appwrite project ID
- APPWRITE_API_KEY: Appwrite API key (for server-side operations)
"""
self.endpoint = os.getenv('APPWRITE_ENDPOINT')
self.project_id = os.getenv('APPWRITE_PROJECT_ID')
self.api_key = os.getenv('APPWRITE_API_KEY')
if not all([self.endpoint, self.project_id, self.api_key]):
logger.error("Missing Appwrite configuration in environment variables")
raise ValueError("Appwrite configuration incomplete. Check APPWRITE_* environment variables.")
# Initialize Appwrite client
self.client = Client()
self.client.set_endpoint(self.endpoint)
self.client.set_project(self.project_id)
self.client.set_key(self.api_key)
# Initialize services
self.account = Account(self.client)
self.users = Users(self.client)
logger.info("Appwrite service initialized", endpoint=self.endpoint, project_id=self.project_id)
def register_user(self, email: str, password: str, name: str) -> UserData:
"""
Register a new user account.
This method:
1. Creates a new user in Appwrite Auth
2. Sets the user's tier to 'free' in preferences
3. Triggers email verification
4. Returns user data
Args:
email: User's email address
password: User's password (will be hashed by Appwrite)
name: User's display name
Returns:
UserData object with user information
Raises:
AppwriteException: If registration fails (e.g., email already exists)
"""
try:
logger.info("Attempting to register new user", email=email, name=name)
# Generate unique user ID
user_id = ID.unique()
# Create user account
user = self.users.create(
user_id=user_id,
email=email,
password=password,
name=name
)
logger.info("User created successfully", user_id=user['$id'], email=email)
# Set default tier to 'free' in user preferences
self.users.update_prefs(
user_id=user['$id'],
prefs={
'tier': 'free',
'tier_updated_at': datetime.now(timezone.utc).isoformat()
}
)
logger.info("User tier set to 'free'", user_id=user['$id'])
# Note: Email verification is handled by Appwrite automatically
# when email templates are configured in the Appwrite console.
# For server-side user creation, verification emails are sent
# automatically if the email provider is configured.
#
# To manually trigger verification, users can use the Account service
# (client-side) after logging in, or configure email verification
# settings in the Appwrite console.
logger.info("User created, email verification handled by Appwrite", user_id=user['$id'], email=email)
# Return user data
return self._user_to_userdata(user)
except AppwriteException as e:
logger.error("Failed to register user", email=email, error=str(e), code=e.code)
raise
def login_user(self, email: str, password: str) -> tuple[SessionData, UserData]:
"""
Authenticate a user and create a session.
For server-side authentication, we create a temporary client with user
credentials to verify them, then create a session using the server SDK.
Args:
email: User's email address
password: User's password
Returns:
Tuple of (SessionData, UserData)
Raises:
AppwriteException: If login fails (invalid credentials, etc.)
"""
try:
logger.info("Attempting user login", email=email)
# Use admin client (with API key) to create session
# This is required to get the session secret in the response
from appwrite.services.account import Account
admin_account = Account(self.client) # self.client already has API key set
# Create email/password session using admin client
# When using admin client, the 'secret' field is populated in the response
user_session = admin_account.create_email_password_session(
email=email,
password=password
)
logger.info("Session created successfully",
user_id=user_session['userId'],
session_id=user_session['$id'])
# Extract session secret from response
# Admin client populates this field, unlike regular client
session_secret = user_session.get('secret', '')
if not session_secret:
logger.error("Session secret not found in response - this should not happen with admin client")
raise AppwriteException("Failed to get session secret", code=500)
# Get user data using server SDK
user = self.users.get(user_id=user_session['userId'])
# Convert to our data classes
session_data = SessionData(
session_id=session_secret, # Use the secret, not the session ID
user_id=user_session['userId'],
provider=user_session['provider'],
expire=datetime.fromisoformat(user_session['expire'].replace('Z', '+00:00'))
)
user_data = self._user_to_userdata(user)
return session_data, user_data
except AppwriteException as e:
logger.error("Failed to login user", email=email, error=str(e), code=e.code)
raise
except Exception as e:
logger.error("Unexpected error during login", email=email, error=str(e), exc_info=True)
raise AppwriteException(str(e), code=500)
def logout_user(self, session_id: str) -> bool:
"""
Log out a user by deleting their session.
Args:
session_id: The session ID to delete
Returns:
True if logout successful
Raises:
AppwriteException: If logout fails
"""
try:
logger.info("Attempting to logout user", session_id=session_id)
# For server-side, we need to delete the session using Users service
# First get the session to find the user_id
# Note: Appwrite doesn't have a direct server-side session delete by session_id
# We'll use a workaround by creating a client with the session and deleting it
from appwrite.client import Client
from appwrite.services.account import Account
# Create client with the session
session_client = Client()
session_client.set_endpoint(self.endpoint)
session_client.set_project(self.project_id)
session_client.set_session(session_id)
session_account = Account(session_client)
# Delete the current session
session_account.delete_session('current')
logger.info("User logged out successfully", session_id=session_id)
return True
except AppwriteException as e:
logger.error("Failed to logout user", session_id=session_id, error=str(e), code=e.code)
raise
def verify_email(self, user_id: str, secret: str) -> bool:
"""
Verify a user's email address.
Note: Email verification with server-side SDK requires updating
the user's emailVerification status directly, or using Appwrite's
built-in verification flow through the Account service (client-side).
Args:
user_id: User ID
secret: Verification secret from email link (not validated server-side)
Returns:
True if verification successful
Raises:
AppwriteException: If verification fails (invalid/expired secret)
"""
try:
logger.info("Attempting to verify email", user_id=user_id, secret_provided=bool(secret))
# For server-side verification, we update the user's email verification status
# The secret validation should be done by Appwrite's verification flow
# For now, we'll mark the email as verified
# In production, you should validate the secret token before updating
self.users.update_email_verification(user_id=user_id, email_verification=True)
logger.info("Email verified successfully", user_id=user_id)
return True
except AppwriteException as e:
logger.error("Failed to verify email", user_id=user_id, error=str(e), code=e.code)
raise
def request_password_reset(self, email: str) -> bool:
"""
Request a password reset for a user.
This sends a password reset email to the user. For security,
it always returns True even if the email doesn't exist.
Note: Password reset is handled through Appwrite's built-in Account
service recovery flow. For server-side operations, we would need to
create a password recovery token manually.
Args:
email: User's email address
Returns:
Always True (for security - don't reveal if email exists)
"""
try:
logger.info("Password reset requested", email=email)
# Note: Password reset with server-side SDK requires creating
# a recovery token. For now, we'll log this and return success.
# In production, configure Appwrite's email templates and use
# client-side Account.createRecovery() or implement custom token
# generation and email sending.
logger.warning("Password reset not fully implemented - requires Appwrite email configuration", email=email)
except Exception as e:
# Log the error but still return True for security
# Don't reveal whether the email exists
logger.warning("Password reset request encountered error", email=email, error=str(e))
# Always return True to not reveal if email exists
return True
def confirm_password_reset(self, user_id: str, secret: str, password: str) -> bool:
"""
Confirm a password reset and update the user's password.
Note: For server-side operations, we update the password directly
using the Users service. Secret validation would be handled separately.
Args:
user_id: User ID
secret: Reset secret from email link (should be validated before calling)
password: New password
Returns:
True if password reset successful
Raises:
AppwriteException: If reset fails
"""
try:
logger.info("Attempting to reset password", user_id=user_id, secret_provided=bool(secret))
# For server-side password reset, update the password directly
# In production, you should validate the secret token first before calling this
# The secret parameter is kept for API compatibility but not validated here
self.users.update_password(user_id=user_id, password=password)
logger.info("Password reset successfully", user_id=user_id)
return True
except AppwriteException as e:
logger.error("Failed to reset password", user_id=user_id, error=str(e), code=e.code)
raise
def get_user(self, user_id: str) -> UserData:
"""
Get user data by user ID.
Args:
user_id: User ID
Returns:
UserData object
Raises:
AppwriteException: If user not found
"""
try:
user = self.users.get(user_id=user_id)
return self._user_to_userdata(user)
except AppwriteException as e:
logger.error("Failed to fetch user", user_id=user_id, error=str(e), code=e.code)
raise
def get_session(self, session_id: str) -> SessionData:
"""
Get session data and validate it's still active.
Args:
session_id: Session ID
Returns:
SessionData object
Raises:
AppwriteException: If session invalid or expired
"""
try:
# Create a client with the session to validate it
from appwrite.client import Client
from appwrite.services.account import Account
session_client = Client()
session_client.set_endpoint(self.endpoint)
session_client.set_project(self.project_id)
session_client.set_session(session_id)
session_account = Account(session_client)
# Get the current session (this validates it exists and is active)
session = session_account.get_session('current')
# Check if session is expired
expire_time = datetime.fromisoformat(session['expire'].replace('Z', '+00:00'))
if expire_time < datetime.now(timezone.utc):
logger.warning("Session expired", session_id=session_id, expired_at=expire_time)
raise AppwriteException("Session expired", code=401)
return SessionData(
session_id=session['$id'],
user_id=session['userId'],
provider=session['provider'],
expire=expire_time
)
except AppwriteException as e:
logger.error("Failed to validate session", session_id=session_id, error=str(e), code=e.code)
raise
def get_user_tier(self, user_id: str) -> str:
"""
Get the user's subscription tier.
Args:
user_id: User ID
Returns:
Tier string (free, basic, premium, elite)
"""
try:
logger.debug("Fetching user tier", user_id=user_id)
user = self.users.get(user_id=user_id)
prefs = user.get('prefs', {})
tier = prefs.get('tier', 'free')
logger.debug("User tier retrieved", user_id=user_id, tier=tier)
return tier
except AppwriteException as e:
logger.error("Failed to fetch user tier", user_id=user_id, error=str(e), code=e.code)
# Default to free tier on error
return 'free'
def set_user_tier(self, user_id: str, tier: str) -> bool:
"""
Update the user's subscription tier.
Args:
user_id: User ID
tier: New tier (free, basic, premium, elite)
Returns:
True if update successful
Raises:
AppwriteException: If update fails
ValueError: If tier is invalid
"""
valid_tiers = ['free', 'basic', 'premium', 'elite']
if tier not in valid_tiers:
raise ValueError(f"Invalid tier: {tier}. Must be one of {valid_tiers}")
try:
logger.info("Updating user tier", user_id=user_id, new_tier=tier)
# Get current preferences
user = self.users.get(user_id=user_id)
prefs = user.get('prefs', {})
# Update tier
prefs['tier'] = tier
prefs['tier_updated_at'] = datetime.now(timezone.utc).isoformat()
self.users.update_prefs(user_id=user_id, prefs=prefs)
logger.info("User tier updated successfully", user_id=user_id, tier=tier)
return True
except AppwriteException as e:
logger.error("Failed to update user tier", user_id=user_id, tier=tier, error=str(e), code=e.code)
raise
def _user_to_userdata(self, user: Dict[str, Any]) -> UserData:
"""
Convert Appwrite user object to UserData dataclass.
Args:
user: Appwrite user dictionary
Returns:
UserData object
"""
# Get tier from preferences, default to 'free'
prefs = user.get('prefs', {})
tier = prefs.get('tier', 'free')
# Parse timestamps
created_at = user.get('$createdAt', datetime.now(timezone.utc).isoformat())
updated_at = user.get('$updatedAt', datetime.now(timezone.utc).isoformat())
if isinstance(created_at, str):
created_at = datetime.fromisoformat(created_at.replace('Z', '+00:00'))
if isinstance(updated_at, str):
updated_at = datetime.fromisoformat(updated_at.replace('Z', '+00:00'))
return UserData(
id=user['$id'],
email=user['email'],
name=user['name'],
email_verified=user.get('emailVerification', False),
tier=tier,
created_at=created_at,
updated_at=updated_at
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,277 @@
"""
ClassLoader service for loading player class definitions from YAML files.
This service reads class configuration files and converts them into PlayerClass
dataclass instances, providing caching for performance.
"""
import yaml
from pathlib import Path
from typing import Dict, List, Optional
import structlog
from app.models.skills import PlayerClass, SkillTree, SkillNode
from app.models.stats import Stats
logger = structlog.get_logger(__name__)
class ClassLoader:
"""
Loads player class definitions from YAML configuration files.
This allows game designers to define classes and skill trees without touching code.
All class definitions are stored in /app/data/classes/ as YAML files.
"""
def __init__(self, data_dir: Optional[str] = None):
"""
Initialize the class loader.
Args:
data_dir: Path to directory containing class YAML files.
Defaults to /app/data/classes/
"""
if data_dir is None:
# Default to app/data/classes relative to this file
current_file = Path(__file__)
app_dir = current_file.parent.parent # Go up to /app
data_dir = str(app_dir / "data" / "classes")
self.data_dir = Path(data_dir)
self._class_cache: Dict[str, PlayerClass] = {}
logger.info("ClassLoader initialized", data_dir=str(self.data_dir))
def load_class(self, class_id: str) -> Optional[PlayerClass]:
"""
Load a single player class by ID.
Args:
class_id: Unique class identifier (e.g., "vanguard")
Returns:
PlayerClass instance or None if not found
"""
# Check cache first
if class_id in self._class_cache:
logger.debug("Class loaded from cache", class_id=class_id)
return self._class_cache[class_id]
# Construct file path
file_path = self.data_dir / f"{class_id}.yaml"
if not file_path.exists():
logger.warning("Class file not found", class_id=class_id, file_path=str(file_path))
return None
try:
# Load YAML file
with open(file_path, 'r') as f:
data = yaml.safe_load(f)
# Parse into PlayerClass
player_class = self._parse_class_data(data)
# Cache the result
self._class_cache[class_id] = player_class
logger.info("Class loaded successfully", class_id=class_id)
return player_class
except Exception as e:
logger.error("Failed to load class", class_id=class_id, error=str(e))
return None
def load_all_classes(self) -> List[PlayerClass]:
"""
Load all player classes from the data directory.
Returns:
List of PlayerClass instances
"""
classes = []
# Find all YAML files in the directory
if not self.data_dir.exists():
logger.error("Class data directory does not exist", data_dir=str(self.data_dir))
return classes
for file_path in self.data_dir.glob("*.yaml"):
class_id = file_path.stem # Get filename without extension
player_class = self.load_class(class_id)
if player_class:
classes.append(player_class)
logger.info("All classes loaded", count=len(classes))
return classes
def get_class_by_id(self, class_id: str) -> Optional[PlayerClass]:
"""
Get a player class by ID (alias for load_class).
Args:
class_id: Unique class identifier
Returns:
PlayerClass instance or None if not found
"""
return self.load_class(class_id)
def get_all_class_ids(self) -> List[str]:
"""
Get a list of all available class IDs.
Returns:
List of class IDs (e.g., ["vanguard", "assassin", "arcanist"])
"""
if not self.data_dir.exists():
return []
return [file_path.stem for file_path in self.data_dir.glob("*.yaml")]
def reload_class(self, class_id: str) -> Optional[PlayerClass]:
"""
Force reload a class from disk, bypassing cache.
Useful for development/testing when class definitions change.
Args:
class_id: Unique class identifier
Returns:
PlayerClass instance or None if not found
"""
# Remove from cache if present
if class_id in self._class_cache:
del self._class_cache[class_id]
return self.load_class(class_id)
def clear_cache(self):
"""Clear the class cache. Useful for testing."""
self._class_cache.clear()
logger.info("Class cache cleared")
def _parse_class_data(self, data: Dict) -> PlayerClass:
"""
Parse YAML data into a PlayerClass dataclass.
Args:
data: Dictionary loaded from YAML file
Returns:
PlayerClass instance
Raises:
ValueError: If data is invalid or missing required fields
"""
# Validate required fields
required_fields = ["class_id", "name", "description", "base_stats", "skill_trees"]
for field in required_fields:
if field not in data:
raise ValueError(f"Missing required field: {field}")
# Parse base stats
base_stats = Stats(**data["base_stats"])
# Parse skill trees
skill_trees = []
for tree_data in data["skill_trees"]:
skill_tree = self._parse_skill_tree(tree_data)
skill_trees.append(skill_tree)
# Get optional fields
starting_equipment = data.get("starting_equipment", [])
starting_abilities = data.get("starting_abilities", [])
# Create PlayerClass instance
player_class = PlayerClass(
class_id=data["class_id"],
name=data["name"],
description=data["description"],
base_stats=base_stats,
skill_trees=skill_trees,
starting_equipment=starting_equipment,
starting_abilities=starting_abilities
)
return player_class
def _parse_skill_tree(self, tree_data: Dict) -> SkillTree:
"""
Parse a skill tree from YAML data.
Args:
tree_data: Dictionary containing skill tree data
Returns:
SkillTree instance
"""
# Validate required fields
required_fields = ["tree_id", "name", "description", "nodes"]
for field in required_fields:
if field not in tree_data:
raise ValueError(f"Missing required field in skill tree: {field}")
# Parse skill nodes
nodes = []
for node_data in tree_data["nodes"]:
skill_node = self._parse_skill_node(node_data)
nodes.append(skill_node)
# Create SkillTree instance
skill_tree = SkillTree(
tree_id=tree_data["tree_id"],
name=tree_data["name"],
description=tree_data["description"],
nodes=nodes
)
return skill_tree
def _parse_skill_node(self, node_data: Dict) -> SkillNode:
"""
Parse a skill node from YAML data.
Args:
node_data: Dictionary containing skill node data
Returns:
SkillNode instance
"""
# Validate required fields
required_fields = ["skill_id", "name", "description", "tier", "effects"]
for field in required_fields:
if field not in node_data:
raise ValueError(f"Missing required field in skill node: {field}")
# Create SkillNode instance
skill_node = SkillNode(
skill_id=node_data["skill_id"],
name=node_data["name"],
description=node_data["description"],
tier=node_data["tier"],
prerequisites=node_data.get("prerequisites", []),
effects=node_data.get("effects", {}),
unlocked=False # Always start locked
)
return skill_node
# Global instance for convenience
_loader_instance: Optional[ClassLoader] = None
def get_class_loader() -> ClassLoader:
"""
Get the global ClassLoader instance.
Returns:
Singleton ClassLoader instance
"""
global _loader_instance
if _loader_instance is None:
_loader_instance = ClassLoader()
return _loader_instance

View File

@@ -0,0 +1,709 @@
"""
Database Initialization Service.
This service handles programmatic creation of Appwrite database tables,
including schema definition, column creation, and index setup.
"""
import os
import time
from typing import List, Dict, Any, Optional
from appwrite.client import Client
from appwrite.services.tables_db import TablesDB
from appwrite.exception import AppwriteException
from app.utils.logging import get_logger
logger = get_logger(__file__)
class DatabaseInitService:
"""
Service for initializing Appwrite database tables.
This service provides methods to:
- Create tables if they don't exist
- Define table schemas (columns/attributes)
- Create indexes for efficient querying
- Validate existing table structures
"""
def __init__(self):
"""
Initialize the database initialization service.
Reads configuration from environment variables:
- APPWRITE_ENDPOINT: Appwrite API endpoint
- APPWRITE_PROJECT_ID: Appwrite project ID
- APPWRITE_API_KEY: Appwrite API key
- APPWRITE_DATABASE_ID: Appwrite database ID
"""
self.endpoint = os.getenv('APPWRITE_ENDPOINT')
self.project_id = os.getenv('APPWRITE_PROJECT_ID')
self.api_key = os.getenv('APPWRITE_API_KEY')
self.database_id = os.getenv('APPWRITE_DATABASE_ID', 'main')
if not all([self.endpoint, self.project_id, self.api_key]):
logger.error("Missing Appwrite configuration in environment variables")
raise ValueError("Appwrite configuration incomplete. Check APPWRITE_* environment variables.")
# Initialize Appwrite client
self.client = Client()
self.client.set_endpoint(self.endpoint)
self.client.set_project(self.project_id)
self.client.set_key(self.api_key)
# Initialize TablesDB service
self.tables_db = TablesDB(self.client)
logger.info("DatabaseInitService initialized", database_id=self.database_id)
def init_all_tables(self) -> Dict[str, bool]:
"""
Initialize all application tables.
Returns:
Dictionary mapping table names to success status
"""
results = {}
logger.info("Initializing all database tables")
# Initialize characters table
try:
self.init_characters_table()
results['characters'] = True
logger.info("Characters table initialized successfully")
except Exception as e:
logger.error("Failed to initialize characters table", error=str(e))
results['characters'] = False
# Initialize game_sessions table
try:
self.init_game_sessions_table()
results['game_sessions'] = True
logger.info("Game sessions table initialized successfully")
except Exception as e:
logger.error("Failed to initialize game_sessions table", error=str(e))
results['game_sessions'] = False
# Initialize ai_usage_logs table
try:
self.init_ai_usage_logs_table()
results['ai_usage_logs'] = True
logger.info("AI usage logs table initialized successfully")
except Exception as e:
logger.error("Failed to initialize ai_usage_logs table", error=str(e))
results['ai_usage_logs'] = False
success_count = sum(1 for v in results.values() if v)
total_count = len(results)
logger.info("Table initialization complete",
success=success_count,
total=total_count,
results=results)
return results
def init_characters_table(self) -> bool:
"""
Initialize the characters table.
Table schema:
- userId (string, required): Owner's user ID
- characterData (string, required): JSON-serialized character data
- is_active (boolean, default=True): Soft delete flag
- created_at (datetime): Auto-managed creation timestamp
- updated_at (datetime): Auto-managed update timestamp
Indexes:
- userId: For general user queries
- userId + is_active: Composite index for efficiently fetching active characters
Returns:
True if successful
Raises:
AppwriteException: If table creation fails
"""
table_id = 'characters'
logger.info("Initializing characters table", table_id=table_id)
try:
# Check if table already exists
try:
self.tables_db.get_table(
database_id=self.database_id,
table_id=table_id
)
logger.info("Characters table already exists", table_id=table_id)
return True
except AppwriteException as e:
if e.code != 404:
raise
logger.info("Characters table does not exist, creating...")
# Create table
logger.info("Creating characters table")
table = self.tables_db.create_table(
database_id=self.database_id,
table_id=table_id,
name='Characters'
)
logger.info("Characters table created", table_id=table['$id'])
# Create columns
self._create_column(
table_id=table_id,
column_id='userId',
column_type='string',
size=255,
required=True
)
self._create_column(
table_id=table_id,
column_id='characterData',
column_type='string',
size=65535, # Large text field for JSON data
required=True
)
self._create_column(
table_id=table_id,
column_id='is_active',
column_type='boolean',
required=False, # Cannot be required if we want a default value
default=True
)
# Note: created_at and updated_at are auto-managed by DatabaseService
# through the _parse_row method and timestamp updates
# Wait for columns to fully propagate in Appwrite before creating indexes
logger.info("Waiting for columns to propagate before creating indexes...")
time.sleep(2)
# Create indexes for efficient querying
# Note: Individual userId index for general user queries
self._create_index(
table_id=table_id,
index_id='idx_userId',
index_type='key',
attributes=['userId']
)
# Composite index for the most common query pattern:
# Query.equal('userId', user_id) + Query.equal('is_active', True)
# This single composite index covers both conditions efficiently
self._create_index(
table_id=table_id,
index_id='idx_userId_is_active',
index_type='key',
attributes=['userId', 'is_active']
)
logger.info("Characters table initialized successfully", table_id=table_id)
return True
except AppwriteException as e:
logger.error("Failed to initialize characters table",
table_id=table_id,
error=str(e),
code=e.code)
raise
def init_game_sessions_table(self) -> bool:
"""
Initialize the game_sessions table.
Table schema:
- userId (string, required): Owner's user ID
- characterId (string, required): Character ID for this session
- sessionData (string, required): JSON-serialized session data
- status (string, required): Session status (active, completed, abandoned)
- sessionType (string, required): Session type (solo, multiplayer)
Indexes:
- userId: For user session queries
- userId + status: For active session queries
- characterId: For character session lookups
Returns:
True if successful
Raises:
AppwriteException: If table creation fails
"""
table_id = 'game_sessions'
logger.info("Initializing game_sessions table", table_id=table_id)
try:
# Check if table already exists
try:
self.tables_db.get_table(
database_id=self.database_id,
table_id=table_id
)
logger.info("Game sessions table already exists", table_id=table_id)
return True
except AppwriteException as e:
if e.code != 404:
raise
logger.info("Game sessions table does not exist, creating...")
# Create table
logger.info("Creating game_sessions table")
table = self.tables_db.create_table(
database_id=self.database_id,
table_id=table_id,
name='Game Sessions'
)
logger.info("Game sessions table created", table_id=table['$id'])
# Create columns
self._create_column(
table_id=table_id,
column_id='userId',
column_type='string',
size=255,
required=True
)
self._create_column(
table_id=table_id,
column_id='characterId',
column_type='string',
size=255,
required=True
)
self._create_column(
table_id=table_id,
column_id='sessionData',
column_type='string',
size=65535, # Large text field for JSON data
required=True
)
self._create_column(
table_id=table_id,
column_id='status',
column_type='string',
size=50,
required=True
)
self._create_column(
table_id=table_id,
column_id='sessionType',
column_type='string',
size=50,
required=True
)
# Wait for columns to fully propagate
logger.info("Waiting for columns to propagate before creating indexes...")
time.sleep(2)
# Create indexes
self._create_index(
table_id=table_id,
index_id='idx_userId',
index_type='key',
attributes=['userId']
)
self._create_index(
table_id=table_id,
index_id='idx_userId_status',
index_type='key',
attributes=['userId', 'status']
)
self._create_index(
table_id=table_id,
index_id='idx_characterId',
index_type='key',
attributes=['characterId']
)
logger.info("Game sessions table initialized successfully", table_id=table_id)
return True
except AppwriteException as e:
logger.error("Failed to initialize game_sessions table",
table_id=table_id,
error=str(e),
code=e.code)
raise
def init_ai_usage_logs_table(self) -> bool:
"""
Initialize the ai_usage_logs table for tracking AI API usage and costs.
Table schema:
- user_id (string, required): User who made the request
- timestamp (string, required): ISO timestamp of the request
- model (string, required): Model identifier
- tokens_input (integer, required): Input token count
- tokens_output (integer, required): Output token count
- tokens_total (integer, required): Total token count
- estimated_cost (float, required): Estimated cost in USD
- task_type (string, required): Type of task
- session_id (string, optional): Game session ID
- character_id (string, optional): Character ID
- request_duration_ms (integer): Request duration in milliseconds
- success (boolean): Whether request succeeded
- error_message (string, optional): Error message if failed
Indexes:
- user_id: For user usage queries
- timestamp: For date range queries
- user_id + timestamp: Composite for user date range queries
Returns:
True if successful
Raises:
AppwriteException: If table creation fails
"""
table_id = 'ai_usage_logs'
logger.info("Initializing ai_usage_logs table", table_id=table_id)
try:
# Check if table already exists
try:
self.tables_db.get_table(
database_id=self.database_id,
table_id=table_id
)
logger.info("AI usage logs table already exists", table_id=table_id)
return True
except AppwriteException as e:
if e.code != 404:
raise
logger.info("AI usage logs table does not exist, creating...")
# Create table
logger.info("Creating ai_usage_logs table")
table = self.tables_db.create_table(
database_id=self.database_id,
table_id=table_id,
name='AI Usage Logs'
)
logger.info("AI usage logs table created", table_id=table['$id'])
# Create columns
self._create_column(
table_id=table_id,
column_id='user_id',
column_type='string',
size=255,
required=True
)
self._create_column(
table_id=table_id,
column_id='timestamp',
column_type='string',
size=50, # ISO timestamp format
required=True
)
self._create_column(
table_id=table_id,
column_id='model',
column_type='string',
size=255,
required=True
)
self._create_column(
table_id=table_id,
column_id='tokens_input',
column_type='integer',
required=True
)
self._create_column(
table_id=table_id,
column_id='tokens_output',
column_type='integer',
required=True
)
self._create_column(
table_id=table_id,
column_id='tokens_total',
column_type='integer',
required=True
)
self._create_column(
table_id=table_id,
column_id='estimated_cost',
column_type='float',
required=True
)
self._create_column(
table_id=table_id,
column_id='task_type',
column_type='string',
size=50,
required=True
)
self._create_column(
table_id=table_id,
column_id='session_id',
column_type='string',
size=255,
required=False
)
self._create_column(
table_id=table_id,
column_id='character_id',
column_type='string',
size=255,
required=False
)
self._create_column(
table_id=table_id,
column_id='request_duration_ms',
column_type='integer',
required=False,
default=0
)
self._create_column(
table_id=table_id,
column_id='success',
column_type='boolean',
required=False,
default=True
)
self._create_column(
table_id=table_id,
column_id='error_message',
column_type='string',
size=1000,
required=False
)
# Wait for columns to fully propagate
logger.info("Waiting for columns to propagate before creating indexes...")
time.sleep(2)
# Create indexes
self._create_index(
table_id=table_id,
index_id='idx_user_id',
index_type='key',
attributes=['user_id']
)
self._create_index(
table_id=table_id,
index_id='idx_timestamp',
index_type='key',
attributes=['timestamp']
)
self._create_index(
table_id=table_id,
index_id='idx_user_id_timestamp',
index_type='key',
attributes=['user_id', 'timestamp']
)
logger.info("AI usage logs table initialized successfully", table_id=table_id)
return True
except AppwriteException as e:
logger.error("Failed to initialize ai_usage_logs table",
table_id=table_id,
error=str(e),
code=e.code)
raise
def _create_column(
self,
table_id: str,
column_id: str,
column_type: str,
size: Optional[int] = None,
required: bool = False,
default: Optional[Any] = None,
array: bool = False
) -> Dict[str, Any]:
"""
Create a column in a table.
Args:
table_id: Table ID
column_id: Column ID
column_type: Column type (string, integer, float, boolean, datetime, email, ip, url)
size: Column size (for string types)
required: Whether column is required
default: Default value
array: Whether column is an array
Returns:
Column creation response
Raises:
AppwriteException: If column creation fails
"""
try:
logger.info("Creating column",
table_id=table_id,
column_id=column_id,
column_type=column_type)
# Build column parameters (Appwrite SDK uses 'key' not 'column_id')
params = {
'database_id': self.database_id,
'table_id': table_id,
'key': column_id,
'required': required,
'array': array
}
if size is not None:
params['size'] = size
if default is not None:
params['default'] = default
# Create column using the appropriate method based on type
if column_type == 'string':
result = self.tables_db.create_string_column(**params)
elif column_type == 'integer':
result = self.tables_db.create_integer_column(**params)
elif column_type == 'float':
result = self.tables_db.create_float_column(**params)
elif column_type == 'boolean':
result = self.tables_db.create_boolean_column(**params)
elif column_type == 'datetime':
result = self.tables_db.create_datetime_column(**params)
elif column_type == 'email':
result = self.tables_db.create_email_column(**params)
else:
raise ValueError(f"Unsupported column type: {column_type}")
logger.info("Column created successfully",
table_id=table_id,
column_id=column_id)
return result
except AppwriteException as e:
# If column already exists, log warning but don't fail
if e.code == 409: # Conflict - column already exists
logger.warning("Column already exists",
table_id=table_id,
column_id=column_id)
return {}
logger.error("Failed to create column",
table_id=table_id,
column_id=column_id,
error=str(e),
code=e.code)
raise
def _create_index(
self,
table_id: str,
index_id: str,
index_type: str,
attributes: List[str],
orders: Optional[List[str]] = None
) -> Dict[str, Any]:
"""
Create an index on a table.
Args:
table_id: Table ID
index_id: Index ID
index_type: Index type (key, fulltext, unique)
attributes: List of column IDs to index
orders: List of sort orders (ASC, DESC) for each attribute
Returns:
Index creation response
Raises:
AppwriteException: If index creation fails
"""
try:
logger.info("Creating index",
table_id=table_id,
index_id=index_id,
attributes=attributes)
result = self.tables_db.create_index(
database_id=self.database_id,
table_id=table_id,
key=index_id,
type=index_type,
columns=attributes, # SDK uses 'columns', not 'attributes'
orders=orders or ['ASC'] * len(attributes)
)
logger.info("Index created successfully",
table_id=table_id,
index_id=index_id)
return result
except AppwriteException as e:
# If index already exists, log warning but don't fail
if e.code == 409: # Conflict - index already exists
logger.warning("Index already exists",
table_id=table_id,
index_id=index_id)
return {}
logger.error("Failed to create index",
table_id=table_id,
index_id=index_id,
error=str(e),
code=e.code)
raise
# Global instance for convenience
_init_service_instance: Optional[DatabaseInitService] = None
def get_database_init_service() -> DatabaseInitService:
"""
Get the global DatabaseInitService instance.
Returns:
Singleton DatabaseInitService instance
"""
global _init_service_instance
if _init_service_instance is None:
_init_service_instance = DatabaseInitService()
return _init_service_instance
def init_database() -> Dict[str, bool]:
"""
Convenience function to initialize all database tables.
Returns:
Dictionary mapping table names to success status
"""
service = get_database_init_service()
return service.init_all_tables()

View File

@@ -0,0 +1,441 @@
"""
Database Service for Appwrite database operations.
This service wraps the Appwrite Databases SDK to provide a clean interface
for CRUD operations on collections. It handles JSON serialization, error handling,
and provides structured logging.
"""
import os
from typing import Dict, Any, List, Optional
from datetime import datetime, timezone
from dataclasses import dataclass
from appwrite.client import Client
from appwrite.services.tables_db import TablesDB
from appwrite.exception import AppwriteException
from appwrite.id import ID
from app.utils.logging import get_logger
logger = get_logger(__file__)
@dataclass
class DatabaseRow:
"""
Represents a row in an Appwrite table.
Attributes:
id: Row ID
table_id: Table ID
data: Row data (parsed from JSON)
created_at: Creation timestamp
updated_at: Last update timestamp
"""
id: str
table_id: str
data: Dict[str, Any]
created_at: datetime
updated_at: datetime
def to_dict(self) -> Dict[str, Any]:
"""Convert row to dictionary."""
return {
"id": self.id,
"table_id": self.table_id,
"data": self.data,
"created_at": self.created_at.isoformat() if isinstance(self.created_at, datetime) else self.created_at,
"updated_at": self.updated_at.isoformat() if isinstance(self.updated_at, datetime) else self.updated_at,
}
class DatabaseService:
"""
Service for interacting with Appwrite database tables.
This service provides methods for:
- Creating rows
- Reading rows by ID or query
- Updating rows
- Deleting rows
- Querying with filters
"""
def __init__(self):
"""
Initialize the database service.
Reads configuration from environment variables:
- APPWRITE_ENDPOINT: Appwrite API endpoint
- APPWRITE_PROJECT_ID: Appwrite project ID
- APPWRITE_API_KEY: Appwrite API key
- APPWRITE_DATABASE_ID: Appwrite database ID
"""
self.endpoint = os.getenv('APPWRITE_ENDPOINT')
self.project_id = os.getenv('APPWRITE_PROJECT_ID')
self.api_key = os.getenv('APPWRITE_API_KEY')
self.database_id = os.getenv('APPWRITE_DATABASE_ID', 'main')
if not all([self.endpoint, self.project_id, self.api_key]):
logger.error("Missing Appwrite configuration in environment variables")
raise ValueError("Appwrite configuration incomplete. Check APPWRITE_* environment variables.")
# Initialize Appwrite client
self.client = Client()
self.client.set_endpoint(self.endpoint)
self.client.set_project(self.project_id)
self.client.set_key(self.api_key)
# Initialize TablesDB service
self.tables_db = TablesDB(self.client)
logger.info("DatabaseService initialized", database_id=self.database_id)
def create_row(
self,
table_id: str,
data: Dict[str, Any],
row_id: Optional[str] = None,
permissions: Optional[List[str]] = None
) -> DatabaseRow:
"""
Create a new row in a table.
Args:
table_id: Table ID (e.g., "characters")
data: Row data (will be JSON-serialized if needed)
row_id: Optional custom row ID (auto-generated if None)
permissions: Optional permissions array
Returns:
DatabaseRow with created row
Raises:
AppwriteException: If creation fails
"""
try:
logger.info("Creating row", table_id=table_id, has_custom_id=bool(row_id))
# Generate ID if not provided
if row_id is None:
row_id = ID.unique()
# Create row (Appwrite manages timestamps automatically via $createdAt/$updatedAt)
result = self.tables_db.create_row(
database_id=self.database_id,
table_id=table_id,
row_id=row_id,
data=data,
permissions=permissions or []
)
logger.info("Row created successfully",
table_id=table_id,
row_id=result['$id'])
return self._parse_row(result, table_id)
except AppwriteException as e:
logger.error("Failed to create row",
table_id=table_id,
error=str(e),
code=e.code)
raise
def get_row(self, table_id: str, row_id: str) -> Optional[DatabaseRow]:
"""
Get a row by ID.
Args:
table_id: Table ID
row_id: Row ID
Returns:
DatabaseRow or None if not found
Raises:
AppwriteException: If retrieval fails (except 404)
"""
try:
logger.debug("Fetching row", table_id=table_id, row_id=row_id)
result = self.tables_db.get_row(
database_id=self.database_id,
table_id=table_id,
row_id=row_id
)
return self._parse_row(result, table_id)
except AppwriteException as e:
if e.code == 404:
logger.warning("Row not found",
table_id=table_id,
row_id=row_id)
return None
logger.error("Failed to fetch row",
table_id=table_id,
row_id=row_id,
error=str(e),
code=e.code)
raise
def update_row(
self,
table_id: str,
row_id: str,
data: Dict[str, Any],
permissions: Optional[List[str]] = None
) -> DatabaseRow:
"""
Update an existing row.
Args:
table_id: Table ID
row_id: Row ID
data: New row data (partial updates supported)
permissions: Optional permissions array
Returns:
DatabaseRow with updated row
Raises:
AppwriteException: If update fails
"""
try:
logger.info("Updating row", table_id=table_id, row_id=row_id)
# Update row (Appwrite manages timestamps automatically via $updatedAt)
result = self.tables_db.update_row(
database_id=self.database_id,
table_id=table_id,
row_id=row_id,
data=data,
permissions=permissions
)
logger.info("Row updated successfully",
table_id=table_id,
row_id=row_id)
return self._parse_row(result, table_id)
except AppwriteException as e:
logger.error("Failed to update row",
table_id=table_id,
row_id=row_id,
error=str(e),
code=e.code)
raise
def delete_row(self, table_id: str, row_id: str) -> bool:
"""
Delete a row.
Args:
table_id: Table ID
row_id: Row ID
Returns:
True if deletion successful
Raises:
AppwriteException: If deletion fails
"""
try:
logger.info("Deleting row", table_id=table_id, row_id=row_id)
self.tables_db.delete_row(
database_id=self.database_id,
table_id=table_id,
row_id=row_id
)
logger.info("Row deleted successfully",
table_id=table_id,
row_id=row_id)
return True
except AppwriteException as e:
logger.error("Failed to delete row",
table_id=table_id,
row_id=row_id,
error=str(e),
code=e.code)
raise
def list_rows(
self,
table_id: str,
queries: Optional[List[str]] = None,
limit: int = 25,
offset: int = 0
) -> List[DatabaseRow]:
"""
List rows in a table with optional filtering.
Args:
table_id: Table ID
queries: Optional Appwrite query filters
limit: Maximum rows to return (default 25, max 100)
offset: Number of rows to skip
Returns:
List of DatabaseRow instances
Raises:
AppwriteException: If query fails
"""
try:
logger.debug("Listing rows",
table_id=table_id,
has_queries=bool(queries),
limit=limit,
offset=offset)
result = self.tables_db.list_rows(
database_id=self.database_id,
table_id=table_id,
queries=queries or []
)
rows = [self._parse_row(row, table_id) for row in result['rows']]
logger.debug("Rows listed successfully",
table_id=table_id,
count=len(rows),
total=result.get('total', len(rows)))
return rows
except AppwriteException as e:
logger.error("Failed to list rows",
table_id=table_id,
error=str(e),
code=e.code)
raise
def count_rows(self, table_id: str, queries: Optional[List[str]] = None) -> int:
"""
Count rows in a table with optional filtering.
Args:
table_id: Table ID
queries: Optional Appwrite query filters
Returns:
Row count
Raises:
AppwriteException: If query fails
"""
try:
logger.debug("Counting rows", table_id=table_id, has_queries=bool(queries))
result = self.tables_db.list_rows(
database_id=self.database_id,
table_id=table_id,
queries=queries or []
)
count = result.get('total', len(result.get('rows', [])))
logger.debug("Rows counted", table_id=table_id, count=count)
return count
except AppwriteException as e:
logger.error("Failed to count rows",
table_id=table_id,
error=str(e),
code=e.code)
raise
def _parse_row(self, row: Dict[str, Any], table_id: str) -> DatabaseRow:
"""
Parse Appwrite row into DatabaseRow.
Args:
row: Appwrite row dictionary
table_id: Table ID
Returns:
DatabaseRow instance
"""
# Extract metadata
row_id = row['$id']
created_at = row.get('$createdAt', datetime.now(timezone.utc).isoformat())
updated_at = row.get('$updatedAt', datetime.now(timezone.utc).isoformat())
# Parse timestamps
if isinstance(created_at, str):
created_at = datetime.fromisoformat(created_at.replace('Z', '+00:00'))
if isinstance(updated_at, str):
updated_at = datetime.fromisoformat(updated_at.replace('Z', '+00:00'))
# Remove Appwrite metadata from data
data = {k: v for k, v in row.items() if not k.startswith('$')}
return DatabaseRow(
id=row_id,
table_id=table_id,
data=data,
created_at=created_at,
updated_at=updated_at
)
# Backward compatibility aliases (deprecated, use new methods)
def create_document(self, collection_id: str, data: Dict[str, Any],
document_id: Optional[str] = None,
permissions: Optional[List[str]] = None) -> DatabaseRow:
"""Deprecated: Use create_row() instead."""
logger.warning("create_document() is deprecated, use create_row() instead")
return self.create_row(collection_id, data, document_id, permissions)
def get_document(self, collection_id: str, document_id: str) -> Optional[DatabaseRow]:
"""Deprecated: Use get_row() instead."""
logger.warning("get_document() is deprecated, use get_row() instead")
return self.get_row(collection_id, document_id)
def update_document(self, collection_id: str, document_id: str,
data: Dict[str, Any],
permissions: Optional[List[str]] = None) -> DatabaseRow:
"""Deprecated: Use update_row() instead."""
logger.warning("update_document() is deprecated, use update_row() instead")
return self.update_row(collection_id, document_id, data, permissions)
def delete_document(self, collection_id: str, document_id: str) -> bool:
"""Deprecated: Use delete_row() instead."""
logger.warning("delete_document() is deprecated, use delete_row() instead")
return self.delete_row(collection_id, document_id)
def list_documents(self, collection_id: str, queries: Optional[List[str]] = None,
limit: int = 25, offset: int = 0) -> List[DatabaseRow]:
"""Deprecated: Use list_rows() instead."""
logger.warning("list_documents() is deprecated, use list_rows() instead")
return self.list_rows(collection_id, queries, limit, offset)
def count_documents(self, collection_id: str, queries: Optional[List[str]] = None) -> int:
"""Deprecated: Use count_rows() instead."""
logger.warning("count_documents() is deprecated, use count_rows() instead")
return self.count_rows(collection_id, queries)
# Backward compatibility alias
DatabaseDocument = DatabaseRow
# Global instance for convenience
_service_instance: Optional[DatabaseService] = None
def get_database_service() -> DatabaseService:
"""
Get the global DatabaseService instance.
Returns:
Singleton DatabaseService instance
"""
global _service_instance
if _service_instance is None:
_service_instance = DatabaseService()
return _service_instance

View File

@@ -0,0 +1,351 @@
"""
Item validation service for AI-granted items.
This module validates and resolves items that the AI grants to players during
gameplay, ensuring they meet character requirements and game balance rules.
"""
import uuid
from pathlib import Path
from typing import Optional
import structlog
import yaml
from app.models.items import Item
from app.models.enums import ItemType
from app.models.character import Character
from app.ai.response_parser import ItemGrant
logger = structlog.get_logger(__name__)
class ItemValidationError(Exception):
"""
Exception raised when an item fails validation.
Attributes:
message: Human-readable error message
item_grant: The ItemGrant that failed validation
reason: Machine-readable reason code
"""
def __init__(self, message: str, item_grant: ItemGrant, reason: str):
super().__init__(message)
self.message = message
self.item_grant = item_grant
self.reason = reason
class ItemValidator:
"""
Validates and resolves items granted by the AI.
This service:
1. Resolves item references (by ID or creates generic items)
2. Validates items against character requirements
3. Logs validation failures for review
"""
# Map of generic item type strings to ItemType enums
TYPE_MAP = {
"weapon": ItemType.WEAPON,
"armor": ItemType.ARMOR,
"consumable": ItemType.CONSUMABLE,
"quest_item": ItemType.QUEST_ITEM,
}
def __init__(self, data_path: Optional[Path] = None):
"""
Initialize the item validator.
Args:
data_path: Path to game data directory. Defaults to app/data/
"""
if data_path is None:
# Default to api/app/data/
data_path = Path(__file__).parent.parent / "data"
self.data_path = data_path
self._item_registry: dict[str, dict] = {}
self._generic_templates: dict[str, dict] = {}
self._load_data()
logger.info(
"ItemValidator initialized",
items_loaded=len(self._item_registry),
generic_templates_loaded=len(self._generic_templates)
)
def _load_data(self) -> None:
"""Load item data from YAML files."""
# Load main item registry if it exists
items_file = self.data_path / "items.yaml"
if items_file.exists():
with open(items_file) as f:
data = yaml.safe_load(f) or {}
self._item_registry = data.get("items", {})
# Load generic item templates
generic_file = self.data_path / "generic_items.yaml"
if generic_file.exists():
with open(generic_file) as f:
data = yaml.safe_load(f) or {}
self._generic_templates = data.get("templates", {})
def resolve_item(self, item_grant: ItemGrant) -> Item:
"""
Resolve an ItemGrant to an actual Item instance.
For existing items (by item_id), looks up from item registry.
For generic items (by name/type), creates a new Item.
Args:
item_grant: The ItemGrant from AI response
Returns:
Resolved Item instance
Raises:
ItemValidationError: If item cannot be resolved
"""
if item_grant.is_existing_item():
return self._resolve_existing_item(item_grant)
elif item_grant.is_generic_item():
return self._create_generic_item(item_grant)
else:
raise ItemValidationError(
"ItemGrant has neither item_id nor name",
item_grant,
"INVALID_ITEM_GRANT"
)
def _resolve_existing_item(self, item_grant: ItemGrant) -> Item:
"""
Look up an existing item by ID.
Args:
item_grant: ItemGrant with item_id set
Returns:
Item instance from registry
Raises:
ItemValidationError: If item not found
"""
item_id = item_grant.item_id
if item_id not in self._item_registry:
logger.warning(
"Item not found in registry",
item_id=item_id
)
raise ItemValidationError(
f"Unknown item_id: {item_id}",
item_grant,
"ITEM_NOT_FOUND"
)
item_data = self._item_registry[item_id]
# Convert to Item instance
return Item.from_dict({
"item_id": item_id,
**item_data
})
def _create_generic_item(self, item_grant: ItemGrant) -> Item:
"""
Create a generic item from AI-provided details.
Generic items are simple items with no special stats,
suitable for mundane objects like torches, food, etc.
Args:
item_grant: ItemGrant with name, type, description
Returns:
New Item instance
Raises:
ItemValidationError: If item type is invalid
"""
# Validate item type
item_type_str = (item_grant.item_type or "consumable").lower()
if item_type_str not in self.TYPE_MAP:
logger.warning(
"Invalid item type from AI",
item_type=item_type_str,
item_name=item_grant.name
)
# Default to consumable for unknown types
item_type_str = "consumable"
item_type = self.TYPE_MAP[item_type_str]
# Generate unique ID for this item instance
item_id = f"generic_{uuid.uuid4().hex[:8]}"
# Check if we have a template for this item name
template = self._find_template(item_grant.name or "")
if template:
# Use template values as defaults
return Item(
item_id=item_id,
name=item_grant.name or template.get("name", "Unknown Item"),
item_type=item_type,
description=item_grant.description or template.get("description", ""),
value=item_grant.value or template.get("value", 0),
is_tradeable=template.get("is_tradeable", True),
required_level=template.get("required_level", 1),
)
else:
# Create with provided values only
return Item(
item_id=item_id,
name=item_grant.name or "Unknown Item",
item_type=item_type,
description=item_grant.description or "A simple item.",
value=item_grant.value,
is_tradeable=True,
required_level=1,
)
def _find_template(self, item_name: str) -> Optional[dict]:
"""
Find a generic item template by name.
Uses case-insensitive partial matching.
Args:
item_name: Name of the item to find
Returns:
Template dict or None if not found
"""
name_lower = item_name.lower()
# Exact match first
if name_lower in self._generic_templates:
return self._generic_templates[name_lower]
# Partial match
for template_name, template in self._generic_templates.items():
if template_name in name_lower or name_lower in template_name:
return template
return None
def validate_item_for_character(
self,
item: Item,
character: Character
) -> tuple[bool, Optional[str]]:
"""
Validate that a character can receive an item.
Checks:
- Level requirements
- Class restrictions
Args:
item: The Item to validate
character: The Character to receive the item
Returns:
Tuple of (is_valid, error_message)
"""
# Check level requirement
if item.required_level > character.level:
error_msg = (
f"Item '{item.name}' requires level {item.required_level}, "
f"but character is level {character.level}"
)
logger.warning(
"Item validation failed: level requirement",
item_name=item.name,
required_level=item.required_level,
character_level=character.level,
character_name=character.name
)
return False, error_msg
# Check class restriction
if item.required_class:
character_class = character.player_class.class_id
if item.required_class.lower() != character_class.lower():
error_msg = (
f"Item '{item.name}' requires class {item.required_class}, "
f"but character is {character_class}"
)
logger.warning(
"Item validation failed: class restriction",
item_name=item.name,
required_class=item.required_class,
character_class=character_class,
character_name=character.name
)
return False, error_msg
return True, None
def validate_and_resolve_item(
self,
item_grant: ItemGrant,
character: Character
) -> tuple[Optional[Item], Optional[str]]:
"""
Resolve an item grant and validate it for a character.
This is the main entry point for processing AI-granted items.
Args:
item_grant: The ItemGrant from AI response
character: The Character to receive the item
Returns:
Tuple of (Item if valid else None, error_message if invalid else None)
"""
try:
# Resolve the item
item = self.resolve_item(item_grant)
# Validate for character
is_valid, error_msg = self.validate_item_for_character(item, character)
if not is_valid:
return None, error_msg
logger.info(
"Item validated successfully",
item_name=item.name,
item_id=item.item_id,
character_name=character.name
)
return item, None
except ItemValidationError as e:
logger.warning(
"Item resolution failed",
error=e.message,
reason=e.reason
)
return None, e.message
# Global instance for convenience
_validator_instance: Optional[ItemValidator] = None
def get_item_validator() -> ItemValidator:
"""
Get or create the global ItemValidator instance.
Returns:
ItemValidator singleton instance
"""
global _validator_instance
if _validator_instance is None:
_validator_instance = ItemValidator()
return _validator_instance

View File

@@ -0,0 +1,326 @@
"""
LocationLoader service for loading location definitions from YAML files.
This service reads location configuration files and converts them into Location
dataclass instances, providing caching for performance. Locations are organized
by region subdirectories.
"""
import yaml
from pathlib import Path
from typing import Dict, List, Optional
import structlog
from app.models.location import Location, Region
from app.models.enums import LocationType
logger = structlog.get_logger(__name__)
class LocationLoader:
"""
Loads location definitions from YAML configuration files.
Locations are organized in region subdirectories:
/app/data/locations/
regions/
crossville.yaml
crossville/
crossville_village.yaml
crossville_tavern.yaml
This allows game designers to define world locations without touching code.
"""
def __init__(self, data_dir: Optional[str] = None):
"""
Initialize the location loader.
Args:
data_dir: Path to directory containing location YAML files.
Defaults to /app/data/locations/
"""
if data_dir is None:
# Default to app/data/locations relative to this file
current_file = Path(__file__)
app_dir = current_file.parent.parent # Go up to /app
data_dir = str(app_dir / "data" / "locations")
self.data_dir = Path(data_dir)
self._location_cache: Dict[str, Location] = {}
self._region_cache: Dict[str, Region] = {}
logger.info("LocationLoader initialized", data_dir=str(self.data_dir))
def load_location(self, location_id: str) -> Optional[Location]:
"""
Load a single location by ID.
Searches all region subdirectories for the location file.
Args:
location_id: Unique location identifier (e.g., "crossville_tavern")
Returns:
Location instance or None if not found
"""
# Check cache first
if location_id in self._location_cache:
logger.debug("Location loaded from cache", location_id=location_id)
return self._location_cache[location_id]
# Search in region subdirectories
if not self.data_dir.exists():
logger.error("Location data directory does not exist", data_dir=str(self.data_dir))
return None
for region_dir in self.data_dir.iterdir():
# Skip non-directories and the regions folder
if not region_dir.is_dir() or region_dir.name == "regions":
continue
file_path = region_dir / f"{location_id}.yaml"
if file_path.exists():
return self._load_location_file(file_path)
logger.warning("Location not found", location_id=location_id)
return None
def _load_location_file(self, file_path: Path) -> Optional[Location]:
"""
Load a location from a specific file.
Args:
file_path: Path to the YAML file
Returns:
Location instance or None if loading fails
"""
try:
with open(file_path, 'r') as f:
data = yaml.safe_load(f)
location = self._parse_location_data(data)
self._location_cache[location.location_id] = location
logger.info("Location loaded successfully", location_id=location.location_id)
return location
except Exception as e:
logger.error("Failed to load location", file=str(file_path), error=str(e))
return None
def _parse_location_data(self, data: Dict) -> Location:
"""
Parse YAML data into a Location dataclass.
Args:
data: Dictionary loaded from YAML file
Returns:
Location instance
Raises:
ValueError: If data is invalid or missing required fields
"""
# Validate required fields
required_fields = ["location_id", "name", "region_id", "description"]
for field in required_fields:
if field not in data:
raise ValueError(f"Missing required field: {field}")
# Parse location type - default to town
location_type_str = data.get("location_type", "town")
try:
location_type = LocationType(location_type_str)
except ValueError:
logger.warning(
"Invalid location type, defaulting to town",
location_id=data["location_id"],
invalid_type=location_type_str
)
location_type = LocationType.TOWN
return Location(
location_id=data["location_id"],
name=data["name"],
location_type=location_type,
region_id=data["region_id"],
description=data["description"],
lore=data.get("lore"),
ambient_description=data.get("ambient_description"),
available_quests=data.get("available_quests", []),
npc_ids=data.get("npc_ids", []),
discoverable_locations=data.get("discoverable_locations", []),
is_starting_location=data.get("is_starting_location", False),
tags=data.get("tags", []),
)
def load_all_locations(self) -> List[Location]:
"""
Load all locations from all region directories.
Returns:
List of Location instances
"""
locations = []
if not self.data_dir.exists():
logger.error("Location data directory does not exist", data_dir=str(self.data_dir))
return locations
for region_dir in self.data_dir.iterdir():
# Skip non-directories and the regions folder
if not region_dir.is_dir() or region_dir.name == "regions":
continue
for file_path in region_dir.glob("*.yaml"):
location = self._load_location_file(file_path)
if location:
locations.append(location)
logger.info("All locations loaded", count=len(locations))
return locations
def load_region(self, region_id: str) -> Optional[Region]:
"""
Load a region definition.
Args:
region_id: Unique region identifier (e.g., "crossville")
Returns:
Region instance or None if not found
"""
# Check cache first
if region_id in self._region_cache:
logger.debug("Region loaded from cache", region_id=region_id)
return self._region_cache[region_id]
file_path = self.data_dir / "regions" / f"{region_id}.yaml"
if not file_path.exists():
logger.warning("Region file not found", region_id=region_id)
return None
try:
with open(file_path, 'r') as f:
data = yaml.safe_load(f)
region = Region.from_dict(data)
self._region_cache[region_id] = region
logger.info("Region loaded successfully", region_id=region_id)
return region
except Exception as e:
logger.error("Failed to load region", region_id=region_id, error=str(e))
return None
def get_locations_in_region(self, region_id: str) -> List[Location]:
"""
Get all locations belonging to a specific region.
Args:
region_id: Region identifier
Returns:
List of Location instances in this region
"""
# Load all locations if cache is empty
if not self._location_cache:
self.load_all_locations()
return [
loc for loc in self._location_cache.values()
if loc.region_id == region_id
]
def get_starting_locations(self) -> List[Location]:
"""
Get all locations that can be starting points.
Returns:
List of Location instances marked as starting locations
"""
# Load all locations if cache is empty
if not self._location_cache:
self.load_all_locations()
return [
loc for loc in self._location_cache.values()
if loc.is_starting_location
]
def get_location_by_type(self, location_type: LocationType) -> List[Location]:
"""
Get all locations of a specific type.
Args:
location_type: Type to filter by
Returns:
List of Location instances of this type
"""
# Load all locations if cache is empty
if not self._location_cache:
self.load_all_locations()
return [
loc for loc in self._location_cache.values()
if loc.location_type == location_type
]
def get_all_location_ids(self) -> List[str]:
"""
Get a list of all available location IDs.
Returns:
List of location IDs
"""
# Load all locations if cache is empty
if not self._location_cache:
self.load_all_locations()
return list(self._location_cache.keys())
def reload_location(self, location_id: str) -> Optional[Location]:
"""
Force reload a location from disk, bypassing cache.
Useful for development/testing when location definitions change.
Args:
location_id: Unique location identifier
Returns:
Location instance or None if not found
"""
# Remove from cache if present
if location_id in self._location_cache:
del self._location_cache[location_id]
return self.load_location(location_id)
def clear_cache(self) -> None:
"""Clear all cached data. Useful for testing."""
self._location_cache.clear()
self._region_cache.clear()
logger.info("Location cache cleared")
# Global singleton instance
_loader_instance: Optional[LocationLoader] = None
def get_location_loader() -> LocationLoader:
"""
Get the global LocationLoader instance.
Returns:
Singleton LocationLoader instance
"""
global _loader_instance
if _loader_instance is None:
_loader_instance = LocationLoader()
return _loader_instance

View File

@@ -0,0 +1,385 @@
"""
NPCLoader service for loading NPC definitions from YAML files.
This service reads NPC configuration files and converts them into NPC
dataclass instances, providing caching for performance. NPCs are organized
by region subdirectories.
"""
import yaml
from pathlib import Path
from typing import Dict, List, Optional
import structlog
from app.models.npc import (
NPC,
NPCPersonality,
NPCAppearance,
NPCKnowledge,
NPCKnowledgeCondition,
NPCRelationship,
NPCInventoryItem,
NPCDialogueHooks,
)
logger = structlog.get_logger(__name__)
class NPCLoader:
"""
Loads NPC definitions from YAML configuration files.
NPCs are organized in region subdirectories:
/app/data/npcs/
crossville/
npc_grom_001.yaml
npc_mira_001.yaml
This allows game designers to define NPCs without touching code.
"""
def __init__(self, data_dir: Optional[str] = None):
"""
Initialize the NPC loader.
Args:
data_dir: Path to directory containing NPC YAML files.
Defaults to /app/data/npcs/
"""
if data_dir is None:
# Default to app/data/npcs relative to this file
current_file = Path(__file__)
app_dir = current_file.parent.parent # Go up to /app
data_dir = str(app_dir / "data" / "npcs")
self.data_dir = Path(data_dir)
self._npc_cache: Dict[str, NPC] = {}
self._location_npc_cache: Dict[str, List[str]] = {}
logger.info("NPCLoader initialized", data_dir=str(self.data_dir))
def load_npc(self, npc_id: str) -> Optional[NPC]:
"""
Load a single NPC by ID.
Searches all region subdirectories for the NPC file.
Args:
npc_id: Unique NPC identifier (e.g., "npc_grom_001")
Returns:
NPC instance or None if not found
"""
# Check cache first
if npc_id in self._npc_cache:
logger.debug("NPC loaded from cache", npc_id=npc_id)
return self._npc_cache[npc_id]
# Search in region subdirectories
if not self.data_dir.exists():
logger.error("NPC data directory does not exist", data_dir=str(self.data_dir))
return None
for region_dir in self.data_dir.iterdir():
if not region_dir.is_dir():
continue
file_path = region_dir / f"{npc_id}.yaml"
if file_path.exists():
return self._load_npc_file(file_path)
logger.warning("NPC not found", npc_id=npc_id)
return None
def _load_npc_file(self, file_path: Path) -> Optional[NPC]:
"""
Load an NPC from a specific file.
Args:
file_path: Path to the YAML file
Returns:
NPC instance or None if loading fails
"""
try:
with open(file_path, 'r') as f:
data = yaml.safe_load(f)
npc = self._parse_npc_data(data)
self._npc_cache[npc.npc_id] = npc
# Update location cache
if npc.location_id not in self._location_npc_cache:
self._location_npc_cache[npc.location_id] = []
if npc.npc_id not in self._location_npc_cache[npc.location_id]:
self._location_npc_cache[npc.location_id].append(npc.npc_id)
logger.info("NPC loaded successfully", npc_id=npc.npc_id)
return npc
except Exception as e:
logger.error("Failed to load NPC", file=str(file_path), error=str(e))
return None
def _parse_npc_data(self, data: Dict) -> NPC:
"""
Parse YAML data into an NPC dataclass.
Args:
data: Dictionary loaded from YAML file
Returns:
NPC instance
Raises:
ValueError: If data is invalid or missing required fields
"""
# Validate required fields
required_fields = ["npc_id", "name", "role", "location_id"]
for field in required_fields:
if field not in data:
raise ValueError(f"Missing required field: {field}")
# Parse personality
personality_data = data.get("personality", {})
personality = NPCPersonality(
traits=personality_data.get("traits", []),
speech_style=personality_data.get("speech_style", ""),
quirks=personality_data.get("quirks", []),
)
# Parse appearance
appearance_data = data.get("appearance", {})
if isinstance(appearance_data, str):
appearance = NPCAppearance(brief=appearance_data)
else:
appearance = NPCAppearance(
brief=appearance_data.get("brief", ""),
detailed=appearance_data.get("detailed"),
)
# Parse knowledge (optional)
knowledge = None
if data.get("knowledge"):
knowledge_data = data["knowledge"]
conditions = [
NPCKnowledgeCondition(
condition=c.get("condition", ""),
reveals=c.get("reveals", ""),
)
for c in knowledge_data.get("will_share_if", [])
]
knowledge = NPCKnowledge(
public=knowledge_data.get("public", []),
secret=knowledge_data.get("secret", []),
will_share_if=conditions,
)
# Parse relationships
relationships = [
NPCRelationship(
npc_id=r["npc_id"],
attitude=r["attitude"],
reason=r.get("reason"),
)
for r in data.get("relationships", [])
]
# Parse inventory
inventory = []
for item_data in data.get("inventory_for_sale", []):
# Handle shorthand format: { item: "ale", price: 2 }
item_id = item_data.get("item_id") or item_data.get("item", "")
inventory.append(NPCInventoryItem(
item_id=item_id,
price=item_data.get("price", 0),
quantity=item_data.get("quantity"),
))
# Parse dialogue hooks (optional)
dialogue_hooks = None
if data.get("dialogue_hooks"):
hooks_data = data["dialogue_hooks"]
dialogue_hooks = NPCDialogueHooks(
greeting=hooks_data.get("greeting"),
farewell=hooks_data.get("farewell"),
busy=hooks_data.get("busy"),
quest_complete=hooks_data.get("quest_complete"),
)
return NPC(
npc_id=data["npc_id"],
name=data["name"],
role=data["role"],
location_id=data["location_id"],
personality=personality,
appearance=appearance,
knowledge=knowledge,
relationships=relationships,
inventory_for_sale=inventory,
dialogue_hooks=dialogue_hooks,
quest_giver_for=data.get("quest_giver_for", []),
reveals_locations=data.get("reveals_locations", []),
tags=data.get("tags", []),
)
def load_all_npcs(self) -> List[NPC]:
"""
Load all NPCs from all region directories.
Returns:
List of NPC instances
"""
npcs = []
if not self.data_dir.exists():
logger.error("NPC data directory does not exist", data_dir=str(self.data_dir))
return npcs
for region_dir in self.data_dir.iterdir():
if not region_dir.is_dir():
continue
for file_path in region_dir.glob("*.yaml"):
npc = self._load_npc_file(file_path)
if npc:
npcs.append(npc)
logger.info("All NPCs loaded", count=len(npcs))
return npcs
def get_npcs_at_location(self, location_id: str) -> List[NPC]:
"""
Get all NPCs at a specific location.
Args:
location_id: Location identifier
Returns:
List of NPC instances at this location
"""
# Ensure all NPCs are loaded
if not self._npc_cache:
self.load_all_npcs()
npc_ids = self._location_npc_cache.get(location_id, [])
return [
self._npc_cache[npc_id]
for npc_id in npc_ids
if npc_id in self._npc_cache
]
def get_npc_ids_at_location(self, location_id: str) -> List[str]:
"""
Get NPC IDs at a specific location.
Args:
location_id: Location identifier
Returns:
List of NPC IDs at this location
"""
# Ensure all NPCs are loaded
if not self._npc_cache:
self.load_all_npcs()
return self._location_npc_cache.get(location_id, [])
def get_npcs_by_tag(self, tag: str) -> List[NPC]:
"""
Get all NPCs with a specific tag.
Args:
tag: Tag to filter by (e.g., "merchant", "quest_giver")
Returns:
List of NPC instances with this tag
"""
# Ensure all NPCs are loaded
if not self._npc_cache:
self.load_all_npcs()
return [
npc for npc in self._npc_cache.values()
if tag in npc.tags
]
def get_quest_givers(self, quest_id: str) -> List[NPC]:
"""
Get all NPCs that can give a specific quest.
Args:
quest_id: Quest identifier
Returns:
List of NPC instances that give this quest
"""
# Ensure all NPCs are loaded
if not self._npc_cache:
self.load_all_npcs()
return [
npc for npc in self._npc_cache.values()
if quest_id in npc.quest_giver_for
]
def get_all_npc_ids(self) -> List[str]:
"""
Get a list of all available NPC IDs.
Returns:
List of NPC IDs
"""
# Ensure all NPCs are loaded
if not self._npc_cache:
self.load_all_npcs()
return list(self._npc_cache.keys())
def reload_npc(self, npc_id: str) -> Optional[NPC]:
"""
Force reload an NPC from disk, bypassing cache.
Useful for development/testing when NPC definitions change.
Args:
npc_id: Unique NPC identifier
Returns:
NPC instance or None if not found
"""
# Remove from caches if present
if npc_id in self._npc_cache:
old_npc = self._npc_cache[npc_id]
# Remove from location cache
if old_npc.location_id in self._location_npc_cache:
self._location_npc_cache[old_npc.location_id] = [
n for n in self._location_npc_cache[old_npc.location_id]
if n != npc_id
]
del self._npc_cache[npc_id]
return self.load_npc(npc_id)
def clear_cache(self) -> None:
"""Clear all cached data. Useful for testing."""
self._npc_cache.clear()
self._location_npc_cache.clear()
logger.info("NPC cache cleared")
# Global singleton instance
_loader_instance: Optional[NPCLoader] = None
def get_npc_loader() -> NPCLoader:
"""
Get the global NPCLoader instance.
Returns:
Singleton NPCLoader instance
"""
global _loader_instance
if _loader_instance is None:
_loader_instance = NPCLoader()
return _loader_instance

View File

@@ -0,0 +1,236 @@
"""
OriginService for loading character origin definitions from YAML files.
This service reads origin configuration and converts it into Origin
dataclass instances, providing caching for performance.
"""
import yaml
from pathlib import Path
from typing import Dict, List, Optional
import structlog
from app.models.origins import Origin, StartingLocation, StartingBonus
logger = structlog.get_logger(__name__)
class OriginService:
"""
Loads character origin definitions from YAML configuration.
Origins define character backstories, starting locations, and narrative
hooks that the AI DM uses to create personalized gameplay experiences.
All origin definitions are stored in /app/data/origins.yaml.
"""
def __init__(self, data_file: Optional[str] = None):
"""
Initialize the origin service.
Args:
data_file: Path to origins YAML file.
Defaults to /app/data/origins.yaml
"""
if data_file is None:
# Default to app/data/origins.yaml relative to this file
current_file = Path(__file__)
app_dir = current_file.parent.parent # Go up to /app
data_file = str(app_dir / "data" / "origins.yaml")
self.data_file = Path(data_file)
self._origins_cache: Dict[str, Origin] = {}
self._all_origins_loaded = False
logger.info("OriginService initialized", data_file=str(self.data_file))
def load_origin(self, origin_id: str) -> Optional[Origin]:
"""
Load a single origin by ID.
Args:
origin_id: Unique origin identifier (e.g., "soul_revenant")
Returns:
Origin instance or None if not found
"""
# Check cache first
if origin_id in self._origins_cache:
logger.debug("Origin loaded from cache", origin_id=origin_id)
return self._origins_cache[origin_id]
# Load all origins if not already loaded
if not self._all_origins_loaded:
self._load_all_origins()
# Return from cache after loading
origin = self._origins_cache.get(origin_id)
if origin:
logger.info("Origin loaded successfully", origin_id=origin_id)
else:
logger.warning("Origin not found", origin_id=origin_id)
return origin
def load_all_origins(self) -> List[Origin]:
"""
Load all origins from the data file.
Returns:
List of Origin instances
"""
if self._all_origins_loaded and self._origins_cache:
logger.debug("All origins loaded from cache")
return list(self._origins_cache.values())
return self._load_all_origins()
def _load_all_origins(self) -> List[Origin]:
"""
Internal method to load all origins from YAML.
Returns:
List of Origin instances
"""
if not self.data_file.exists():
logger.error("Origins data file does not exist", data_file=str(self.data_file))
return []
try:
# Load YAML file
with open(self.data_file, 'r') as f:
data = yaml.safe_load(f)
origins_data = data.get("origins", {})
origins = []
# Parse each origin
for origin_id, origin_data in origins_data.items():
try:
origin = self._parse_origin_data(origin_id, origin_data)
self._origins_cache[origin_id] = origin
origins.append(origin)
except Exception as e:
logger.error("Failed to parse origin", origin_id=origin_id, error=str(e))
continue
self._all_origins_loaded = True
logger.info("All origins loaded successfully", count=len(origins))
return origins
except Exception as e:
logger.error("Failed to load origins file", error=str(e))
return []
def get_origin_by_id(self, origin_id: str) -> Optional[Origin]:
"""
Get an origin by ID (alias for load_origin).
Args:
origin_id: Unique origin identifier
Returns:
Origin instance or None if not found
"""
return self.load_origin(origin_id)
def get_all_origin_ids(self) -> List[str]:
"""
Get a list of all available origin IDs.
Returns:
List of origin IDs (e.g., ["soul_revenant", "memory_thief"])
"""
if not self._all_origins_loaded:
self._load_all_origins()
return list(self._origins_cache.keys())
def reload_origins(self) -> List[Origin]:
"""
Force reload all origins from disk, bypassing cache.
Useful for development/testing when origin definitions change.
Returns:
List of Origin instances
"""
self.clear_cache()
return self._load_all_origins()
def clear_cache(self):
"""Clear the origins cache. Useful for testing."""
self._origins_cache.clear()
self._all_origins_loaded = False
logger.info("Origins cache cleared")
def _parse_origin_data(self, origin_id: str, data: Dict) -> Origin:
"""
Parse YAML data into an Origin dataclass.
Args:
origin_id: The origin's unique identifier
data: Dictionary loaded from YAML file
Returns:
Origin instance
Raises:
ValueError: If data is invalid or missing required fields
"""
# Validate required fields
required_fields = ["name", "description", "starting_location"]
for field in required_fields:
if field not in data:
raise ValueError(f"Missing required field in origin '{origin_id}': {field}")
# Parse starting location
location_data = data["starting_location"]
starting_location = StartingLocation(
id=location_data.get("id", ""),
name=location_data.get("name", ""),
region=location_data.get("region", ""),
description=location_data.get("description", "")
)
# Parse starting bonus (optional)
starting_bonus = None
if "starting_bonus" in data:
bonus_data = data["starting_bonus"]
starting_bonus = StartingBonus(
trait=bonus_data.get("trait", ""),
description=bonus_data.get("description", ""),
effect=bonus_data.get("effect", "")
)
# Parse narrative hooks (optional)
narrative_hooks = data.get("narrative_hooks", [])
# Create Origin instance
origin = Origin(
id=data.get("id", origin_id), # Use provided ID or fall back to key
name=data["name"],
description=data["description"],
starting_location=starting_location,
narrative_hooks=narrative_hooks,
starting_bonus=starting_bonus
)
return origin
# Global instance for convenience
_service_instance: Optional[OriginService] = None
def get_origin_service() -> OriginService:
"""
Get the global OriginService instance.
Returns:
Singleton OriginService instance
"""
global _service_instance
if _service_instance is None:
_service_instance = OriginService()
return _service_instance

View File

@@ -0,0 +1,373 @@
"""
Outcome determination service for Code of Conquest.
This service handles all code-determined game outcomes before they're passed
to AI for narration. It uses the dice mechanics system to determine success/failure
and selects appropriate rewards from loot tables.
"""
import random
import yaml
import structlog
from pathlib import Path
from dataclasses import dataclass
from typing import Optional, List, Dict, Any
from app.models.character import Character
from app.game_logic.dice import (
CheckResult, SkillType, Difficulty,
skill_check, get_stat_for_skill, perception_check
)
logger = structlog.get_logger(__name__)
@dataclass
class ItemFound:
"""
Represents an item found during a search.
Uses template key from generic_items.yaml.
"""
template_key: str
name: str
description: str
value: int
def to_dict(self) -> dict:
"""Serialize for API response."""
return {
"template_key": self.template_key,
"name": self.name,
"description": self.description,
"value": self.value,
}
@dataclass
class SearchOutcome:
"""
Complete result of a search action.
Includes the dice check result and any items/gold found.
"""
check_result: CheckResult
items_found: List[ItemFound]
gold_found: int
def to_dict(self) -> dict:
"""Serialize for API response."""
return {
"check_result": self.check_result.to_dict(),
"items_found": [item.to_dict() for item in self.items_found],
"gold_found": self.gold_found,
}
@dataclass
class SkillCheckOutcome:
"""
Result of a generic skill check.
Used for persuasion, lockpicking, stealth, etc.
"""
check_result: CheckResult
context: Dict[str, Any] # Additional context for AI narration
def to_dict(self) -> dict:
"""Serialize for API response."""
return {
"check_result": self.check_result.to_dict(),
"context": self.context,
}
class OutcomeService:
"""
Service for determining game action outcomes.
Handles all dice rolls and loot selection before passing to AI.
"""
def __init__(self):
"""Initialize the outcome service with loot tables and item templates."""
self._loot_tables: Dict[str, Any] = {}
self._item_templates: Dict[str, Any] = {}
self._load_data()
def _load_data(self) -> None:
"""Load loot tables and item templates from YAML files."""
data_dir = Path(__file__).parent.parent / "data"
# Load loot tables
loot_path = data_dir / "loot_tables.yaml"
if loot_path.exists():
with open(loot_path, "r") as f:
self._loot_tables = yaml.safe_load(f)
logger.info("loaded_loot_tables", count=len(self._loot_tables))
else:
logger.warning("loot_tables_not_found", path=str(loot_path))
# Load generic item templates
items_path = data_dir / "generic_items.yaml"
if items_path.exists():
with open(items_path, "r") as f:
data = yaml.safe_load(f)
self._item_templates = data.get("templates", {})
logger.info("loaded_item_templates", count=len(self._item_templates))
else:
logger.warning("item_templates_not_found", path=str(items_path))
def determine_search_outcome(
self,
character: Character,
location_type: str,
dc: int = 12,
bonus: int = 0
) -> SearchOutcome:
"""
Determine the outcome of a search action.
Uses a perception check to determine success, then selects items
from the appropriate loot table based on the roll margin.
Args:
character: The character performing the search
location_type: Type of location (forest, cave, town, etc.)
dc: Difficulty class (default 12 = easy-medium)
bonus: Additional bonus to the check
Returns:
SearchOutcome with check result, items found, and gold found
"""
# Get character's effective wisdom for perception
effective_stats = character.get_effective_stats()
wisdom = effective_stats.wisdom
# Perform the perception check
check_result = perception_check(wisdom, dc, bonus)
# Determine loot based on result
items_found: List[ItemFound] = []
gold_found: int = 0
if check_result.success:
# Get loot table for this location (fall back to default)
loot_table = self._loot_tables.get(
location_type.lower(),
self._loot_tables.get("default", {})
)
# Select item rarity based on margin
if check_result.margin >= 10:
rarity = "rare"
elif check_result.margin >= 5:
rarity = "uncommon"
else:
rarity = "common"
# Get items for this rarity
item_keys = loot_table.get(rarity, [])
if item_keys:
# Select 1-2 items based on margin
num_items = 1 if check_result.margin < 8 else 2
selected_keys = random.sample(
item_keys,
min(num_items, len(item_keys))
)
for key in selected_keys:
template = self._item_templates.get(key)
if template:
items_found.append(ItemFound(
template_key=key,
name=template.get("name", key.title()),
description=template.get("description", ""),
value=template.get("value", 1),
))
# Calculate gold found
gold_config = loot_table.get("gold", {})
if gold_config:
min_gold = gold_config.get("min", 0)
max_gold = gold_config.get("max", 10)
bonus_per_margin = gold_config.get("bonus_per_margin", 0)
base_gold = random.randint(min_gold, max_gold)
margin_bonus = check_result.margin * bonus_per_margin
gold_found = base_gold + margin_bonus
logger.info(
"search_outcome_determined",
character_id=character.character_id,
location_type=location_type,
dc=dc,
success=check_result.success,
margin=check_result.margin,
items_count=len(items_found),
gold_found=gold_found
)
return SearchOutcome(
check_result=check_result,
items_found=items_found,
gold_found=gold_found
)
def determine_skill_check_outcome(
self,
character: Character,
skill_type: SkillType,
dc: int,
bonus: int = 0,
context: Optional[Dict[str, Any]] = None
) -> SkillCheckOutcome:
"""
Determine the outcome of a generic skill check.
Args:
character: The character performing the check
skill_type: The type of skill check (PERSUASION, STEALTH, etc.)
dc: Difficulty class to beat
bonus: Additional bonus to the check
context: Optional context for AI narration (e.g., NPC name, door type)
Returns:
SkillCheckOutcome with check result and context
"""
# Get the appropriate stat for this skill
stat_name = get_stat_for_skill(skill_type)
effective_stats = character.get_effective_stats()
stat_value = getattr(effective_stats, stat_name, 10)
# Perform the check
check_result = skill_check(stat_value, dc, skill_type, bonus)
# Build outcome context
outcome_context = context or {}
outcome_context["skill_used"] = skill_type.name.lower()
outcome_context["stat_used"] = stat_name
logger.info(
"skill_check_outcome_determined",
character_id=character.character_id,
skill=skill_type.name,
stat=stat_name,
dc=dc,
success=check_result.success,
margin=check_result.margin
)
return SkillCheckOutcome(
check_result=check_result,
context=outcome_context
)
def determine_persuasion_outcome(
self,
character: Character,
dc: int,
npc_name: Optional[str] = None,
bonus: int = 0
) -> SkillCheckOutcome:
"""
Convenience method for persuasion checks.
Args:
character: The character attempting persuasion
dc: Difficulty class based on NPC disposition
npc_name: Name of the NPC being persuaded
bonus: Additional bonus
Returns:
SkillCheckOutcome
"""
context = {"npc_name": npc_name} if npc_name else {}
return self.determine_skill_check_outcome(
character,
SkillType.PERSUASION,
dc,
bonus,
context
)
def determine_stealth_outcome(
self,
character: Character,
dc: int,
situation: Optional[str] = None,
bonus: int = 0
) -> SkillCheckOutcome:
"""
Convenience method for stealth checks.
Args:
character: The character attempting stealth
dc: Difficulty class based on environment/observers
situation: Description of what they're sneaking past
bonus: Additional bonus
Returns:
SkillCheckOutcome
"""
context = {"situation": situation} if situation else {}
return self.determine_skill_check_outcome(
character,
SkillType.STEALTH,
dc,
bonus,
context
)
def determine_lockpicking_outcome(
self,
character: Character,
dc: int,
lock_description: Optional[str] = None,
bonus: int = 0
) -> SkillCheckOutcome:
"""
Convenience method for lockpicking checks.
Args:
character: The character attempting to pick the lock
dc: Difficulty class based on lock quality
lock_description: Description of the lock/door
bonus: Additional bonus (e.g., from thieves' tools)
Returns:
SkillCheckOutcome
"""
context = {"lock_description": lock_description} if lock_description else {}
return self.determine_skill_check_outcome(
character,
SkillType.LOCKPICKING,
dc,
bonus,
context
)
def get_dc_for_difficulty(self, difficulty: str) -> int:
"""
Get the DC value for a named difficulty.
Args:
difficulty: Difficulty name (trivial, easy, medium, hard, very_hard)
Returns:
DC value
"""
difficulty_map = {
"trivial": Difficulty.TRIVIAL.value,
"easy": Difficulty.EASY.value,
"medium": Difficulty.MEDIUM.value,
"hard": Difficulty.HARD.value,
"very_hard": Difficulty.VERY_HARD.value,
"nearly_impossible": Difficulty.NEARLY_IMPOSSIBLE.value,
}
return difficulty_map.get(difficulty.lower(), Difficulty.MEDIUM.value)
# Global instance for use in API endpoints
outcome_service = OutcomeService()

View File

@@ -0,0 +1,602 @@
"""
Rate Limiter Service
This module implements tier-based rate limiting for AI requests using Redis
for distributed counting. Each user tier has a different daily limit for
AI-generated turns.
Usage:
from app.services.rate_limiter_service import RateLimiterService, RateLimitExceeded
from app.ai.model_selector import UserTier
# Initialize service
rate_limiter = RateLimiterService()
# Check and increment usage
try:
rate_limiter.check_rate_limit("user_123", UserTier.FREE)
rate_limiter.increment_usage("user_123")
except RateLimitExceeded as e:
print(f"Rate limit exceeded: {e}")
# Get remaining turns
remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE)
"""
from datetime import date, datetime, timezone, timedelta
from typing import Optional
from app.services.redis_service import RedisService, RedisServiceError
from app.ai.model_selector import UserTier
from app.utils.logging import get_logger
# Initialize logger
logger = get_logger(__file__)
class RateLimitExceeded(Exception):
"""
Raised when a user has exceeded their daily rate limit.
Attributes:
user_id: The user who exceeded the limit
user_tier: The user's subscription tier
limit: The daily limit for their tier
current_usage: The current usage count
reset_time: UTC timestamp when the limit resets
"""
def __init__(
self,
user_id: str,
user_tier: UserTier,
limit: int,
current_usage: int,
reset_time: datetime
):
self.user_id = user_id
self.user_tier = user_tier
self.limit = limit
self.current_usage = current_usage
self.reset_time = reset_time
message = (
f"Rate limit exceeded for user {user_id} ({user_tier.value} tier). "
f"Used {current_usage}/{limit} turns. Resets at {reset_time.isoformat()}"
)
super().__init__(message)
class RateLimiterService:
"""
Service for managing tier-based rate limiting.
This service uses Redis to track daily AI usage per user and enforces
limits based on subscription tier. Counters reset daily at midnight UTC.
Tier Limits:
- Free: 20 turns/day
- Basic: 50 turns/day
- Premium: 100 turns/day
- Elite: 200 turns/day
Attributes:
redis: RedisService instance for counter storage
tier_limits: Mapping of tier to daily turn limit
"""
# Daily turn limits per tier
TIER_LIMITS = {
UserTier.FREE: 20,
UserTier.BASIC: 50,
UserTier.PREMIUM: 100,
UserTier.ELITE: 200,
}
# Daily DM question limits per tier
DM_QUESTION_LIMITS = {
UserTier.FREE: 10,
UserTier.BASIC: 20,
UserTier.PREMIUM: 50,
UserTier.ELITE: -1, # -1 means unlimited
}
# Redis key prefix for rate limit counters
KEY_PREFIX = "rate_limit:daily:"
DM_QUESTION_PREFIX = "rate_limit:dm_questions:"
def __init__(self, redis_service: Optional[RedisService] = None):
"""
Initialize the rate limiter service.
Args:
redis_service: Optional RedisService instance. If not provided,
a new instance will be created.
"""
self.redis = redis_service or RedisService()
logger.info(
"RateLimiterService initialized",
tier_limits=self.TIER_LIMITS
)
def _get_daily_key(self, user_id: str, day: Optional[date] = None) -> str:
"""
Generate the Redis key for a user's daily counter.
Args:
user_id: The user ID
day: The date (defaults to today UTC)
Returns:
Redis key in format "rate_limit:daily:user_id:YYYY-MM-DD"
"""
if day is None:
day = datetime.now(timezone.utc).date()
return f"{self.KEY_PREFIX}{user_id}:{day.isoformat()}"
def _get_seconds_until_midnight_utc(self) -> int:
"""
Calculate seconds remaining until midnight UTC.
Returns:
Number of seconds until the next UTC midnight
"""
now = datetime.now(timezone.utc)
tomorrow = datetime(
now.year, now.month, now.day,
tzinfo=timezone.utc
) + timedelta(days=1)
return int((tomorrow - now).total_seconds())
def _get_reset_time(self) -> datetime:
"""
Get the UTC datetime when the rate limit resets.
Returns:
Datetime of next midnight UTC
"""
now = datetime.now(timezone.utc)
return datetime(
now.year, now.month, now.day,
tzinfo=timezone.utc
) + timedelta(days=1)
def get_limit_for_tier(self, user_tier: UserTier) -> int:
"""
Get the daily turn limit for a specific tier.
Args:
user_tier: The user's subscription tier
Returns:
Daily turn limit for the tier
"""
return self.TIER_LIMITS.get(user_tier, self.TIER_LIMITS[UserTier.FREE])
def get_current_usage(self, user_id: str) -> int:
"""
Get the current daily usage count for a user.
Args:
user_id: The user ID to check
Returns:
Current usage count (0 if no usage today)
Raises:
RedisServiceError: If Redis operation fails
"""
key = self._get_daily_key(user_id)
try:
value = self.redis.get(key)
usage = int(value) if value else 0
logger.debug(
"Retrieved current usage",
user_id=user_id,
usage=usage
)
return usage
except (ValueError, TypeError) as e:
logger.error(
"Invalid usage value in Redis",
user_id=user_id,
error=str(e)
)
return 0
def check_rate_limit(self, user_id: str, user_tier: UserTier) -> None:
"""
Check if a user has exceeded their daily rate limit.
This method checks the current usage against the tier limit and
raises an exception if the limit has been reached.
Args:
user_id: The user ID to check
user_tier: The user's subscription tier
Raises:
RateLimitExceeded: If the user has reached their daily limit
RedisServiceError: If Redis operation fails
"""
current_usage = self.get_current_usage(user_id)
limit = self.get_limit_for_tier(user_tier)
if current_usage >= limit:
reset_time = self._get_reset_time()
logger.warning(
"Rate limit exceeded",
user_id=user_id,
user_tier=user_tier.value,
current_usage=current_usage,
limit=limit,
reset_time=reset_time.isoformat()
)
raise RateLimitExceeded(
user_id=user_id,
user_tier=user_tier,
limit=limit,
current_usage=current_usage,
reset_time=reset_time
)
logger.debug(
"Rate limit check passed",
user_id=user_id,
user_tier=user_tier.value,
current_usage=current_usage,
limit=limit
)
def increment_usage(self, user_id: str) -> int:
"""
Increment the daily usage counter for a user.
This method should be called after successfully processing an AI request.
The counter will automatically expire at midnight UTC.
Args:
user_id: The user ID to increment
Returns:
The new usage count after incrementing
Raises:
RedisServiceError: If Redis operation fails
"""
key = self._get_daily_key(user_id)
# Increment the counter
new_count = self.redis.incr(key)
# Set expiration if this is the first increment (new_count == 1)
# This ensures the key expires at midnight UTC
if new_count == 1:
ttl = self._get_seconds_until_midnight_utc()
self.redis.expire(key, ttl)
logger.debug(
"Set expiration on new rate limit key",
user_id=user_id,
ttl=ttl
)
logger.info(
"Incremented usage counter",
user_id=user_id,
new_count=new_count
)
return new_count
def get_remaining_turns(self, user_id: str, user_tier: UserTier) -> int:
"""
Get the number of remaining turns for a user today.
Args:
user_id: The user ID to check
user_tier: The user's subscription tier
Returns:
Number of turns remaining (0 if limit reached)
"""
current_usage = self.get_current_usage(user_id)
limit = self.get_limit_for_tier(user_tier)
remaining = max(0, limit - current_usage)
logger.debug(
"Calculated remaining turns",
user_id=user_id,
user_tier=user_tier.value,
current_usage=current_usage,
limit=limit,
remaining=remaining
)
return remaining
def get_usage_info(self, user_id: str, user_tier: UserTier) -> dict:
"""
Get comprehensive usage information for a user.
Args:
user_id: The user ID to check
user_tier: The user's subscription tier
Returns:
Dictionary with usage info:
- user_id: User identifier
- user_tier: Subscription tier
- current_usage: Current daily usage
- daily_limit: Daily limit for tier
- remaining: Remaining turns
- reset_time: ISO format UTC reset time
- is_limited: Whether limit has been reached
"""
current_usage = self.get_current_usage(user_id)
limit = self.get_limit_for_tier(user_tier)
remaining = max(0, limit - current_usage)
reset_time = self._get_reset_time()
info = {
"user_id": user_id,
"user_tier": user_tier.value,
"current_usage": current_usage,
"daily_limit": limit,
"remaining": remaining,
"reset_time": reset_time.isoformat(),
"is_limited": current_usage >= limit
}
logger.debug("Retrieved usage info", **info)
return info
def reset_usage(self, user_id: str) -> bool:
"""
Reset the daily usage counter for a user.
This is primarily for admin/testing purposes.
Args:
user_id: The user ID to reset
Returns:
True if the counter was deleted, False if it didn't exist
Raises:
RedisServiceError: If Redis operation fails
"""
key = self._get_daily_key(user_id)
deleted = self.redis.delete(key)
logger.info(
"Reset usage counter",
user_id=user_id,
deleted=deleted > 0
)
return deleted > 0
# ===== DM QUESTION RATE LIMITING =====
def _get_dm_question_key(self, user_id: str, day: Optional[date] = None) -> str:
"""
Generate the Redis key for a user's daily DM question counter.
Args:
user_id: The user ID
day: The date (defaults to today UTC)
Returns:
Redis key in format "rate_limit:dm_questions:user_id:YYYY-MM-DD"
"""
if day is None:
day = datetime.now(timezone.utc).date()
return f"{self.DM_QUESTION_PREFIX}{user_id}:{day.isoformat()}"
def get_dm_question_limit_for_tier(self, user_tier: UserTier) -> int:
"""
Get the daily DM question limit for a specific tier.
Args:
user_tier: The user's subscription tier
Returns:
Daily DM question limit for the tier (-1 for unlimited)
"""
return self.DM_QUESTION_LIMITS.get(user_tier, self.DM_QUESTION_LIMITS[UserTier.FREE])
def get_current_dm_usage(self, user_id: str) -> int:
"""
Get the current daily DM question usage count for a user.
Args:
user_id: The user ID to check
Returns:
Current DM question usage count (0 if no usage today)
Raises:
RedisServiceError: If Redis operation fails
"""
key = self._get_dm_question_key(user_id)
try:
value = self.redis.get(key)
usage = int(value) if value else 0
logger.debug(
"Retrieved current DM question usage",
user_id=user_id,
usage=usage
)
return usage
except (ValueError, TypeError) as e:
logger.error(
"Invalid DM question usage value in Redis",
user_id=user_id,
error=str(e)
)
return 0
def check_dm_question_limit(self, user_id: str, user_tier: UserTier) -> None:
"""
Check if a user has exceeded their daily DM question limit.
Args:
user_id: The user ID to check
user_tier: The user's subscription tier
Raises:
RateLimitExceeded: If the user has reached their daily DM question limit
RedisServiceError: If Redis operation fails
"""
limit = self.get_dm_question_limit_for_tier(user_tier)
# -1 means unlimited
if limit == -1:
logger.debug(
"DM question limit check passed (unlimited)",
user_id=user_id,
user_tier=user_tier.value
)
return
current_usage = self.get_current_dm_usage(user_id)
if current_usage >= limit:
reset_time = self._get_reset_time()
logger.warning(
"DM question limit exceeded",
user_id=user_id,
user_tier=user_tier.value,
current_usage=current_usage,
limit=limit,
reset_time=reset_time.isoformat()
)
raise RateLimitExceeded(
user_id=user_id,
user_tier=user_tier,
limit=limit,
current_usage=current_usage,
reset_time=reset_time
)
logger.debug(
"DM question limit check passed",
user_id=user_id,
user_tier=user_tier.value,
current_usage=current_usage,
limit=limit
)
def increment_dm_usage(self, user_id: str) -> int:
"""
Increment the daily DM question counter for a user.
Args:
user_id: The user ID to increment
Returns:
The new DM question usage count after incrementing
Raises:
RedisServiceError: If Redis operation fails
"""
key = self._get_dm_question_key(user_id)
# Increment the counter
new_count = self.redis.incr(key)
# Set expiration if this is the first increment
if new_count == 1:
ttl = self._get_seconds_until_midnight_utc()
self.redis.expire(key, ttl)
logger.debug(
"Set expiration on new DM question key",
user_id=user_id,
ttl=ttl
)
logger.info(
"Incremented DM question counter",
user_id=user_id,
new_count=new_count
)
return new_count
def get_remaining_dm_questions(self, user_id: str, user_tier: UserTier) -> int:
"""
Get the number of remaining DM questions for a user today.
Args:
user_id: The user ID to check
user_tier: The user's subscription tier
Returns:
Number of DM questions remaining (-1 if unlimited, 0 if limit reached)
"""
limit = self.get_dm_question_limit_for_tier(user_tier)
# -1 means unlimited
if limit == -1:
return -1
current_usage = self.get_current_dm_usage(user_id)
remaining = max(0, limit - current_usage)
logger.debug(
"Calculated remaining DM questions",
user_id=user_id,
user_tier=user_tier.value,
current_usage=current_usage,
limit=limit,
remaining=remaining
)
return remaining
def reset_dm_usage(self, user_id: str) -> bool:
"""
Reset the daily DM question counter for a user.
This is primarily for admin/testing purposes.
Args:
user_id: The user ID to reset
Returns:
True if the counter was deleted, False if it didn't exist
Raises:
RedisServiceError: If Redis operation fails
"""
key = self._get_dm_question_key(user_id)
deleted = self.redis.delete(key)
logger.info(
"Reset DM question counter",
user_id=user_id,
deleted=deleted > 0
)
return deleted > 0

View File

@@ -0,0 +1,505 @@
"""
Redis Service Wrapper
This module provides a wrapper around the redis-py client for handling caching,
job queue data, and temporary storage. It provides connection pooling, automatic
reconnection, and a clean interface for common Redis operations.
Usage:
from app.services.redis_service import RedisService
# Initialize service
redis = RedisService()
# Basic operations
redis.set("key", "value", ttl=3600) # Set with 1 hour TTL
value = redis.get("key")
redis.delete("key")
# Health check
if redis.health_check():
print("Redis is healthy")
"""
import os
import json
from typing import Optional, Any, Union
import redis
from redis.exceptions import RedisError, ConnectionError as RedisConnectionError
from app.utils.logging import get_logger
# Initialize logger
logger = get_logger(__file__)
class RedisServiceError(Exception):
"""Base exception for Redis service errors."""
pass
class RedisConnectionFailed(RedisServiceError):
"""Raised when Redis connection cannot be established."""
pass
class RedisService:
"""
Service class for interacting with Redis.
This class provides:
- Connection pooling for efficient connection management
- Basic operations: get, set, delete, exists
- TTL support for caching
- Health check for monitoring
- Automatic JSON serialization for complex objects
Attributes:
pool: Redis connection pool
client: Redis client instance
"""
def __init__(self, redis_url: Optional[str] = None):
"""
Initialize the Redis service.
Reads configuration from environment variables if not provided:
- REDIS_URL: Full Redis URL (e.g., redis://localhost:6379/0)
Args:
redis_url: Optional Redis URL to override environment variable
Raises:
RedisConnectionFailed: If connection to Redis fails
"""
self.redis_url = redis_url or os.getenv('REDIS_URL', 'redis://localhost:6379/0')
if not self.redis_url:
logger.error("Missing Redis URL configuration")
raise ValueError("Redis URL not configured. Set REDIS_URL environment variable.")
try:
# Create connection pool for efficient connection management
# Connection pooling allows multiple operations to share connections
# and automatically manages connection lifecycle
self.pool = redis.ConnectionPool.from_url(
self.redis_url,
max_connections=10,
decode_responses=True, # Return strings instead of bytes
socket_connect_timeout=5, # Connection timeout in seconds
socket_timeout=5, # Operation timeout in seconds
retry_on_timeout=True, # Retry on timeout
)
# Create client using the connection pool
self.client = redis.Redis(connection_pool=self.pool)
# Test connection
self.client.ping()
logger.info("Redis service initialized", redis_url=self._sanitize_url(self.redis_url))
except RedisConnectionError as e:
logger.error("Failed to connect to Redis", redis_url=self._sanitize_url(self.redis_url), error=str(e))
raise RedisConnectionFailed(f"Could not connect to Redis at {self._sanitize_url(self.redis_url)}: {e}")
except RedisError as e:
logger.error("Redis initialization error", error=str(e))
raise RedisServiceError(f"Redis initialization failed: {e}")
def get(self, key: str) -> Optional[str]:
"""
Get a value from Redis by key.
Args:
key: The key to retrieve
Returns:
The value as string if found, None if key doesn't exist
Raises:
RedisServiceError: If the operation fails
"""
try:
value = self.client.get(key)
if value is not None:
logger.debug("Redis GET", key=key, found=True)
else:
logger.debug("Redis GET", key=key, found=False)
return value
except RedisError as e:
logger.error("Redis GET failed", key=key, error=str(e))
raise RedisServiceError(f"Failed to get key '{key}': {e}")
def get_json(self, key: str) -> Optional[Any]:
"""
Get a value from Redis and deserialize it from JSON.
Args:
key: The key to retrieve
Returns:
The deserialized value if found, None if key doesn't exist
Raises:
RedisServiceError: If the operation fails or JSON is invalid
"""
value = self.get(key)
if value is None:
return None
try:
return json.loads(value)
except json.JSONDecodeError as e:
logger.error("Failed to decode JSON from Redis", key=key, error=str(e))
raise RedisServiceError(f"Failed to decode JSON for key '{key}': {e}")
def set(
self,
key: str,
value: str,
ttl: Optional[int] = None,
nx: bool = False,
xx: bool = False
) -> bool:
"""
Set a value in Redis.
Args:
key: The key to set
value: The value to store (must be string)
ttl: Time to live in seconds (None for no expiration)
nx: Only set if key does not exist (for locking)
xx: Only set if key already exists
Returns:
True if the key was set, False if not set (due to nx/xx conditions)
Raises:
RedisServiceError: If the operation fails
"""
try:
result = self.client.set(
key,
value,
ex=ttl, # Expiration in seconds
nx=nx, # Only set if not exists
xx=xx # Only set if exists
)
# set() returns True if set, None if not set due to nx/xx
success = result is True or result == 1
logger.debug("Redis SET", key=key, ttl=ttl, nx=nx, xx=xx, success=success)
return success
except RedisError as e:
logger.error("Redis SET failed", key=key, error=str(e))
raise RedisServiceError(f"Failed to set key '{key}': {e}")
def set_json(
self,
key: str,
value: Any,
ttl: Optional[int] = None,
nx: bool = False,
xx: bool = False
) -> bool:
"""
Serialize a value to JSON and store it in Redis.
Args:
key: The key to set
value: The value to serialize and store (must be JSON-serializable)
ttl: Time to live in seconds (None for no expiration)
nx: Only set if key does not exist
xx: Only set if key already exists
Returns:
True if the key was set, False if not set (due to nx/xx conditions)
Raises:
RedisServiceError: If the operation fails or value is not JSON-serializable
"""
try:
json_value = json.dumps(value)
except (TypeError, ValueError) as e:
logger.error("Failed to serialize value to JSON", key=key, error=str(e))
raise RedisServiceError(f"Failed to serialize value for key '{key}': {e}")
return self.set(key, json_value, ttl=ttl, nx=nx, xx=xx)
def delete(self, *keys: str) -> int:
"""
Delete one or more keys from Redis.
Args:
*keys: One or more keys to delete
Returns:
The number of keys that were deleted
Raises:
RedisServiceError: If the operation fails
"""
if not keys:
return 0
try:
deleted_count = self.client.delete(*keys)
logger.debug("Redis DELETE", keys=keys, deleted_count=deleted_count)
return deleted_count
except RedisError as e:
logger.error("Redis DELETE failed", keys=keys, error=str(e))
raise RedisServiceError(f"Failed to delete keys {keys}: {e}")
def exists(self, *keys: str) -> int:
"""
Check if one or more keys exist in Redis.
Args:
*keys: One or more keys to check
Returns:
The number of keys that exist
Raises:
RedisServiceError: If the operation fails
"""
if not keys:
return 0
try:
exists_count = self.client.exists(*keys)
logger.debug("Redis EXISTS", keys=keys, exists_count=exists_count)
return exists_count
except RedisError as e:
logger.error("Redis EXISTS failed", keys=keys, error=str(e))
raise RedisServiceError(f"Failed to check existence of keys {keys}: {e}")
def expire(self, key: str, ttl: int) -> bool:
"""
Set a TTL (time to live) on an existing key.
Args:
key: The key to set expiration on
ttl: Time to live in seconds
Returns:
True if the timeout was set, False if key doesn't exist
Raises:
RedisServiceError: If the operation fails
"""
try:
result = self.client.expire(key, ttl)
logger.debug("Redis EXPIRE", key=key, ttl=ttl, success=result)
return result
except RedisError as e:
logger.error("Redis EXPIRE failed", key=key, ttl=ttl, error=str(e))
raise RedisServiceError(f"Failed to set expiration for key '{key}': {e}")
def ttl(self, key: str) -> int:
"""
Get the remaining TTL (time to live) for a key.
Args:
key: The key to check
Returns:
TTL in seconds, -1 if key exists but has no expiry, -2 if key doesn't exist
Raises:
RedisServiceError: If the operation fails
"""
try:
remaining = self.client.ttl(key)
logger.debug("Redis TTL", key=key, remaining=remaining)
return remaining
except RedisError as e:
logger.error("Redis TTL failed", key=key, error=str(e))
raise RedisServiceError(f"Failed to get TTL for key '{key}': {e}")
def incr(self, key: str, amount: int = 1) -> int:
"""
Increment a key's value by the given amount.
If the key doesn't exist, it will be created with the increment value.
Args:
key: The key to increment
amount: Amount to increment by (default 1)
Returns:
The new value after incrementing
Raises:
RedisServiceError: If the operation fails or value is not an integer
"""
try:
new_value = self.client.incrby(key, amount)
logger.debug("Redis INCR", key=key, amount=amount, new_value=new_value)
return new_value
except RedisError as e:
logger.error("Redis INCR failed", key=key, amount=amount, error=str(e))
raise RedisServiceError(f"Failed to increment key '{key}': {e}")
def decr(self, key: str, amount: int = 1) -> int:
"""
Decrement a key's value by the given amount.
If the key doesn't exist, it will be created with the negative increment value.
Args:
key: The key to decrement
amount: Amount to decrement by (default 1)
Returns:
The new value after decrementing
Raises:
RedisServiceError: If the operation fails or value is not an integer
"""
try:
new_value = self.client.decrby(key, amount)
logger.debug("Redis DECR", key=key, amount=amount, new_value=new_value)
return new_value
except RedisError as e:
logger.error("Redis DECR failed", key=key, amount=amount, error=str(e))
raise RedisServiceError(f"Failed to decrement key '{key}': {e}")
def health_check(self) -> bool:
"""
Check if Redis connection is healthy.
This performs a PING command to verify the connection is working.
Returns:
True if Redis is healthy and responding, False otherwise
"""
try:
response = self.client.ping()
if response:
logger.debug("Redis health check passed")
return True
else:
logger.warning("Redis health check failed - unexpected response", response=response)
return False
except RedisError as e:
logger.error("Redis health check failed", error=str(e))
return False
def info(self) -> dict:
"""
Get Redis server information.
Returns:
Dictionary containing server info (version, memory, clients, etc.)
Raises:
RedisServiceError: If the operation fails
"""
try:
info = self.client.info()
logger.debug("Redis INFO retrieved", redis_version=info.get('redis_version'))
return info
except RedisError as e:
logger.error("Redis INFO failed", error=str(e))
raise RedisServiceError(f"Failed to get Redis info: {e}")
def flush_db(self) -> bool:
"""
Delete all keys in the current database.
WARNING: This is a destructive operation. Use with caution.
Returns:
True if successful
Raises:
RedisServiceError: If the operation fails
"""
try:
self.client.flushdb()
logger.warning("Redis database flushed")
return True
except RedisError as e:
logger.error("Redis FLUSHDB failed", error=str(e))
raise RedisServiceError(f"Failed to flush database: {e}")
def close(self) -> None:
"""
Close all connections in the pool.
Call this when shutting down the application to cleanly release connections.
"""
try:
self.pool.disconnect()
logger.info("Redis connection pool closed")
except Exception as e:
logger.error("Error closing Redis connection pool", error=str(e))
def _sanitize_url(self, url: str) -> str:
"""
Remove password from Redis URL for safe logging.
Args:
url: Redis URL that may contain password
Returns:
URL with password masked
"""
# Simple sanitization - mask password if present
# Format: redis://user:password@host:port/db
if '@' in url:
# Split on @ and mask everything before it except the protocol
parts = url.split('@')
protocol_and_creds = parts[0]
host_and_rest = parts[1]
if '://' in protocol_and_creds:
protocol = protocol_and_creds.split('://')[0]
return f"{protocol}://***@{host_and_rest}"
return url
def __enter__(self):
"""Context manager entry."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit - close connections."""
self.close()
return False

View File

@@ -0,0 +1,705 @@
"""
Session Service - CRUD operations for game sessions.
This service handles creating, reading, updating, and managing game sessions,
with support for both solo and multiplayer sessions.
"""
import json
from typing import List, Optional
from datetime import datetime, timezone
from appwrite.query import Query
from appwrite.id import ID
from app.models.session import GameSession, GameState, ConversationEntry, SessionConfig
from app.models.enums import SessionStatus, SessionType
from app.models.action_prompt import LocationType
from app.services.database_service import get_database_service
from app.services.appwrite_service import AppwriteService
from app.services.character_service import get_character_service, CharacterNotFound
from app.services.location_loader import get_location_loader
from app.utils.logging import get_logger
logger = get_logger(__file__)
# Session limits per user
MAX_ACTIVE_SESSIONS = 5
class SessionNotFound(Exception):
"""Raised when session ID doesn't exist or user doesn't own it."""
pass
class SessionLimitExceeded(Exception):
"""Raised when user tries to create more sessions than allowed."""
pass
class SessionValidationError(Exception):
"""Raised when session validation fails."""
pass
class SessionService:
"""
Service for managing game sessions.
This service provides:
- Session creation (solo and multiplayer)
- Session retrieval and listing
- Session state updates
- Conversation history management
- Game state tracking
"""
def __init__(self):
"""Initialize the session service with dependencies."""
self.db = get_database_service()
self.appwrite = AppwriteService()
self.character_service = get_character_service()
self.collection_id = "game_sessions"
logger.info("SessionService initialized")
def create_solo_session(
self,
user_id: str,
character_id: str,
starting_location: Optional[str] = None,
starting_location_type: Optional[LocationType] = None
) -> GameSession:
"""
Create a new solo game session.
This method:
1. Validates user owns the character
2. Validates user hasn't exceeded session limit
3. Determines starting location from location data
4. Creates session with initial game state
5. Stores in Appwrite database
Args:
user_id: Owner's user ID
character_id: Character ID for this session
starting_location: Initial location ID (optional, uses default starting location)
starting_location_type: Initial location type (optional, derived from location data)
Returns:
Created GameSession instance
Raises:
CharacterNotFound: If character doesn't exist or user doesn't own it
SessionLimitExceeded: If user has reached active session limit
"""
try:
logger.info("Creating solo session",
user_id=user_id,
character_id=character_id)
# Validate user owns the character
character = self.character_service.get_character(character_id, user_id)
if not character:
raise CharacterNotFound(f"Character not found: {character_id}")
# Determine starting location from location data if not provided
if not starting_location:
location_loader = get_location_loader()
starting_locations = location_loader.get_starting_locations()
if starting_locations:
# Use first starting location (usually crossville_village)
start_loc = starting_locations[0]
starting_location = start_loc.location_id
# Convert from enums.LocationType to action_prompt.LocationType via string value
starting_location_type = LocationType(start_loc.location_type.value)
logger.info("Using starting location from data",
location_id=starting_location,
location_type=starting_location_type.value)
else:
# Fallback to crossville_village
starting_location = "crossville_village"
starting_location_type = LocationType.TOWN
logger.warning("No starting locations found, using fallback",
location_id=starting_location)
# Ensure location type is set
if not starting_location_type:
starting_location_type = LocationType.TOWN
# Check session limit
active_count = self.count_user_sessions(user_id, active_only=True)
if active_count >= MAX_ACTIVE_SESSIONS:
logger.warning("Session limit exceeded",
user_id=user_id,
current=active_count,
limit=MAX_ACTIVE_SESSIONS)
raise SessionLimitExceeded(
f"Maximum active sessions reached ({active_count}/{MAX_ACTIVE_SESSIONS}). "
f"Please end an existing session to start a new one."
)
# Generate unique session ID
session_id = ID.unique()
# Create game state with starting location
game_state = GameState(
current_location=starting_location,
location_type=starting_location_type,
discovered_locations=[starting_location],
active_quests=[],
world_events=[]
)
# Create session instance
session = GameSession(
session_id=session_id,
session_type=SessionType.SOLO,
solo_character_id=character_id,
user_id=user_id,
party_member_ids=[],
config=SessionConfig(),
game_state=game_state,
turn_order=[character_id],
current_turn=0,
turn_number=0,
status=SessionStatus.ACTIVE
)
# Serialize and store
session_dict = session.to_dict()
session_json = json.dumps(session_dict)
document_data = {
'userId': user_id,
'characterId': character_id,
'sessionData': session_json,
'status': SessionStatus.ACTIVE.value,
'sessionType': SessionType.SOLO.value
}
self.db.create_document(
collection_id=self.collection_id,
data=document_data,
document_id=session_id
)
logger.info("Solo session created successfully",
session_id=session_id,
user_id=user_id,
character_id=character_id)
return session
except (CharacterNotFound, SessionLimitExceeded):
raise
except Exception as e:
logger.error("Failed to create solo session",
user_id=user_id,
character_id=character_id,
error=str(e))
raise
def get_session(self, session_id: str, user_id: Optional[str] = None) -> GameSession:
"""
Get a session by ID.
Args:
session_id: Session ID
user_id: Optional user ID for ownership validation
Returns:
GameSession instance
Raises:
SessionNotFound: If session doesn't exist or user doesn't own it
"""
try:
logger.debug("Fetching session", session_id=session_id)
# Get document from database
document = self.db.get_row(self.collection_id, session_id)
if not document:
logger.warning("Session not found", session_id=session_id)
raise SessionNotFound(f"Session not found: {session_id}")
# Verify ownership if user_id provided
if user_id and document.data.get('userId') != user_id:
logger.warning("Session ownership mismatch",
session_id=session_id,
expected_user=user_id,
actual_user=document.data.get('userId'))
raise SessionNotFound(f"Session not found: {session_id}")
# Parse session data
session_json = document.data.get('sessionData')
session_dict = json.loads(session_json)
session = GameSession.from_dict(session_dict)
logger.debug("Session fetched successfully", session_id=session_id)
return session
except SessionNotFound:
raise
except Exception as e:
logger.error("Failed to fetch session",
session_id=session_id,
error=str(e))
raise
def update_session(self, session: GameSession) -> GameSession:
"""
Update a session in the database.
Args:
session: GameSession instance with updated data
Returns:
Updated GameSession instance
"""
try:
logger.debug("Updating session", session_id=session.session_id)
# Serialize session
session_dict = session.to_dict()
session_json = json.dumps(session_dict)
# Update in database
self.db.update_document(
collection_id=self.collection_id,
document_id=session.session_id,
data={
'sessionData': session_json,
'status': session.status.value
}
)
logger.debug("Session updated successfully", session_id=session.session_id)
return session
except Exception as e:
logger.error("Failed to update session",
session_id=session.session_id,
error=str(e))
raise
def get_user_sessions(
self,
user_id: str,
active_only: bool = True,
limit: int = 25
) -> List[GameSession]:
"""
Get all sessions for a user.
Args:
user_id: User ID
active_only: If True, only return active sessions
limit: Maximum number of sessions to return
Returns:
List of GameSession instances
"""
try:
logger.debug("Fetching user sessions",
user_id=user_id,
active_only=active_only)
# Build query
queries = [Query.equal('userId', user_id)]
if active_only:
queries.append(Query.equal('status', SessionStatus.ACTIVE.value))
documents = self.db.list_rows(
table_id=self.collection_id,
queries=queries,
limit=limit
)
# Parse all sessions
sessions = []
for document in documents:
try:
session_json = document.data.get('sessionData')
session_dict = json.loads(session_json)
session = GameSession.from_dict(session_dict)
sessions.append(session)
except Exception as e:
logger.error("Failed to parse session",
document_id=document.id,
error=str(e))
continue
logger.debug("User sessions fetched",
user_id=user_id,
count=len(sessions))
return sessions
except Exception as e:
logger.error("Failed to fetch user sessions",
user_id=user_id,
error=str(e))
raise
def count_user_sessions(self, user_id: str, active_only: bool = True) -> int:
"""
Count sessions for a user.
Args:
user_id: User ID
active_only: If True, only count active sessions
Returns:
Number of sessions
"""
try:
queries = [Query.equal('userId', user_id)]
if active_only:
queries.append(Query.equal('status', SessionStatus.ACTIVE.value))
count = self.db.count_documents(
collection_id=self.collection_id,
queries=queries
)
logger.debug("Session count",
user_id=user_id,
active_only=active_only,
count=count)
return count
except Exception as e:
logger.error("Failed to count sessions",
user_id=user_id,
error=str(e))
return 0
def end_session(self, session_id: str, user_id: str) -> GameSession:
"""
End a session by marking it as completed.
Args:
session_id: Session ID
user_id: User ID for ownership validation
Returns:
Updated GameSession instance
Raises:
SessionNotFound: If session doesn't exist or user doesn't own it
"""
try:
logger.info("Ending session", session_id=session_id, user_id=user_id)
session = self.get_session(session_id, user_id)
session.status = SessionStatus.COMPLETED
session.update_activity()
return self.update_session(session)
except SessionNotFound:
raise
except Exception as e:
logger.error("Failed to end session",
session_id=session_id,
error=str(e))
raise
def add_conversation_entry(
self,
session_id: str,
character_id: str,
character_name: str,
action: str,
dm_response: str,
combat_log: Optional[List] = None,
quest_offered: Optional[dict] = None
) -> GameSession:
"""
Add an entry to the conversation history.
This method automatically:
- Increments turn number
- Adds timestamp
- Updates last activity
Args:
session_id: Session ID
character_id: Acting character's ID
character_name: Acting character's name
action: Player's action text
dm_response: AI DM's response
combat_log: Optional combat actions
quest_offered: Optional quest offering info
Returns:
Updated GameSession instance
"""
try:
logger.debug("Adding conversation entry",
session_id=session_id,
character_id=character_id)
session = self.get_session(session_id)
# Create conversation entry
entry = ConversationEntry(
turn=session.turn_number + 1,
character_id=character_id,
character_name=character_name,
action=action,
dm_response=dm_response,
combat_log=combat_log or [],
quest_offered=quest_offered
)
# Add entry and increment turn
session.conversation_history.append(entry)
session.turn_number += 1
session.update_activity()
# Save to database
return self.update_session(session)
except Exception as e:
logger.error("Failed to add conversation entry",
session_id=session_id,
error=str(e))
raise
def get_conversation_history(
self,
session_id: str,
limit: Optional[int] = None,
offset: int = 0
) -> List[ConversationEntry]:
"""
Get conversation history for a session.
Args:
session_id: Session ID
limit: Maximum entries to return (None for all)
offset: Number of entries to skip from end
Returns:
List of ConversationEntry instances
"""
try:
session = self.get_session(session_id)
history = session.conversation_history
# Apply offset (from end)
if offset > 0:
history = history[:-offset] if offset < len(history) else []
# Apply limit (from end)
if limit and len(history) > limit:
history = history[-limit:]
return history
except Exception as e:
logger.error("Failed to get conversation history",
session_id=session_id,
error=str(e))
raise
def get_recent_history(self, session_id: str, num_turns: int = 3) -> List[ConversationEntry]:
"""
Get the most recent conversation entries for AI context.
Args:
session_id: Session ID
num_turns: Number of recent turns to return
Returns:
List of most recent ConversationEntry instances
"""
return self.get_conversation_history(session_id, limit=num_turns)
def update_location(
self,
session_id: str,
new_location: str,
location_type: LocationType
) -> GameSession:
"""
Update the current location in the session.
Also adds location to discovered_locations if not already there.
Args:
session_id: Session ID
new_location: New location name
location_type: New location type
Returns:
Updated GameSession instance
"""
try:
logger.debug("Updating location",
session_id=session_id,
new_location=new_location)
session = self.get_session(session_id)
session.game_state.current_location = new_location
session.game_state.location_type = location_type
# Track discovered locations
if new_location not in session.game_state.discovered_locations:
session.game_state.discovered_locations.append(new_location)
session.update_activity()
return self.update_session(session)
except Exception as e:
logger.error("Failed to update location",
session_id=session_id,
error=str(e))
raise
def add_discovered_location(self, session_id: str, location: str) -> GameSession:
"""
Add a location to the discovered locations list.
Args:
session_id: Session ID
location: Location name to add
Returns:
Updated GameSession instance
"""
try:
session = self.get_session(session_id)
if location not in session.game_state.discovered_locations:
session.game_state.discovered_locations.append(location)
session.update_activity()
return self.update_session(session)
return session
except Exception as e:
logger.error("Failed to add discovered location",
session_id=session_id,
error=str(e))
raise
def add_active_quest(self, session_id: str, quest_id: str) -> GameSession:
"""
Add a quest to the active quests list.
Validates max 2 active quests limit.
Args:
session_id: Session ID
quest_id: Quest ID to add
Returns:
Updated GameSession instance
Raises:
SessionValidationError: If max quests limit exceeded
"""
try:
session = self.get_session(session_id)
# Check max active quests (2)
if len(session.game_state.active_quests) >= 2:
raise SessionValidationError(
"Maximum active quests reached (2/2). "
"Complete or abandon a quest to accept a new one."
)
if quest_id not in session.game_state.active_quests:
session.game_state.active_quests.append(quest_id)
session.update_activity()
return self.update_session(session)
return session
except SessionValidationError:
raise
except Exception as e:
logger.error("Failed to add active quest",
session_id=session_id,
quest_id=quest_id,
error=str(e))
raise
def remove_active_quest(self, session_id: str, quest_id: str) -> GameSession:
"""
Remove a quest from the active quests list.
Args:
session_id: Session ID
quest_id: Quest ID to remove
Returns:
Updated GameSession instance
"""
try:
session = self.get_session(session_id)
if quest_id in session.game_state.active_quests:
session.game_state.active_quests.remove(quest_id)
session.update_activity()
return self.update_session(session)
return session
except Exception as e:
logger.error("Failed to remove active quest",
session_id=session_id,
quest_id=quest_id,
error=str(e))
raise
def add_world_event(self, session_id: str, event: dict) -> GameSession:
"""
Add a world event to the session.
Args:
session_id: Session ID
event: Event dictionary with type, description, etc.
Returns:
Updated GameSession instance
"""
try:
session = self.get_session(session_id)
# Add timestamp if not present
if 'timestamp' not in event:
event['timestamp'] = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
session.game_state.world_events.append(event)
session.update_activity()
return self.update_session(session)
except Exception as e:
logger.error("Failed to add world event",
session_id=session_id,
error=str(e))
raise
# Global instance for convenience
_service_instance: Optional[SessionService] = None
def get_session_service() -> SessionService:
"""
Get the global SessionService instance.
Returns:
Singleton SessionService instance
"""
global _service_instance
if _service_instance is None:
_service_instance = SessionService()
return _service_instance

View File

@@ -0,0 +1,528 @@
"""
Usage Tracking Service for AI cost and usage monitoring.
This service tracks all AI usage events, calculates costs, and provides
analytics for monitoring and rate limiting purposes.
Usage:
from app.services.usage_tracking_service import UsageTrackingService
tracker = UsageTrackingService()
# Log a usage event
tracker.log_usage(
user_id="user_123",
model="anthropic/claude-3.5-sonnet",
tokens_input=100,
tokens_output=350,
task_type=TaskType.STORY_PROGRESSION
)
# Get daily usage
usage = tracker.get_daily_usage("user_123", date.today())
print(f"Total requests: {usage.total_requests}")
print(f"Estimated cost: ${usage.estimated_cost:.4f}")
"""
import os
from datetime import datetime, timezone, date, timedelta
from typing import Dict, Any, List, Optional
from uuid import uuid4
from appwrite.client import Client
from appwrite.services.tables_db import TablesDB
from appwrite.exception import AppwriteException
from appwrite.id import ID
from appwrite.query import Query
from app.utils.logging import get_logger
from app.models.ai_usage import (
AIUsageLog,
DailyUsageSummary,
MonthlyUsageSummary,
TaskType
)
logger = get_logger(__file__)
# Cost per 1000 tokens by model (in USD)
# These are estimates based on Replicate pricing
MODEL_COSTS = {
# Llama models (via Replicate) - very cheap
"meta/meta-llama-3-8b-instruct": {
"input": 0.0001, # $0.0001 per 1K input tokens
"output": 0.0001, # $0.0001 per 1K output tokens
},
"meta/meta-llama-3-70b-instruct": {
"input": 0.0006,
"output": 0.0006,
},
# Claude models (via Replicate)
"anthropic/claude-3.5-haiku": {
"input": 0.001, # $0.001 per 1K input tokens
"output": 0.005, # $0.005 per 1K output tokens
},
"anthropic/claude-3-haiku": {
"input": 0.00025,
"output": 0.00125,
},
"anthropic/claude-3.5-sonnet": {
"input": 0.003, # $0.003 per 1K input tokens
"output": 0.015, # $0.015 per 1K output tokens
},
"anthropic/claude-4.5-sonnet": {
"input": 0.003,
"output": 0.015,
},
"anthropic/claude-3-opus": {
"input": 0.015, # $0.015 per 1K input tokens
"output": 0.075, # $0.075 per 1K output tokens
},
}
# Default cost for unknown models
DEFAULT_COST = {"input": 0.001, "output": 0.005}
class UsageTrackingService:
"""
Service for tracking AI usage and calculating costs.
This service provides:
- Logging individual AI usage events to Appwrite
- Calculating estimated costs based on model pricing
- Retrieving daily and monthly usage summaries
- Analytics for monitoring and rate limiting
The service stores usage logs in an Appwrite collection named 'ai_usage_logs'.
"""
# Collection ID for usage logs
COLLECTION_ID = "ai_usage_logs"
def __init__(self):
"""
Initialize the usage tracking service.
Reads configuration from environment variables:
- APPWRITE_ENDPOINT: Appwrite API endpoint
- APPWRITE_PROJECT_ID: Appwrite project ID
- APPWRITE_API_KEY: Appwrite API key
- APPWRITE_DATABASE_ID: Appwrite database ID
Raises:
ValueError: If required environment variables are missing
"""
self.endpoint = os.getenv('APPWRITE_ENDPOINT')
self.project_id = os.getenv('APPWRITE_PROJECT_ID')
self.api_key = os.getenv('APPWRITE_API_KEY')
self.database_id = os.getenv('APPWRITE_DATABASE_ID', 'main')
if not all([self.endpoint, self.project_id, self.api_key]):
logger.error("Missing Appwrite configuration in environment variables")
raise ValueError("Appwrite configuration incomplete. Check APPWRITE_* environment variables.")
# Initialize Appwrite client
self.client = Client()
self.client.set_endpoint(self.endpoint)
self.client.set_project(self.project_id)
self.client.set_key(self.api_key)
# Initialize TablesDB service
self.tables_db = TablesDB(self.client)
logger.info("UsageTrackingService initialized", database_id=self.database_id)
def log_usage(
self,
user_id: str,
model: str,
tokens_input: int,
tokens_output: int,
task_type: TaskType,
session_id: Optional[str] = None,
character_id: Optional[str] = None,
request_duration_ms: int = 0,
success: bool = True,
error_message: Optional[str] = None
) -> AIUsageLog:
"""
Log an AI usage event.
This method creates a new usage log entry in Appwrite with all
relevant information about the AI request including calculated
estimated cost.
Args:
user_id: User who made the request
model: Model identifier (e.g., "anthropic/claude-3.5-sonnet")
tokens_input: Number of input tokens (prompt)
tokens_output: Number of output tokens (response)
task_type: Type of task (story, combat, quest, npc)
session_id: Optional game session ID
character_id: Optional character ID
request_duration_ms: Request duration in milliseconds
success: Whether the request succeeded
error_message: Error message if failed
Returns:
AIUsageLog with the logged data
Raises:
AppwriteException: If storage fails
"""
# Calculate total tokens
tokens_total = tokens_input + tokens_output
# Calculate estimated cost
estimated_cost = self._calculate_cost(model, tokens_input, tokens_output)
# Generate log ID
log_id = str(uuid4())
# Create usage log
usage_log = AIUsageLog(
log_id=log_id,
user_id=user_id,
timestamp=datetime.now(timezone.utc),
model=model,
tokens_input=tokens_input,
tokens_output=tokens_output,
tokens_total=tokens_total,
estimated_cost=estimated_cost,
task_type=task_type,
session_id=session_id,
character_id=character_id,
request_duration_ms=request_duration_ms,
success=success,
error_message=error_message,
)
try:
# Store in Appwrite
result = self.tables_db.create_row(
database_id=self.database_id,
table_id=self.COLLECTION_ID,
row_id=log_id,
data=usage_log.to_dict()
)
logger.info(
"AI usage logged",
log_id=log_id,
user_id=user_id,
model=model,
tokens_total=tokens_total,
estimated_cost=estimated_cost,
task_type=task_type.value,
success=success
)
return usage_log
except AppwriteException as e:
logger.error(
"Failed to log AI usage",
user_id=user_id,
model=model,
error=str(e),
code=e.code
)
raise
def get_daily_usage(self, user_id: str, target_date: date) -> DailyUsageSummary:
"""
Get AI usage summary for a specific day.
Args:
user_id: User ID to get usage for
target_date: Date to get usage for
Returns:
DailyUsageSummary with aggregated usage data
Raises:
AppwriteException: If query fails
"""
try:
# Build date range for the target day (UTC)
start_of_day = datetime.combine(target_date, datetime.min.time()).replace(tzinfo=timezone.utc)
end_of_day = datetime.combine(target_date, datetime.max.time()).replace(tzinfo=timezone.utc)
# Query usage logs for this user and date
result = self.tables_db.list_rows(
database_id=self.database_id,
table_id=self.COLLECTION_ID,
queries=[
Query.equal("user_id", user_id),
Query.greater_than_equal("timestamp", start_of_day.isoformat()),
Query.less_than_equal("timestamp", end_of_day.isoformat()),
Query.limit(1000) # Cap at 1000 entries per day
]
)
# Aggregate the data
total_requests = 0
total_tokens = 0
total_input_tokens = 0
total_output_tokens = 0
total_cost = 0.0
requests_by_task: Dict[str, int] = {}
for doc in result['rows']:
total_requests += 1
total_tokens += doc.get('tokens_total', 0)
total_input_tokens += doc.get('tokens_input', 0)
total_output_tokens += doc.get('tokens_output', 0)
total_cost += doc.get('estimated_cost', 0.0)
task_type = doc.get('task_type', 'general')
requests_by_task[task_type] = requests_by_task.get(task_type, 0) + 1
summary = DailyUsageSummary(
date=target_date,
user_id=user_id,
total_requests=total_requests,
total_tokens=total_tokens,
total_input_tokens=total_input_tokens,
total_output_tokens=total_output_tokens,
estimated_cost=total_cost,
requests_by_task=requests_by_task
)
logger.debug(
"Daily usage retrieved",
user_id=user_id,
date=target_date.isoformat(),
total_requests=total_requests,
estimated_cost=total_cost
)
return summary
except AppwriteException as e:
logger.error(
"Failed to get daily usage",
user_id=user_id,
date=target_date.isoformat(),
error=str(e),
code=e.code
)
raise
def get_monthly_cost(self, user_id: str, year: int, month: int) -> MonthlyUsageSummary:
"""
Get AI usage cost summary for a specific month.
Args:
user_id: User ID to get cost for
year: Year (e.g., 2025)
month: Month (1-12)
Returns:
MonthlyUsageSummary with aggregated cost data
Raises:
AppwriteException: If query fails
ValueError: If month is invalid
"""
if not 1 <= month <= 12:
raise ValueError(f"Invalid month: {month}. Must be 1-12.")
try:
# Build date range for the month
start_of_month = datetime(year, month, 1, 0, 0, 0, tzinfo=timezone.utc)
# Calculate end of month
if month == 12:
end_of_month = datetime(year + 1, 1, 1, 0, 0, 0, tzinfo=timezone.utc) - timedelta(seconds=1)
else:
end_of_month = datetime(year, month + 1, 1, 0, 0, 0, tzinfo=timezone.utc) - timedelta(seconds=1)
# Query usage logs for this user and month
result = self.tables_db.list_rows(
database_id=self.database_id,
table_id=self.COLLECTION_ID,
queries=[
Query.equal("user_id", user_id),
Query.greater_than_equal("timestamp", start_of_month.isoformat()),
Query.less_than_equal("timestamp", end_of_month.isoformat()),
Query.limit(5000) # Cap at 5000 entries per month
]
)
# Aggregate the data
total_requests = 0
total_tokens = 0
total_cost = 0.0
for doc in result['rows']:
total_requests += 1
total_tokens += doc.get('tokens_total', 0)
total_cost += doc.get('estimated_cost', 0.0)
summary = MonthlyUsageSummary(
year=year,
month=month,
user_id=user_id,
total_requests=total_requests,
total_tokens=total_tokens,
estimated_cost=total_cost
)
logger.debug(
"Monthly cost retrieved",
user_id=user_id,
year=year,
month=month,
total_requests=total_requests,
estimated_cost=total_cost
)
return summary
except AppwriteException as e:
logger.error(
"Failed to get monthly cost",
user_id=user_id,
year=year,
month=month,
error=str(e),
code=e.code
)
raise
def get_total_daily_cost(self, target_date: date) -> float:
"""
Get the total AI cost across all users for a specific day.
Used for admin monitoring and alerting.
Args:
target_date: Date to get cost for
Returns:
Total estimated cost in USD
Raises:
AppwriteException: If query fails
"""
try:
# Build date range for the target day
start_of_day = datetime.combine(target_date, datetime.min.time()).replace(tzinfo=timezone.utc)
end_of_day = datetime.combine(target_date, datetime.max.time()).replace(tzinfo=timezone.utc)
# Query all usage logs for this date
result = self.tables_db.list_rows(
database_id=self.database_id,
table_id=self.COLLECTION_ID,
queries=[
Query.greater_than_equal("timestamp", start_of_day.isoformat()),
Query.less_than_equal("timestamp", end_of_day.isoformat()),
Query.limit(10000)
]
)
# Sum up costs
total_cost = sum(doc.get('estimated_cost', 0.0) for doc in result['rows'])
logger.debug(
"Total daily cost retrieved",
date=target_date.isoformat(),
total_cost=total_cost,
total_documents=len(result['rows'])
)
return total_cost
except AppwriteException as e:
logger.error(
"Failed to get total daily cost",
date=target_date.isoformat(),
error=str(e),
code=e.code
)
raise
def get_user_request_count_today(self, user_id: str) -> int:
"""
Get the number of AI requests a user has made today.
Used for rate limiting checks.
Args:
user_id: User ID to check
Returns:
Number of requests made today
Raises:
AppwriteException: If query fails
"""
try:
summary = self.get_daily_usage(user_id, date.today())
return summary.total_requests
except AppwriteException:
# If there's an error, return 0 to be safe (fail open)
logger.warning(
"Failed to get user request count, returning 0",
user_id=user_id
)
return 0
def _calculate_cost(self, model: str, tokens_input: int, tokens_output: int) -> float:
"""
Calculate the estimated cost for an AI request.
Args:
model: Model identifier
tokens_input: Number of input tokens
tokens_output: Number of output tokens
Returns:
Estimated cost in USD
"""
# Get cost per 1K tokens for this model
model_cost = MODEL_COSTS.get(model, DEFAULT_COST)
# Calculate cost (costs are per 1K tokens)
input_cost = (tokens_input / 1000) * model_cost["input"]
output_cost = (tokens_output / 1000) * model_cost["output"]
total_cost = input_cost + output_cost
return round(total_cost, 6) # Round to 6 decimal places
@staticmethod
def estimate_cost_for_model(model: str, tokens_input: int, tokens_output: int) -> float:
"""
Static method to estimate cost without needing a service instance.
Useful for pre-calculation and UI display.
Args:
model: Model identifier
tokens_input: Number of input tokens
tokens_output: Number of output tokens
Returns:
Estimated cost in USD
"""
model_cost = MODEL_COSTS.get(model, DEFAULT_COST)
input_cost = (tokens_input / 1000) * model_cost["input"]
output_cost = (tokens_output / 1000) * model_cost["output"]
return round(input_cost + output_cost, 6)
@staticmethod
def get_model_cost_info(model: str) -> Dict[str, float]:
"""
Get cost information for a model.
Args:
model: Model identifier
Returns:
Dictionary with 'input' and 'output' cost per 1K tokens
"""
return MODEL_COSTS.get(model, DEFAULT_COST)