first commit
This commit is contained in:
0
api/app/services/__init__.py
Normal file
0
api/app/services/__init__.py
Normal file
320
api/app/services/action_prompt_loader.py
Normal file
320
api/app/services/action_prompt_loader.py
Normal file
@@ -0,0 +1,320 @@
|
||||
"""
|
||||
Action Prompt Loader Service
|
||||
|
||||
This module provides a service for loading and filtering action prompts from YAML.
|
||||
It implements a singleton pattern to cache loaded prompts in memory.
|
||||
|
||||
Usage:
|
||||
from app.services.action_prompt_loader import ActionPromptLoader
|
||||
|
||||
loader = ActionPromptLoader()
|
||||
loader.load_from_yaml("app/data/action_prompts.yaml")
|
||||
|
||||
# Get available actions for a user at a location
|
||||
actions = loader.get_available_actions(
|
||||
user_tier=UserTier.FREE,
|
||||
location_type=LocationType.TOWN
|
||||
)
|
||||
|
||||
# Get specific action
|
||||
action = loader.get_action_by_id("ask_locals")
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Optional, Dict
|
||||
|
||||
import yaml
|
||||
|
||||
from app.models.action_prompt import ActionPrompt, LocationType
|
||||
from app.ai.model_selector import UserTier
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
|
||||
# Initialize logger
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
class ActionPromptLoaderError(Exception):
|
||||
"""Base exception for action prompt loader errors."""
|
||||
pass
|
||||
|
||||
|
||||
class ActionPromptNotFoundError(ActionPromptLoaderError):
|
||||
"""Raised when a requested action prompt is not found."""
|
||||
pass
|
||||
|
||||
|
||||
class ActionPromptLoader:
|
||||
"""
|
||||
Service for loading and filtering action prompts.
|
||||
|
||||
This class loads action prompts from YAML files and provides methods
|
||||
to filter them based on user tier and location type.
|
||||
|
||||
Uses singleton pattern to cache loaded prompts in memory.
|
||||
|
||||
Attributes:
|
||||
_prompts: Dictionary of loaded action prompts keyed by prompt_id
|
||||
_loaded: Flag indicating if prompts have been loaded
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_prompts: Dict[str, ActionPrompt] = {}
|
||||
_loaded: bool = False
|
||||
|
||||
def __new__(cls):
|
||||
"""Implement singleton pattern."""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._prompts = {}
|
||||
cls._instance._loaded = False
|
||||
return cls._instance
|
||||
|
||||
def load_from_yaml(self, filepath: str) -> int:
|
||||
"""
|
||||
Load action prompts from a YAML file.
|
||||
|
||||
Args:
|
||||
filepath: Path to the YAML file
|
||||
|
||||
Returns:
|
||||
Number of prompts loaded
|
||||
|
||||
Raises:
|
||||
ActionPromptLoaderError: If file cannot be read or parsed
|
||||
"""
|
||||
if not os.path.exists(filepath):
|
||||
logger.error("Action prompts file not found", filepath=filepath)
|
||||
raise ActionPromptLoaderError(f"File not found: {filepath}")
|
||||
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
except yaml.YAMLError as e:
|
||||
logger.error("Failed to parse YAML", filepath=filepath, error=str(e))
|
||||
raise ActionPromptLoaderError(f"Invalid YAML in {filepath}: {e}")
|
||||
|
||||
except IOError as e:
|
||||
logger.error("Failed to read file", filepath=filepath, error=str(e))
|
||||
raise ActionPromptLoaderError(f"Cannot read {filepath}: {e}")
|
||||
|
||||
if not data or 'action_prompts' not in data:
|
||||
logger.error("No action_prompts key in YAML", filepath=filepath)
|
||||
raise ActionPromptLoaderError(f"Missing 'action_prompts' key in {filepath}")
|
||||
|
||||
# Clear existing prompts
|
||||
self._prompts = {}
|
||||
|
||||
# Parse each prompt
|
||||
prompts_data = data['action_prompts']
|
||||
errors = []
|
||||
|
||||
for i, prompt_data in enumerate(prompts_data):
|
||||
try:
|
||||
prompt = ActionPrompt.from_dict(prompt_data)
|
||||
self._prompts[prompt.prompt_id] = prompt
|
||||
|
||||
except (ValueError, KeyError) as e:
|
||||
errors.append(f"Prompt {i}: {e}")
|
||||
logger.warning(
|
||||
"Failed to parse action prompt",
|
||||
index=i,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
if errors:
|
||||
logger.warning(
|
||||
"Some action prompts failed to load",
|
||||
error_count=len(errors),
|
||||
errors=errors
|
||||
)
|
||||
|
||||
self._loaded = True
|
||||
loaded_count = len(self._prompts)
|
||||
|
||||
logger.info(
|
||||
"Action prompts loaded",
|
||||
filepath=filepath,
|
||||
count=loaded_count,
|
||||
errors=len(errors)
|
||||
)
|
||||
|
||||
return loaded_count
|
||||
|
||||
def get_all_actions(self) -> List[ActionPrompt]:
|
||||
"""
|
||||
Get all loaded action prompts.
|
||||
|
||||
Returns:
|
||||
List of all action prompts
|
||||
"""
|
||||
self._ensure_loaded()
|
||||
return list(self._prompts.values())
|
||||
|
||||
def get_action_by_id(self, prompt_id: str) -> ActionPrompt:
|
||||
"""
|
||||
Get a specific action prompt by ID.
|
||||
|
||||
Args:
|
||||
prompt_id: The unique identifier of the action
|
||||
|
||||
Returns:
|
||||
The ActionPrompt object
|
||||
|
||||
Raises:
|
||||
ActionPromptNotFoundError: If action not found
|
||||
"""
|
||||
self._ensure_loaded()
|
||||
|
||||
if prompt_id not in self._prompts:
|
||||
logger.warning("Action prompt not found", prompt_id=prompt_id)
|
||||
raise ActionPromptNotFoundError(f"Action prompt '{prompt_id}' not found")
|
||||
|
||||
return self._prompts[prompt_id]
|
||||
|
||||
def get_available_actions(
|
||||
self,
|
||||
user_tier: UserTier,
|
||||
location_type: LocationType
|
||||
) -> List[ActionPrompt]:
|
||||
"""
|
||||
Get actions available to a user at a specific location.
|
||||
|
||||
Args:
|
||||
user_tier: The user's subscription tier
|
||||
location_type: The current location type
|
||||
|
||||
Returns:
|
||||
List of available action prompts
|
||||
"""
|
||||
self._ensure_loaded()
|
||||
|
||||
available = []
|
||||
for prompt in self._prompts.values():
|
||||
if prompt.is_available(user_tier, location_type):
|
||||
available.append(prompt)
|
||||
|
||||
logger.debug(
|
||||
"Filtered available actions",
|
||||
user_tier=user_tier.value,
|
||||
location_type=location_type.value,
|
||||
count=len(available)
|
||||
)
|
||||
|
||||
return available
|
||||
|
||||
def get_actions_by_tier(self, user_tier: UserTier) -> List[ActionPrompt]:
|
||||
"""
|
||||
Get all actions available to a user tier (ignoring location).
|
||||
|
||||
Args:
|
||||
user_tier: The user's subscription tier
|
||||
|
||||
Returns:
|
||||
List of action prompts available to the tier
|
||||
"""
|
||||
self._ensure_loaded()
|
||||
|
||||
available = []
|
||||
for prompt in self._prompts.values():
|
||||
if prompt._tier_meets_requirement(user_tier):
|
||||
available.append(prompt)
|
||||
|
||||
return available
|
||||
|
||||
def get_actions_by_category(self, category: str) -> List[ActionPrompt]:
|
||||
"""
|
||||
Get all actions in a specific category.
|
||||
|
||||
Args:
|
||||
category: The action category (e.g., "ask_question", "explore")
|
||||
|
||||
Returns:
|
||||
List of action prompts in the category
|
||||
"""
|
||||
self._ensure_loaded()
|
||||
|
||||
return [
|
||||
prompt for prompt in self._prompts.values()
|
||||
if prompt.category.value == category
|
||||
]
|
||||
|
||||
def get_locked_actions(
|
||||
self,
|
||||
user_tier: UserTier,
|
||||
location_type: LocationType
|
||||
) -> List[ActionPrompt]:
|
||||
"""
|
||||
Get actions that are locked due to tier restrictions.
|
||||
|
||||
Used to show locked actions with upgrade prompts in UI.
|
||||
|
||||
Args:
|
||||
user_tier: The user's subscription tier
|
||||
location_type: The current location type
|
||||
|
||||
Returns:
|
||||
List of locked action prompts
|
||||
"""
|
||||
self._ensure_loaded()
|
||||
|
||||
locked = []
|
||||
for prompt in self._prompts.values():
|
||||
# Must match location but be tier-locked
|
||||
if prompt._location_matches_filter(location_type) and prompt.is_locked(user_tier):
|
||||
locked.append(prompt)
|
||||
|
||||
return locked
|
||||
|
||||
def reload(self, filepath: str) -> int:
|
||||
"""
|
||||
Force reload prompts from YAML file.
|
||||
|
||||
Args:
|
||||
filepath: Path to the YAML file
|
||||
|
||||
Returns:
|
||||
Number of prompts loaded
|
||||
"""
|
||||
self._loaded = False
|
||||
return self.load_from_yaml(filepath)
|
||||
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if prompts have been loaded."""
|
||||
return self._loaded
|
||||
|
||||
def get_prompt_count(self) -> int:
|
||||
"""Get the number of loaded prompts."""
|
||||
return len(self._prompts)
|
||||
|
||||
def _ensure_loaded(self) -> None:
|
||||
"""
|
||||
Ensure prompts are loaded, auto-load from default path if not.
|
||||
|
||||
Raises:
|
||||
ActionPromptLoaderError: If prompts cannot be loaded
|
||||
"""
|
||||
if not self._loaded:
|
||||
# Try default path
|
||||
default_path = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
'..', 'data', 'action_prompts.yaml'
|
||||
)
|
||||
default_path = os.path.normpath(default_path)
|
||||
|
||||
if os.path.exists(default_path):
|
||||
self.load_from_yaml(default_path)
|
||||
else:
|
||||
raise ActionPromptLoaderError(
|
||||
"Action prompts not loaded. Call load_from_yaml() first."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def reset_instance(cls) -> None:
|
||||
"""
|
||||
Reset the singleton instance.
|
||||
|
||||
Primarily for testing purposes.
|
||||
"""
|
||||
cls._instance = None
|
||||
588
api/app/services/appwrite_service.py
Normal file
588
api/app/services/appwrite_service.py
Normal file
@@ -0,0 +1,588 @@
|
||||
"""
|
||||
Appwrite Service Wrapper
|
||||
|
||||
This module provides a wrapper around the Appwrite SDK for handling user authentication,
|
||||
session management, and user data operations. It abstracts Appwrite's API to provide
|
||||
a clean interface for the application.
|
||||
|
||||
Usage:
|
||||
from app.services.appwrite_service import AppwriteService
|
||||
|
||||
# Initialize service
|
||||
service = AppwriteService()
|
||||
|
||||
# Register a new user
|
||||
user = service.register_user(
|
||||
email="player@example.com",
|
||||
password="SecurePass123!",
|
||||
name="Brave Adventurer"
|
||||
)
|
||||
|
||||
# Login
|
||||
session = service.login_user(
|
||||
email="player@example.com",
|
||||
password="SecurePass123!"
|
||||
)
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from appwrite.client import Client
|
||||
from appwrite.services.account import Account
|
||||
from appwrite.services.users import Users
|
||||
from appwrite.exception import AppwriteException
|
||||
from appwrite.id import ID
|
||||
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
|
||||
# Initialize logger
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserData:
|
||||
"""
|
||||
Data class representing a user in the system.
|
||||
|
||||
Attributes:
|
||||
id: Unique user identifier
|
||||
email: User's email address
|
||||
name: User's display name
|
||||
email_verified: Whether email has been verified
|
||||
tier: User's subscription tier (free, basic, premium, elite)
|
||||
created_at: When the user account was created
|
||||
updated_at: When the user account was last updated
|
||||
"""
|
||||
id: str
|
||||
email: str
|
||||
name: str
|
||||
email_verified: bool
|
||||
tier: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert user data to dictionary."""
|
||||
return {
|
||||
"id": self.id,
|
||||
"email": self.email,
|
||||
"name": self.name,
|
||||
"email_verified": self.email_verified,
|
||||
"tier": self.tier,
|
||||
"created_at": self.created_at.isoformat() if isinstance(self.created_at, datetime) else self.created_at,
|
||||
"updated_at": self.updated_at.isoformat() if isinstance(self.updated_at, datetime) else self.updated_at,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionData:
|
||||
"""
|
||||
Data class representing a user session.
|
||||
|
||||
Attributes:
|
||||
session_id: Unique session identifier
|
||||
user_id: User ID associated with this session
|
||||
provider: Authentication provider (email, oauth, etc.)
|
||||
expire: When the session expires
|
||||
"""
|
||||
session_id: str
|
||||
user_id: str
|
||||
provider: str
|
||||
expire: datetime
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert session data to dictionary."""
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"user_id": self.user_id,
|
||||
"provider": self.provider,
|
||||
"expire": self.expire.isoformat() if isinstance(self.expire, datetime) else self.expire,
|
||||
}
|
||||
|
||||
|
||||
class AppwriteService:
|
||||
"""
|
||||
Service class for interacting with Appwrite authentication and user management.
|
||||
|
||||
This class provides methods for:
|
||||
- User registration and email verification
|
||||
- User login and logout
|
||||
- Session management
|
||||
- Password reset
|
||||
- User tier management
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the Appwrite service.
|
||||
|
||||
Reads configuration from environment variables:
|
||||
- APPWRITE_ENDPOINT: Appwrite API endpoint
|
||||
- APPWRITE_PROJECT_ID: Appwrite project ID
|
||||
- APPWRITE_API_KEY: Appwrite API key (for server-side operations)
|
||||
"""
|
||||
self.endpoint = os.getenv('APPWRITE_ENDPOINT')
|
||||
self.project_id = os.getenv('APPWRITE_PROJECT_ID')
|
||||
self.api_key = os.getenv('APPWRITE_API_KEY')
|
||||
|
||||
if not all([self.endpoint, self.project_id, self.api_key]):
|
||||
logger.error("Missing Appwrite configuration in environment variables")
|
||||
raise ValueError("Appwrite configuration incomplete. Check APPWRITE_* environment variables.")
|
||||
|
||||
# Initialize Appwrite client
|
||||
self.client = Client()
|
||||
self.client.set_endpoint(self.endpoint)
|
||||
self.client.set_project(self.project_id)
|
||||
self.client.set_key(self.api_key)
|
||||
|
||||
# Initialize services
|
||||
self.account = Account(self.client)
|
||||
self.users = Users(self.client)
|
||||
|
||||
logger.info("Appwrite service initialized", endpoint=self.endpoint, project_id=self.project_id)
|
||||
|
||||
def register_user(self, email: str, password: str, name: str) -> UserData:
|
||||
"""
|
||||
Register a new user account.
|
||||
|
||||
This method:
|
||||
1. Creates a new user in Appwrite Auth
|
||||
2. Sets the user's tier to 'free' in preferences
|
||||
3. Triggers email verification
|
||||
4. Returns user data
|
||||
|
||||
Args:
|
||||
email: User's email address
|
||||
password: User's password (will be hashed by Appwrite)
|
||||
name: User's display name
|
||||
|
||||
Returns:
|
||||
UserData object with user information
|
||||
|
||||
Raises:
|
||||
AppwriteException: If registration fails (e.g., email already exists)
|
||||
"""
|
||||
try:
|
||||
logger.info("Attempting to register new user", email=email, name=name)
|
||||
|
||||
# Generate unique user ID
|
||||
user_id = ID.unique()
|
||||
|
||||
# Create user account
|
||||
user = self.users.create(
|
||||
user_id=user_id,
|
||||
email=email,
|
||||
password=password,
|
||||
name=name
|
||||
)
|
||||
|
||||
logger.info("User created successfully", user_id=user['$id'], email=email)
|
||||
|
||||
# Set default tier to 'free' in user preferences
|
||||
self.users.update_prefs(
|
||||
user_id=user['$id'],
|
||||
prefs={
|
||||
'tier': 'free',
|
||||
'tier_updated_at': datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
)
|
||||
|
||||
logger.info("User tier set to 'free'", user_id=user['$id'])
|
||||
|
||||
# Note: Email verification is handled by Appwrite automatically
|
||||
# when email templates are configured in the Appwrite console.
|
||||
# For server-side user creation, verification emails are sent
|
||||
# automatically if the email provider is configured.
|
||||
#
|
||||
# To manually trigger verification, users can use the Account service
|
||||
# (client-side) after logging in, or configure email verification
|
||||
# settings in the Appwrite console.
|
||||
|
||||
logger.info("User created, email verification handled by Appwrite", user_id=user['$id'], email=email)
|
||||
|
||||
# Return user data
|
||||
return self._user_to_userdata(user)
|
||||
|
||||
except AppwriteException as e:
|
||||
logger.error("Failed to register user", email=email, error=str(e), code=e.code)
|
||||
raise
|
||||
|
||||
def login_user(self, email: str, password: str) -> tuple[SessionData, UserData]:
|
||||
"""
|
||||
Authenticate a user and create a session.
|
||||
|
||||
For server-side authentication, we create a temporary client with user
|
||||
credentials to verify them, then create a session using the server SDK.
|
||||
|
||||
Args:
|
||||
email: User's email address
|
||||
password: User's password
|
||||
|
||||
Returns:
|
||||
Tuple of (SessionData, UserData)
|
||||
|
||||
Raises:
|
||||
AppwriteException: If login fails (invalid credentials, etc.)
|
||||
"""
|
||||
try:
|
||||
logger.info("Attempting user login", email=email)
|
||||
|
||||
# Use admin client (with API key) to create session
|
||||
# This is required to get the session secret in the response
|
||||
from appwrite.services.account import Account
|
||||
|
||||
admin_account = Account(self.client) # self.client already has API key set
|
||||
|
||||
# Create email/password session using admin client
|
||||
# When using admin client, the 'secret' field is populated in the response
|
||||
user_session = admin_account.create_email_password_session(
|
||||
email=email,
|
||||
password=password
|
||||
)
|
||||
|
||||
logger.info("Session created successfully",
|
||||
user_id=user_session['userId'],
|
||||
session_id=user_session['$id'])
|
||||
|
||||
# Extract session secret from response
|
||||
# Admin client populates this field, unlike regular client
|
||||
session_secret = user_session.get('secret', '')
|
||||
|
||||
if not session_secret:
|
||||
logger.error("Session secret not found in response - this should not happen with admin client")
|
||||
raise AppwriteException("Failed to get session secret", code=500)
|
||||
|
||||
# Get user data using server SDK
|
||||
user = self.users.get(user_id=user_session['userId'])
|
||||
|
||||
# Convert to our data classes
|
||||
session_data = SessionData(
|
||||
session_id=session_secret, # Use the secret, not the session ID
|
||||
user_id=user_session['userId'],
|
||||
provider=user_session['provider'],
|
||||
expire=datetime.fromisoformat(user_session['expire'].replace('Z', '+00:00'))
|
||||
)
|
||||
|
||||
user_data = self._user_to_userdata(user)
|
||||
|
||||
return session_data, user_data
|
||||
|
||||
except AppwriteException as e:
|
||||
logger.error("Failed to login user", email=email, error=str(e), code=e.code)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error during login", email=email, error=str(e), exc_info=True)
|
||||
raise AppwriteException(str(e), code=500)
|
||||
|
||||
def logout_user(self, session_id: str) -> bool:
|
||||
"""
|
||||
Log out a user by deleting their session.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to delete
|
||||
|
||||
Returns:
|
||||
True if logout successful
|
||||
|
||||
Raises:
|
||||
AppwriteException: If logout fails
|
||||
"""
|
||||
try:
|
||||
logger.info("Attempting to logout user", session_id=session_id)
|
||||
|
||||
# For server-side, we need to delete the session using Users service
|
||||
# First get the session to find the user_id
|
||||
# Note: Appwrite doesn't have a direct server-side session delete by session_id
|
||||
# We'll use a workaround by creating a client with the session and deleting it
|
||||
|
||||
from appwrite.client import Client
|
||||
from appwrite.services.account import Account
|
||||
|
||||
# Create client with the session
|
||||
session_client = Client()
|
||||
session_client.set_endpoint(self.endpoint)
|
||||
session_client.set_project(self.project_id)
|
||||
session_client.set_session(session_id)
|
||||
|
||||
session_account = Account(session_client)
|
||||
|
||||
# Delete the current session
|
||||
session_account.delete_session('current')
|
||||
|
||||
logger.info("User logged out successfully", session_id=session_id)
|
||||
return True
|
||||
|
||||
except AppwriteException as e:
|
||||
logger.error("Failed to logout user", session_id=session_id, error=str(e), code=e.code)
|
||||
raise
|
||||
|
||||
def verify_email(self, user_id: str, secret: str) -> bool:
|
||||
"""
|
||||
Verify a user's email address.
|
||||
|
||||
Note: Email verification with server-side SDK requires updating
|
||||
the user's emailVerification status directly, or using Appwrite's
|
||||
built-in verification flow through the Account service (client-side).
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
secret: Verification secret from email link (not validated server-side)
|
||||
|
||||
Returns:
|
||||
True if verification successful
|
||||
|
||||
Raises:
|
||||
AppwriteException: If verification fails (invalid/expired secret)
|
||||
"""
|
||||
try:
|
||||
logger.info("Attempting to verify email", user_id=user_id, secret_provided=bool(secret))
|
||||
|
||||
# For server-side verification, we update the user's email verification status
|
||||
# The secret validation should be done by Appwrite's verification flow
|
||||
# For now, we'll mark the email as verified
|
||||
# In production, you should validate the secret token before updating
|
||||
self.users.update_email_verification(user_id=user_id, email_verification=True)
|
||||
|
||||
logger.info("Email verified successfully", user_id=user_id)
|
||||
return True
|
||||
|
||||
except AppwriteException as e:
|
||||
logger.error("Failed to verify email", user_id=user_id, error=str(e), code=e.code)
|
||||
raise
|
||||
|
||||
def request_password_reset(self, email: str) -> bool:
|
||||
"""
|
||||
Request a password reset for a user.
|
||||
|
||||
This sends a password reset email to the user. For security,
|
||||
it always returns True even if the email doesn't exist.
|
||||
|
||||
Note: Password reset is handled through Appwrite's built-in Account
|
||||
service recovery flow. For server-side operations, we would need to
|
||||
create a password recovery token manually.
|
||||
|
||||
Args:
|
||||
email: User's email address
|
||||
|
||||
Returns:
|
||||
Always True (for security - don't reveal if email exists)
|
||||
"""
|
||||
try:
|
||||
logger.info("Password reset requested", email=email)
|
||||
|
||||
# Note: Password reset with server-side SDK requires creating
|
||||
# a recovery token. For now, we'll log this and return success.
|
||||
# In production, configure Appwrite's email templates and use
|
||||
# client-side Account.createRecovery() or implement custom token
|
||||
# generation and email sending.
|
||||
|
||||
logger.warning("Password reset not fully implemented - requires Appwrite email configuration", email=email)
|
||||
|
||||
except Exception as e:
|
||||
# Log the error but still return True for security
|
||||
# Don't reveal whether the email exists
|
||||
logger.warning("Password reset request encountered error", email=email, error=str(e))
|
||||
|
||||
# Always return True to not reveal if email exists
|
||||
return True
|
||||
|
||||
def confirm_password_reset(self, user_id: str, secret: str, password: str) -> bool:
|
||||
"""
|
||||
Confirm a password reset and update the user's password.
|
||||
|
||||
Note: For server-side operations, we update the password directly
|
||||
using the Users service. Secret validation would be handled separately.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
secret: Reset secret from email link (should be validated before calling)
|
||||
password: New password
|
||||
|
||||
Returns:
|
||||
True if password reset successful
|
||||
|
||||
Raises:
|
||||
AppwriteException: If reset fails
|
||||
"""
|
||||
try:
|
||||
logger.info("Attempting to reset password", user_id=user_id, secret_provided=bool(secret))
|
||||
|
||||
# For server-side password reset, update the password directly
|
||||
# In production, you should validate the secret token first before calling this
|
||||
# The secret parameter is kept for API compatibility but not validated here
|
||||
self.users.update_password(user_id=user_id, password=password)
|
||||
|
||||
logger.info("Password reset successfully", user_id=user_id)
|
||||
return True
|
||||
|
||||
except AppwriteException as e:
|
||||
logger.error("Failed to reset password", user_id=user_id, error=str(e), code=e.code)
|
||||
raise
|
||||
|
||||
def get_user(self, user_id: str) -> UserData:
|
||||
"""
|
||||
Get user data by user ID.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
UserData object
|
||||
|
||||
Raises:
|
||||
AppwriteException: If user not found
|
||||
"""
|
||||
try:
|
||||
user = self.users.get(user_id=user_id)
|
||||
|
||||
return self._user_to_userdata(user)
|
||||
|
||||
except AppwriteException as e:
|
||||
logger.error("Failed to fetch user", user_id=user_id, error=str(e), code=e.code)
|
||||
raise
|
||||
|
||||
def get_session(self, session_id: str) -> SessionData:
|
||||
"""
|
||||
Get session data and validate it's still active.
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
|
||||
Returns:
|
||||
SessionData object
|
||||
|
||||
Raises:
|
||||
AppwriteException: If session invalid or expired
|
||||
"""
|
||||
try:
|
||||
# Create a client with the session to validate it
|
||||
from appwrite.client import Client
|
||||
from appwrite.services.account import Account
|
||||
|
||||
session_client = Client()
|
||||
session_client.set_endpoint(self.endpoint)
|
||||
session_client.set_project(self.project_id)
|
||||
session_client.set_session(session_id)
|
||||
|
||||
session_account = Account(session_client)
|
||||
|
||||
# Get the current session (this validates it exists and is active)
|
||||
session = session_account.get_session('current')
|
||||
|
||||
# Check if session is expired
|
||||
expire_time = datetime.fromisoformat(session['expire'].replace('Z', '+00:00'))
|
||||
if expire_time < datetime.now(timezone.utc):
|
||||
logger.warning("Session expired", session_id=session_id, expired_at=expire_time)
|
||||
raise AppwriteException("Session expired", code=401)
|
||||
|
||||
return SessionData(
|
||||
session_id=session['$id'],
|
||||
user_id=session['userId'],
|
||||
provider=session['provider'],
|
||||
expire=expire_time
|
||||
)
|
||||
|
||||
except AppwriteException as e:
|
||||
logger.error("Failed to validate session", session_id=session_id, error=str(e), code=e.code)
|
||||
raise
|
||||
|
||||
def get_user_tier(self, user_id: str) -> str:
|
||||
"""
|
||||
Get the user's subscription tier.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Tier string (free, basic, premium, elite)
|
||||
"""
|
||||
try:
|
||||
logger.debug("Fetching user tier", user_id=user_id)
|
||||
|
||||
user = self.users.get(user_id=user_id)
|
||||
prefs = user.get('prefs', {})
|
||||
tier = prefs.get('tier', 'free')
|
||||
|
||||
logger.debug("User tier retrieved", user_id=user_id, tier=tier)
|
||||
return tier
|
||||
|
||||
except AppwriteException as e:
|
||||
logger.error("Failed to fetch user tier", user_id=user_id, error=str(e), code=e.code)
|
||||
# Default to free tier on error
|
||||
return 'free'
|
||||
|
||||
def set_user_tier(self, user_id: str, tier: str) -> bool:
|
||||
"""
|
||||
Update the user's subscription tier.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
tier: New tier (free, basic, premium, elite)
|
||||
|
||||
Returns:
|
||||
True if update successful
|
||||
|
||||
Raises:
|
||||
AppwriteException: If update fails
|
||||
ValueError: If tier is invalid
|
||||
"""
|
||||
valid_tiers = ['free', 'basic', 'premium', 'elite']
|
||||
if tier not in valid_tiers:
|
||||
raise ValueError(f"Invalid tier: {tier}. Must be one of {valid_tiers}")
|
||||
|
||||
try:
|
||||
logger.info("Updating user tier", user_id=user_id, new_tier=tier)
|
||||
|
||||
# Get current preferences
|
||||
user = self.users.get(user_id=user_id)
|
||||
prefs = user.get('prefs', {})
|
||||
|
||||
# Update tier
|
||||
prefs['tier'] = tier
|
||||
prefs['tier_updated_at'] = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
self.users.update_prefs(user_id=user_id, prefs=prefs)
|
||||
|
||||
logger.info("User tier updated successfully", user_id=user_id, tier=tier)
|
||||
return True
|
||||
|
||||
except AppwriteException as e:
|
||||
logger.error("Failed to update user tier", user_id=user_id, tier=tier, error=str(e), code=e.code)
|
||||
raise
|
||||
|
||||
def _user_to_userdata(self, user: Dict[str, Any]) -> UserData:
|
||||
"""
|
||||
Convert Appwrite user object to UserData dataclass.
|
||||
|
||||
Args:
|
||||
user: Appwrite user dictionary
|
||||
|
||||
Returns:
|
||||
UserData object
|
||||
"""
|
||||
# Get tier from preferences, default to 'free'
|
||||
prefs = user.get('prefs', {})
|
||||
tier = prefs.get('tier', 'free')
|
||||
|
||||
# Parse timestamps
|
||||
created_at = user.get('$createdAt', datetime.now(timezone.utc).isoformat())
|
||||
updated_at = user.get('$updatedAt', datetime.now(timezone.utc).isoformat())
|
||||
|
||||
if isinstance(created_at, str):
|
||||
created_at = datetime.fromisoformat(created_at.replace('Z', '+00:00'))
|
||||
if isinstance(updated_at, str):
|
||||
updated_at = datetime.fromisoformat(updated_at.replace('Z', '+00:00'))
|
||||
|
||||
return UserData(
|
||||
id=user['$id'],
|
||||
email=user['email'],
|
||||
name=user['name'],
|
||||
email_verified=user.get('emailVerification', False),
|
||||
tier=tier,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at
|
||||
)
|
||||
1049
api/app/services/character_service.py
Normal file
1049
api/app/services/character_service.py
Normal file
File diff suppressed because it is too large
Load Diff
277
api/app/services/class_loader.py
Normal file
277
api/app/services/class_loader.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""
|
||||
ClassLoader service for loading player class definitions from YAML files.
|
||||
|
||||
This service reads class configuration files and converts them into PlayerClass
|
||||
dataclass instances, providing caching for performance.
|
||||
"""
|
||||
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
import structlog
|
||||
|
||||
from app.models.skills import PlayerClass, SkillTree, SkillNode
|
||||
from app.models.stats import Stats
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class ClassLoader:
|
||||
"""
|
||||
Loads player class definitions from YAML configuration files.
|
||||
|
||||
This allows game designers to define classes and skill trees without touching code.
|
||||
All class definitions are stored in /app/data/classes/ as YAML files.
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir: Optional[str] = None):
|
||||
"""
|
||||
Initialize the class loader.
|
||||
|
||||
Args:
|
||||
data_dir: Path to directory containing class YAML files.
|
||||
Defaults to /app/data/classes/
|
||||
"""
|
||||
if data_dir is None:
|
||||
# Default to app/data/classes relative to this file
|
||||
current_file = Path(__file__)
|
||||
app_dir = current_file.parent.parent # Go up to /app
|
||||
data_dir = str(app_dir / "data" / "classes")
|
||||
|
||||
self.data_dir = Path(data_dir)
|
||||
self._class_cache: Dict[str, PlayerClass] = {}
|
||||
|
||||
logger.info("ClassLoader initialized", data_dir=str(self.data_dir))
|
||||
|
||||
def load_class(self, class_id: str) -> Optional[PlayerClass]:
|
||||
"""
|
||||
Load a single player class by ID.
|
||||
|
||||
Args:
|
||||
class_id: Unique class identifier (e.g., "vanguard")
|
||||
|
||||
Returns:
|
||||
PlayerClass instance or None if not found
|
||||
"""
|
||||
# Check cache first
|
||||
if class_id in self._class_cache:
|
||||
logger.debug("Class loaded from cache", class_id=class_id)
|
||||
return self._class_cache[class_id]
|
||||
|
||||
# Construct file path
|
||||
file_path = self.data_dir / f"{class_id}.yaml"
|
||||
|
||||
if not file_path.exists():
|
||||
logger.warning("Class file not found", class_id=class_id, file_path=str(file_path))
|
||||
return None
|
||||
|
||||
try:
|
||||
# Load YAML file
|
||||
with open(file_path, 'r') as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
# Parse into PlayerClass
|
||||
player_class = self._parse_class_data(data)
|
||||
|
||||
# Cache the result
|
||||
self._class_cache[class_id] = player_class
|
||||
|
||||
logger.info("Class loaded successfully", class_id=class_id)
|
||||
return player_class
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to load class", class_id=class_id, error=str(e))
|
||||
return None
|
||||
|
||||
def load_all_classes(self) -> List[PlayerClass]:
|
||||
"""
|
||||
Load all player classes from the data directory.
|
||||
|
||||
Returns:
|
||||
List of PlayerClass instances
|
||||
"""
|
||||
classes = []
|
||||
|
||||
# Find all YAML files in the directory
|
||||
if not self.data_dir.exists():
|
||||
logger.error("Class data directory does not exist", data_dir=str(self.data_dir))
|
||||
return classes
|
||||
|
||||
for file_path in self.data_dir.glob("*.yaml"):
|
||||
class_id = file_path.stem # Get filename without extension
|
||||
player_class = self.load_class(class_id)
|
||||
if player_class:
|
||||
classes.append(player_class)
|
||||
|
||||
logger.info("All classes loaded", count=len(classes))
|
||||
return classes
|
||||
|
||||
def get_class_by_id(self, class_id: str) -> Optional[PlayerClass]:
|
||||
"""
|
||||
Get a player class by ID (alias for load_class).
|
||||
|
||||
Args:
|
||||
class_id: Unique class identifier
|
||||
|
||||
Returns:
|
||||
PlayerClass instance or None if not found
|
||||
"""
|
||||
return self.load_class(class_id)
|
||||
|
||||
def get_all_class_ids(self) -> List[str]:
|
||||
"""
|
||||
Get a list of all available class IDs.
|
||||
|
||||
Returns:
|
||||
List of class IDs (e.g., ["vanguard", "assassin", "arcanist"])
|
||||
"""
|
||||
if not self.data_dir.exists():
|
||||
return []
|
||||
|
||||
return [file_path.stem for file_path in self.data_dir.glob("*.yaml")]
|
||||
|
||||
def reload_class(self, class_id: str) -> Optional[PlayerClass]:
|
||||
"""
|
||||
Force reload a class from disk, bypassing cache.
|
||||
|
||||
Useful for development/testing when class definitions change.
|
||||
|
||||
Args:
|
||||
class_id: Unique class identifier
|
||||
|
||||
Returns:
|
||||
PlayerClass instance or None if not found
|
||||
"""
|
||||
# Remove from cache if present
|
||||
if class_id in self._class_cache:
|
||||
del self._class_cache[class_id]
|
||||
|
||||
return self.load_class(class_id)
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear the class cache. Useful for testing."""
|
||||
self._class_cache.clear()
|
||||
logger.info("Class cache cleared")
|
||||
|
||||
def _parse_class_data(self, data: Dict) -> PlayerClass:
|
||||
"""
|
||||
Parse YAML data into a PlayerClass dataclass.
|
||||
|
||||
Args:
|
||||
data: Dictionary loaded from YAML file
|
||||
|
||||
Returns:
|
||||
PlayerClass instance
|
||||
|
||||
Raises:
|
||||
ValueError: If data is invalid or missing required fields
|
||||
"""
|
||||
# Validate required fields
|
||||
required_fields = ["class_id", "name", "description", "base_stats", "skill_trees"]
|
||||
for field in required_fields:
|
||||
if field not in data:
|
||||
raise ValueError(f"Missing required field: {field}")
|
||||
|
||||
# Parse base stats
|
||||
base_stats = Stats(**data["base_stats"])
|
||||
|
||||
# Parse skill trees
|
||||
skill_trees = []
|
||||
for tree_data in data["skill_trees"]:
|
||||
skill_tree = self._parse_skill_tree(tree_data)
|
||||
skill_trees.append(skill_tree)
|
||||
|
||||
# Get optional fields
|
||||
starting_equipment = data.get("starting_equipment", [])
|
||||
starting_abilities = data.get("starting_abilities", [])
|
||||
|
||||
# Create PlayerClass instance
|
||||
player_class = PlayerClass(
|
||||
class_id=data["class_id"],
|
||||
name=data["name"],
|
||||
description=data["description"],
|
||||
base_stats=base_stats,
|
||||
skill_trees=skill_trees,
|
||||
starting_equipment=starting_equipment,
|
||||
starting_abilities=starting_abilities
|
||||
)
|
||||
|
||||
return player_class
|
||||
|
||||
def _parse_skill_tree(self, tree_data: Dict) -> SkillTree:
|
||||
"""
|
||||
Parse a skill tree from YAML data.
|
||||
|
||||
Args:
|
||||
tree_data: Dictionary containing skill tree data
|
||||
|
||||
Returns:
|
||||
SkillTree instance
|
||||
"""
|
||||
# Validate required fields
|
||||
required_fields = ["tree_id", "name", "description", "nodes"]
|
||||
for field in required_fields:
|
||||
if field not in tree_data:
|
||||
raise ValueError(f"Missing required field in skill tree: {field}")
|
||||
|
||||
# Parse skill nodes
|
||||
nodes = []
|
||||
for node_data in tree_data["nodes"]:
|
||||
skill_node = self._parse_skill_node(node_data)
|
||||
nodes.append(skill_node)
|
||||
|
||||
# Create SkillTree instance
|
||||
skill_tree = SkillTree(
|
||||
tree_id=tree_data["tree_id"],
|
||||
name=tree_data["name"],
|
||||
description=tree_data["description"],
|
||||
nodes=nodes
|
||||
)
|
||||
|
||||
return skill_tree
|
||||
|
||||
def _parse_skill_node(self, node_data: Dict) -> SkillNode:
|
||||
"""
|
||||
Parse a skill node from YAML data.
|
||||
|
||||
Args:
|
||||
node_data: Dictionary containing skill node data
|
||||
|
||||
Returns:
|
||||
SkillNode instance
|
||||
"""
|
||||
# Validate required fields
|
||||
required_fields = ["skill_id", "name", "description", "tier", "effects"]
|
||||
for field in required_fields:
|
||||
if field not in node_data:
|
||||
raise ValueError(f"Missing required field in skill node: {field}")
|
||||
|
||||
# Create SkillNode instance
|
||||
skill_node = SkillNode(
|
||||
skill_id=node_data["skill_id"],
|
||||
name=node_data["name"],
|
||||
description=node_data["description"],
|
||||
tier=node_data["tier"],
|
||||
prerequisites=node_data.get("prerequisites", []),
|
||||
effects=node_data.get("effects", {}),
|
||||
unlocked=False # Always start locked
|
||||
)
|
||||
|
||||
return skill_node
|
||||
|
||||
|
||||
# Global instance for convenience
|
||||
_loader_instance: Optional[ClassLoader] = None
|
||||
|
||||
|
||||
def get_class_loader() -> ClassLoader:
|
||||
"""
|
||||
Get the global ClassLoader instance.
|
||||
|
||||
Returns:
|
||||
Singleton ClassLoader instance
|
||||
"""
|
||||
global _loader_instance
|
||||
if _loader_instance is None:
|
||||
_loader_instance = ClassLoader()
|
||||
return _loader_instance
|
||||
709
api/app/services/database_init.py
Normal file
709
api/app/services/database_init.py
Normal file
@@ -0,0 +1,709 @@
|
||||
"""
|
||||
Database Initialization Service.
|
||||
|
||||
This service handles programmatic creation of Appwrite database tables,
|
||||
including schema definition, column creation, and index setup.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from appwrite.client import Client
|
||||
from appwrite.services.tables_db import TablesDB
|
||||
from appwrite.exception import AppwriteException
|
||||
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
class DatabaseInitService:
|
||||
"""
|
||||
Service for initializing Appwrite database tables.
|
||||
|
||||
This service provides methods to:
|
||||
- Create tables if they don't exist
|
||||
- Define table schemas (columns/attributes)
|
||||
- Create indexes for efficient querying
|
||||
- Validate existing table structures
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the database initialization service.
|
||||
|
||||
Reads configuration from environment variables:
|
||||
- APPWRITE_ENDPOINT: Appwrite API endpoint
|
||||
- APPWRITE_PROJECT_ID: Appwrite project ID
|
||||
- APPWRITE_API_KEY: Appwrite API key
|
||||
- APPWRITE_DATABASE_ID: Appwrite database ID
|
||||
"""
|
||||
self.endpoint = os.getenv('APPWRITE_ENDPOINT')
|
||||
self.project_id = os.getenv('APPWRITE_PROJECT_ID')
|
||||
self.api_key = os.getenv('APPWRITE_API_KEY')
|
||||
self.database_id = os.getenv('APPWRITE_DATABASE_ID', 'main')
|
||||
|
||||
if not all([self.endpoint, self.project_id, self.api_key]):
|
||||
logger.error("Missing Appwrite configuration in environment variables")
|
||||
raise ValueError("Appwrite configuration incomplete. Check APPWRITE_* environment variables.")
|
||||
|
||||
# Initialize Appwrite client
|
||||
self.client = Client()
|
||||
self.client.set_endpoint(self.endpoint)
|
||||
self.client.set_project(self.project_id)
|
||||
self.client.set_key(self.api_key)
|
||||
|
||||
# Initialize TablesDB service
|
||||
self.tables_db = TablesDB(self.client)
|
||||
|
||||
logger.info("DatabaseInitService initialized", database_id=self.database_id)
|
||||
|
||||
def init_all_tables(self) -> Dict[str, bool]:
|
||||
"""
|
||||
Initialize all application tables.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping table names to success status
|
||||
"""
|
||||
results = {}
|
||||
|
||||
logger.info("Initializing all database tables")
|
||||
|
||||
# Initialize characters table
|
||||
try:
|
||||
self.init_characters_table()
|
||||
results['characters'] = True
|
||||
logger.info("Characters table initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize characters table", error=str(e))
|
||||
results['characters'] = False
|
||||
|
||||
# Initialize game_sessions table
|
||||
try:
|
||||
self.init_game_sessions_table()
|
||||
results['game_sessions'] = True
|
||||
logger.info("Game sessions table initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize game_sessions table", error=str(e))
|
||||
results['game_sessions'] = False
|
||||
|
||||
# Initialize ai_usage_logs table
|
||||
try:
|
||||
self.init_ai_usage_logs_table()
|
||||
results['ai_usage_logs'] = True
|
||||
logger.info("AI usage logs table initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize ai_usage_logs table", error=str(e))
|
||||
results['ai_usage_logs'] = False
|
||||
|
||||
success_count = sum(1 for v in results.values() if v)
|
||||
total_count = len(results)
|
||||
|
||||
logger.info("Table initialization complete",
|
||||
success=success_count,
|
||||
total=total_count,
|
||||
results=results)
|
||||
|
||||
return results
|
||||
|
||||
def init_characters_table(self) -> bool:
|
||||
"""
|
||||
Initialize the characters table.
|
||||
|
||||
Table schema:
|
||||
- userId (string, required): Owner's user ID
|
||||
- characterData (string, required): JSON-serialized character data
|
||||
- is_active (boolean, default=True): Soft delete flag
|
||||
- created_at (datetime): Auto-managed creation timestamp
|
||||
- updated_at (datetime): Auto-managed update timestamp
|
||||
|
||||
Indexes:
|
||||
- userId: For general user queries
|
||||
- userId + is_active: Composite index for efficiently fetching active characters
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
|
||||
Raises:
|
||||
AppwriteException: If table creation fails
|
||||
"""
|
||||
table_id = 'characters'
|
||||
|
||||
logger.info("Initializing characters table", table_id=table_id)
|
||||
|
||||
try:
|
||||
# Check if table already exists
|
||||
try:
|
||||
self.tables_db.get_table(
|
||||
database_id=self.database_id,
|
||||
table_id=table_id
|
||||
)
|
||||
logger.info("Characters table already exists", table_id=table_id)
|
||||
return True
|
||||
except AppwriteException as e:
|
||||
if e.code != 404:
|
||||
raise
|
||||
logger.info("Characters table does not exist, creating...")
|
||||
|
||||
# Create table
|
||||
logger.info("Creating characters table")
|
||||
table = self.tables_db.create_table(
|
||||
database_id=self.database_id,
|
||||
table_id=table_id,
|
||||
name='Characters'
|
||||
)
|
||||
logger.info("Characters table created", table_id=table['$id'])
|
||||
|
||||
# Create columns
|
||||
self._create_column(
|
||||
table_id=table_id,
|
||||
column_id='userId',
|
||||
column_type='string',
|
||||
size=255,
|
||||
required=True
|
||||
)
|
||||
|
||||
self._create_column(
|
||||
table_id=table_id,
|
||||
column_id='characterData',
|
||||
column_type='string',
|
||||
size=65535, # Large text field for JSON data
|
||||
required=True
|
||||
)
|
||||
|
||||
self._create_column(
|
||||
table_id=table_id,
|
||||
column_id='is_active',
|
||||
column_type='boolean',
|
||||
required=False, # Cannot be required if we want a default value
|
||||
default=True
|
||||
)
|
||||
|
||||
# Note: created_at and updated_at are auto-managed by DatabaseService
|
||||
# through the _parse_row method and timestamp updates
|
||||
|
||||
# Wait for columns to fully propagate in Appwrite before creating indexes
|
||||
logger.info("Waiting for columns to propagate before creating indexes...")
|
||||
time.sleep(2)
|
||||
|
||||
# Create indexes for efficient querying
|
||||
# Note: Individual userId index for general user queries
|
||||
self._create_index(
|
||||
table_id=table_id,
|
||||
index_id='idx_userId',
|
||||
index_type='key',
|
||||
attributes=['userId']
|
||||
)
|
||||
|
||||
# Composite index for the most common query pattern:
|
||||
# Query.equal('userId', user_id) + Query.equal('is_active', True)
|
||||
# This single composite index covers both conditions efficiently
|
||||
self._create_index(
|
||||
table_id=table_id,
|
||||
index_id='idx_userId_is_active',
|
||||
index_type='key',
|
||||
attributes=['userId', 'is_active']
|
||||
)
|
||||
|
||||
logger.info("Characters table initialized successfully", table_id=table_id)
|
||||
return True
|
||||
|
||||
except AppwriteException as e:
|
||||
logger.error("Failed to initialize characters table",
|
||||
table_id=table_id,
|
||||
error=str(e),
|
||||
code=e.code)
|
||||
raise
|
||||
|
||||
def init_game_sessions_table(self) -> bool:
|
||||
"""
|
||||
Initialize the game_sessions table.
|
||||
|
||||
Table schema:
|
||||
- userId (string, required): Owner's user ID
|
||||
- characterId (string, required): Character ID for this session
|
||||
- sessionData (string, required): JSON-serialized session data
|
||||
- status (string, required): Session status (active, completed, abandoned)
|
||||
- sessionType (string, required): Session type (solo, multiplayer)
|
||||
|
||||
Indexes:
|
||||
- userId: For user session queries
|
||||
- userId + status: For active session queries
|
||||
- characterId: For character session lookups
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
|
||||
Raises:
|
||||
AppwriteException: If table creation fails
|
||||
"""
|
||||
table_id = 'game_sessions'
|
||||
|
||||
logger.info("Initializing game_sessions table", table_id=table_id)
|
||||
|
||||
try:
|
||||
# Check if table already exists
|
||||
try:
|
||||
self.tables_db.get_table(
|
||||
database_id=self.database_id,
|
||||
table_id=table_id
|
||||
)
|
||||
logger.info("Game sessions table already exists", table_id=table_id)
|
||||
return True
|
||||
except AppwriteException as e:
|
||||
if e.code != 404:
|
||||
raise
|
||||
logger.info("Game sessions table does not exist, creating...")
|
||||
|
||||
# Create table
|
||||
logger.info("Creating game_sessions table")
|
||||
table = self.tables_db.create_table(
|
||||
database_id=self.database_id,
|
||||
table_id=table_id,
|
||||
name='Game Sessions'
|
||||
)
|
||||
logger.info("Game sessions table created", table_id=table['$id'])
|
||||
|
||||
# Create columns
|
||||
self._create_column(
|
||||
table_id=table_id,
|
||||
column_id='userId',
|
||||
column_type='string',
|
||||
size=255,
|
||||
required=True
|
||||
)
|
||||
|
||||
self._create_column(
|
||||
table_id=table_id,
|
||||
column_id='characterId',
|
||||
column_type='string',
|
||||
size=255,
|
||||
required=True
|
||||
)
|
||||
|
||||
self._create_column(
|
||||
table_id=table_id,
|
||||
column_id='sessionData',
|
||||
column_type='string',
|
||||
size=65535, # Large text field for JSON data
|
||||
required=True
|
||||
)
|
||||
|
||||
self._create_column(
|
||||
table_id=table_id,
|
||||
column_id='status',
|
||||
column_type='string',
|
||||
size=50,
|
||||
required=True
|
||||
)
|
||||
|
||||
self._create_column(
|
||||
table_id=table_id,
|
||||
column_id='sessionType',
|
||||
column_type='string',
|
||||
size=50,
|
||||
required=True
|
||||
)
|
||||
|
||||
# Wait for columns to fully propagate
|
||||
logger.info("Waiting for columns to propagate before creating indexes...")
|
||||
time.sleep(2)
|
||||
|
||||
# Create indexes
|
||||
self._create_index(
|
||||
table_id=table_id,
|
||||
index_id='idx_userId',
|
||||
index_type='key',
|
||||
attributes=['userId']
|
||||
)
|
||||
|
||||
self._create_index(
|
||||
table_id=table_id,
|
||||
index_id='idx_userId_status',
|
||||
index_type='key',
|
||||
attributes=['userId', 'status']
|
||||
)
|
||||
|
||||
self._create_index(
|
||||
table_id=table_id,
|
||||
index_id='idx_characterId',
|
||||
index_type='key',
|
||||
attributes=['characterId']
|
||||
)
|
||||
|
||||
logger.info("Game sessions table initialized successfully", table_id=table_id)
|
||||
return True
|
||||
|
||||
except AppwriteException as e:
|
||||
logger.error("Failed to initialize game_sessions table",
|
||||
table_id=table_id,
|
||||
error=str(e),
|
||||
code=e.code)
|
||||
raise
|
||||
|
||||
def init_ai_usage_logs_table(self) -> bool:
|
||||
"""
|
||||
Initialize the ai_usage_logs table for tracking AI API usage and costs.
|
||||
|
||||
Table schema:
|
||||
- user_id (string, required): User who made the request
|
||||
- timestamp (string, required): ISO timestamp of the request
|
||||
- model (string, required): Model identifier
|
||||
- tokens_input (integer, required): Input token count
|
||||
- tokens_output (integer, required): Output token count
|
||||
- tokens_total (integer, required): Total token count
|
||||
- estimated_cost (float, required): Estimated cost in USD
|
||||
- task_type (string, required): Type of task
|
||||
- session_id (string, optional): Game session ID
|
||||
- character_id (string, optional): Character ID
|
||||
- request_duration_ms (integer): Request duration in milliseconds
|
||||
- success (boolean): Whether request succeeded
|
||||
- error_message (string, optional): Error message if failed
|
||||
|
||||
Indexes:
|
||||
- user_id: For user usage queries
|
||||
- timestamp: For date range queries
|
||||
- user_id + timestamp: Composite for user date range queries
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
|
||||
Raises:
|
||||
AppwriteException: If table creation fails
|
||||
"""
|
||||
table_id = 'ai_usage_logs'
|
||||
|
||||
logger.info("Initializing ai_usage_logs table", table_id=table_id)
|
||||
|
||||
try:
|
||||
# Check if table already exists
|
||||
try:
|
||||
self.tables_db.get_table(
|
||||
database_id=self.database_id,
|
||||
table_id=table_id
|
||||
)
|
||||
logger.info("AI usage logs table already exists", table_id=table_id)
|
||||
return True
|
||||
except AppwriteException as e:
|
||||
if e.code != 404:
|
||||
raise
|
||||
logger.info("AI usage logs table does not exist, creating...")
|
||||
|
||||
# Create table
|
||||
logger.info("Creating ai_usage_logs table")
|
||||
table = self.tables_db.create_table(
|
||||
database_id=self.database_id,
|
||||
table_id=table_id,
|
||||
name='AI Usage Logs'
|
||||
)
|
||||
logger.info("AI usage logs table created", table_id=table['$id'])
|
||||
|
||||
# Create columns
|
||||
self._create_column(
|
||||
table_id=table_id,
|
||||
column_id='user_id',
|
||||
column_type='string',
|
||||
size=255,
|
||||
required=True
|
||||
)
|
||||
|
||||
self._create_column(
|
||||
table_id=table_id,
|
||||
column_id='timestamp',
|
||||
column_type='string',
|
||||
size=50, # ISO timestamp format
|
||||
required=True
|
||||
)
|
||||
|
||||
self._create_column(
|
||||
table_id=table_id,
|
||||
column_id='model',
|
||||
column_type='string',
|
||||
size=255,
|
||||
required=True
|
||||
)
|
||||
|
||||
self._create_column(
|
||||
table_id=table_id,
|
||||
column_id='tokens_input',
|
||||
column_type='integer',
|
||||
required=True
|
||||
)
|
||||
|
||||
self._create_column(
|
||||
table_id=table_id,
|
||||
column_id='tokens_output',
|
||||
column_type='integer',
|
||||
required=True
|
||||
)
|
||||
|
||||
self._create_column(
|
||||
table_id=table_id,
|
||||
column_id='tokens_total',
|
||||
column_type='integer',
|
||||
required=True
|
||||
)
|
||||
|
||||
self._create_column(
|
||||
table_id=table_id,
|
||||
column_id='estimated_cost',
|
||||
column_type='float',
|
||||
required=True
|
||||
)
|
||||
|
||||
self._create_column(
|
||||
table_id=table_id,
|
||||
column_id='task_type',
|
||||
column_type='string',
|
||||
size=50,
|
||||
required=True
|
||||
)
|
||||
|
||||
self._create_column(
|
||||
table_id=table_id,
|
||||
column_id='session_id',
|
||||
column_type='string',
|
||||
size=255,
|
||||
required=False
|
||||
)
|
||||
|
||||
self._create_column(
|
||||
table_id=table_id,
|
||||
column_id='character_id',
|
||||
column_type='string',
|
||||
size=255,
|
||||
required=False
|
||||
)
|
||||
|
||||
self._create_column(
|
||||
table_id=table_id,
|
||||
column_id='request_duration_ms',
|
||||
column_type='integer',
|
||||
required=False,
|
||||
default=0
|
||||
)
|
||||
|
||||
self._create_column(
|
||||
table_id=table_id,
|
||||
column_id='success',
|
||||
column_type='boolean',
|
||||
required=False,
|
||||
default=True
|
||||
)
|
||||
|
||||
self._create_column(
|
||||
table_id=table_id,
|
||||
column_id='error_message',
|
||||
column_type='string',
|
||||
size=1000,
|
||||
required=False
|
||||
)
|
||||
|
||||
# Wait for columns to fully propagate
|
||||
logger.info("Waiting for columns to propagate before creating indexes...")
|
||||
time.sleep(2)
|
||||
|
||||
# Create indexes
|
||||
self._create_index(
|
||||
table_id=table_id,
|
||||
index_id='idx_user_id',
|
||||
index_type='key',
|
||||
attributes=['user_id']
|
||||
)
|
||||
|
||||
self._create_index(
|
||||
table_id=table_id,
|
||||
index_id='idx_timestamp',
|
||||
index_type='key',
|
||||
attributes=['timestamp']
|
||||
)
|
||||
|
||||
self._create_index(
|
||||
table_id=table_id,
|
||||
index_id='idx_user_id_timestamp',
|
||||
index_type='key',
|
||||
attributes=['user_id', 'timestamp']
|
||||
)
|
||||
|
||||
logger.info("AI usage logs table initialized successfully", table_id=table_id)
|
||||
return True
|
||||
|
||||
except AppwriteException as e:
|
||||
logger.error("Failed to initialize ai_usage_logs table",
|
||||
table_id=table_id,
|
||||
error=str(e),
|
||||
code=e.code)
|
||||
raise
|
||||
|
||||
def _create_column(
|
||||
self,
|
||||
table_id: str,
|
||||
column_id: str,
|
||||
column_type: str,
|
||||
size: Optional[int] = None,
|
||||
required: bool = False,
|
||||
default: Optional[Any] = None,
|
||||
array: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a column in a table.
|
||||
|
||||
Args:
|
||||
table_id: Table ID
|
||||
column_id: Column ID
|
||||
column_type: Column type (string, integer, float, boolean, datetime, email, ip, url)
|
||||
size: Column size (for string types)
|
||||
required: Whether column is required
|
||||
default: Default value
|
||||
array: Whether column is an array
|
||||
|
||||
Returns:
|
||||
Column creation response
|
||||
|
||||
Raises:
|
||||
AppwriteException: If column creation fails
|
||||
"""
|
||||
try:
|
||||
logger.info("Creating column",
|
||||
table_id=table_id,
|
||||
column_id=column_id,
|
||||
column_type=column_type)
|
||||
|
||||
# Build column parameters (Appwrite SDK uses 'key' not 'column_id')
|
||||
params = {
|
||||
'database_id': self.database_id,
|
||||
'table_id': table_id,
|
||||
'key': column_id,
|
||||
'required': required,
|
||||
'array': array
|
||||
}
|
||||
|
||||
if size is not None:
|
||||
params['size'] = size
|
||||
|
||||
if default is not None:
|
||||
params['default'] = default
|
||||
|
||||
# Create column using the appropriate method based on type
|
||||
if column_type == 'string':
|
||||
result = self.tables_db.create_string_column(**params)
|
||||
elif column_type == 'integer':
|
||||
result = self.tables_db.create_integer_column(**params)
|
||||
elif column_type == 'float':
|
||||
result = self.tables_db.create_float_column(**params)
|
||||
elif column_type == 'boolean':
|
||||
result = self.tables_db.create_boolean_column(**params)
|
||||
elif column_type == 'datetime':
|
||||
result = self.tables_db.create_datetime_column(**params)
|
||||
elif column_type == 'email':
|
||||
result = self.tables_db.create_email_column(**params)
|
||||
else:
|
||||
raise ValueError(f"Unsupported column type: {column_type}")
|
||||
|
||||
logger.info("Column created successfully",
|
||||
table_id=table_id,
|
||||
column_id=column_id)
|
||||
|
||||
return result
|
||||
|
||||
except AppwriteException as e:
|
||||
# If column already exists, log warning but don't fail
|
||||
if e.code == 409: # Conflict - column already exists
|
||||
logger.warning("Column already exists",
|
||||
table_id=table_id,
|
||||
column_id=column_id)
|
||||
return {}
|
||||
logger.error("Failed to create column",
|
||||
table_id=table_id,
|
||||
column_id=column_id,
|
||||
error=str(e),
|
||||
code=e.code)
|
||||
raise
|
||||
|
||||
def _create_index(
|
||||
self,
|
||||
table_id: str,
|
||||
index_id: str,
|
||||
index_type: str,
|
||||
attributes: List[str],
|
||||
orders: Optional[List[str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create an index on a table.
|
||||
|
||||
Args:
|
||||
table_id: Table ID
|
||||
index_id: Index ID
|
||||
index_type: Index type (key, fulltext, unique)
|
||||
attributes: List of column IDs to index
|
||||
orders: List of sort orders (ASC, DESC) for each attribute
|
||||
|
||||
Returns:
|
||||
Index creation response
|
||||
|
||||
Raises:
|
||||
AppwriteException: If index creation fails
|
||||
"""
|
||||
try:
|
||||
logger.info("Creating index",
|
||||
table_id=table_id,
|
||||
index_id=index_id,
|
||||
attributes=attributes)
|
||||
|
||||
result = self.tables_db.create_index(
|
||||
database_id=self.database_id,
|
||||
table_id=table_id,
|
||||
key=index_id,
|
||||
type=index_type,
|
||||
columns=attributes, # SDK uses 'columns', not 'attributes'
|
||||
orders=orders or ['ASC'] * len(attributes)
|
||||
)
|
||||
|
||||
logger.info("Index created successfully",
|
||||
table_id=table_id,
|
||||
index_id=index_id)
|
||||
|
||||
return result
|
||||
|
||||
except AppwriteException as e:
|
||||
# If index already exists, log warning but don't fail
|
||||
if e.code == 409: # Conflict - index already exists
|
||||
logger.warning("Index already exists",
|
||||
table_id=table_id,
|
||||
index_id=index_id)
|
||||
return {}
|
||||
logger.error("Failed to create index",
|
||||
table_id=table_id,
|
||||
index_id=index_id,
|
||||
error=str(e),
|
||||
code=e.code)
|
||||
raise
|
||||
|
||||
|
||||
# Global instance for convenience
|
||||
_init_service_instance: Optional[DatabaseInitService] = None
|
||||
|
||||
|
||||
def get_database_init_service() -> DatabaseInitService:
|
||||
"""
|
||||
Get the global DatabaseInitService instance.
|
||||
|
||||
Returns:
|
||||
Singleton DatabaseInitService instance
|
||||
"""
|
||||
global _init_service_instance
|
||||
if _init_service_instance is None:
|
||||
_init_service_instance = DatabaseInitService()
|
||||
return _init_service_instance
|
||||
|
||||
|
||||
def init_database() -> Dict[str, bool]:
|
||||
"""
|
||||
Convenience function to initialize all database tables.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping table names to success status
|
||||
"""
|
||||
service = get_database_init_service()
|
||||
return service.init_all_tables()
|
||||
441
api/app/services/database_service.py
Normal file
441
api/app/services/database_service.py
Normal file
@@ -0,0 +1,441 @@
|
||||
"""
|
||||
Database Service for Appwrite database operations.
|
||||
|
||||
This service wraps the Appwrite Databases SDK to provide a clean interface
|
||||
for CRUD operations on collections. It handles JSON serialization, error handling,
|
||||
and provides structured logging.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timezone
|
||||
from dataclasses import dataclass
|
||||
|
||||
from appwrite.client import Client
|
||||
from appwrite.services.tables_db import TablesDB
|
||||
from appwrite.exception import AppwriteException
|
||||
from appwrite.id import ID
|
||||
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatabaseRow:
|
||||
"""
|
||||
Represents a row in an Appwrite table.
|
||||
|
||||
Attributes:
|
||||
id: Row ID
|
||||
table_id: Table ID
|
||||
data: Row data (parsed from JSON)
|
||||
created_at: Creation timestamp
|
||||
updated_at: Last update timestamp
|
||||
"""
|
||||
id: str
|
||||
table_id: str
|
||||
data: Dict[str, Any]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert row to dictionary."""
|
||||
return {
|
||||
"id": self.id,
|
||||
"table_id": self.table_id,
|
||||
"data": self.data,
|
||||
"created_at": self.created_at.isoformat() if isinstance(self.created_at, datetime) else self.created_at,
|
||||
"updated_at": self.updated_at.isoformat() if isinstance(self.updated_at, datetime) else self.updated_at,
|
||||
}
|
||||
|
||||
|
||||
class DatabaseService:
|
||||
"""
|
||||
Service for interacting with Appwrite database tables.
|
||||
|
||||
This service provides methods for:
|
||||
- Creating rows
|
||||
- Reading rows by ID or query
|
||||
- Updating rows
|
||||
- Deleting rows
|
||||
- Querying with filters
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the database service.
|
||||
|
||||
Reads configuration from environment variables:
|
||||
- APPWRITE_ENDPOINT: Appwrite API endpoint
|
||||
- APPWRITE_PROJECT_ID: Appwrite project ID
|
||||
- APPWRITE_API_KEY: Appwrite API key
|
||||
- APPWRITE_DATABASE_ID: Appwrite database ID
|
||||
"""
|
||||
self.endpoint = os.getenv('APPWRITE_ENDPOINT')
|
||||
self.project_id = os.getenv('APPWRITE_PROJECT_ID')
|
||||
self.api_key = os.getenv('APPWRITE_API_KEY')
|
||||
self.database_id = os.getenv('APPWRITE_DATABASE_ID', 'main')
|
||||
|
||||
if not all([self.endpoint, self.project_id, self.api_key]):
|
||||
logger.error("Missing Appwrite configuration in environment variables")
|
||||
raise ValueError("Appwrite configuration incomplete. Check APPWRITE_* environment variables.")
|
||||
|
||||
# Initialize Appwrite client
|
||||
self.client = Client()
|
||||
self.client.set_endpoint(self.endpoint)
|
||||
self.client.set_project(self.project_id)
|
||||
self.client.set_key(self.api_key)
|
||||
|
||||
# Initialize TablesDB service
|
||||
self.tables_db = TablesDB(self.client)
|
||||
|
||||
logger.info("DatabaseService initialized", database_id=self.database_id)
|
||||
|
||||
def create_row(
|
||||
self,
|
||||
table_id: str,
|
||||
data: Dict[str, Any],
|
||||
row_id: Optional[str] = None,
|
||||
permissions: Optional[List[str]] = None
|
||||
) -> DatabaseRow:
|
||||
"""
|
||||
Create a new row in a table.
|
||||
|
||||
Args:
|
||||
table_id: Table ID (e.g., "characters")
|
||||
data: Row data (will be JSON-serialized if needed)
|
||||
row_id: Optional custom row ID (auto-generated if None)
|
||||
permissions: Optional permissions array
|
||||
|
||||
Returns:
|
||||
DatabaseRow with created row
|
||||
|
||||
Raises:
|
||||
AppwriteException: If creation fails
|
||||
"""
|
||||
try:
|
||||
logger.info("Creating row", table_id=table_id, has_custom_id=bool(row_id))
|
||||
|
||||
# Generate ID if not provided
|
||||
if row_id is None:
|
||||
row_id = ID.unique()
|
||||
|
||||
# Create row (Appwrite manages timestamps automatically via $createdAt/$updatedAt)
|
||||
result = self.tables_db.create_row(
|
||||
database_id=self.database_id,
|
||||
table_id=table_id,
|
||||
row_id=row_id,
|
||||
data=data,
|
||||
permissions=permissions or []
|
||||
)
|
||||
|
||||
logger.info("Row created successfully",
|
||||
table_id=table_id,
|
||||
row_id=result['$id'])
|
||||
|
||||
return self._parse_row(result, table_id)
|
||||
|
||||
except AppwriteException as e:
|
||||
logger.error("Failed to create row",
|
||||
table_id=table_id,
|
||||
error=str(e),
|
||||
code=e.code)
|
||||
raise
|
||||
|
||||
def get_row(self, table_id: str, row_id: str) -> Optional[DatabaseRow]:
|
||||
"""
|
||||
Get a row by ID.
|
||||
|
||||
Args:
|
||||
table_id: Table ID
|
||||
row_id: Row ID
|
||||
|
||||
Returns:
|
||||
DatabaseRow or None if not found
|
||||
|
||||
Raises:
|
||||
AppwriteException: If retrieval fails (except 404)
|
||||
"""
|
||||
try:
|
||||
logger.debug("Fetching row", table_id=table_id, row_id=row_id)
|
||||
|
||||
result = self.tables_db.get_row(
|
||||
database_id=self.database_id,
|
||||
table_id=table_id,
|
||||
row_id=row_id
|
||||
)
|
||||
|
||||
return self._parse_row(result, table_id)
|
||||
|
||||
except AppwriteException as e:
|
||||
if e.code == 404:
|
||||
logger.warning("Row not found",
|
||||
table_id=table_id,
|
||||
row_id=row_id)
|
||||
return None
|
||||
logger.error("Failed to fetch row",
|
||||
table_id=table_id,
|
||||
row_id=row_id,
|
||||
error=str(e),
|
||||
code=e.code)
|
||||
raise
|
||||
|
||||
def update_row(
|
||||
self,
|
||||
table_id: str,
|
||||
row_id: str,
|
||||
data: Dict[str, Any],
|
||||
permissions: Optional[List[str]] = None
|
||||
) -> DatabaseRow:
|
||||
"""
|
||||
Update an existing row.
|
||||
|
||||
Args:
|
||||
table_id: Table ID
|
||||
row_id: Row ID
|
||||
data: New row data (partial updates supported)
|
||||
permissions: Optional permissions array
|
||||
|
||||
Returns:
|
||||
DatabaseRow with updated row
|
||||
|
||||
Raises:
|
||||
AppwriteException: If update fails
|
||||
"""
|
||||
try:
|
||||
logger.info("Updating row", table_id=table_id, row_id=row_id)
|
||||
|
||||
# Update row (Appwrite manages timestamps automatically via $updatedAt)
|
||||
result = self.tables_db.update_row(
|
||||
database_id=self.database_id,
|
||||
table_id=table_id,
|
||||
row_id=row_id,
|
||||
data=data,
|
||||
permissions=permissions
|
||||
)
|
||||
|
||||
logger.info("Row updated successfully",
|
||||
table_id=table_id,
|
||||
row_id=row_id)
|
||||
|
||||
return self._parse_row(result, table_id)
|
||||
|
||||
except AppwriteException as e:
|
||||
logger.error("Failed to update row",
|
||||
table_id=table_id,
|
||||
row_id=row_id,
|
||||
error=str(e),
|
||||
code=e.code)
|
||||
raise
|
||||
|
||||
def delete_row(self, table_id: str, row_id: str) -> bool:
|
||||
"""
|
||||
Delete a row.
|
||||
|
||||
Args:
|
||||
table_id: Table ID
|
||||
row_id: Row ID
|
||||
|
||||
Returns:
|
||||
True if deletion successful
|
||||
|
||||
Raises:
|
||||
AppwriteException: If deletion fails
|
||||
"""
|
||||
try:
|
||||
logger.info("Deleting row", table_id=table_id, row_id=row_id)
|
||||
|
||||
self.tables_db.delete_row(
|
||||
database_id=self.database_id,
|
||||
table_id=table_id,
|
||||
row_id=row_id
|
||||
)
|
||||
|
||||
logger.info("Row deleted successfully",
|
||||
table_id=table_id,
|
||||
row_id=row_id)
|
||||
return True
|
||||
|
||||
except AppwriteException as e:
|
||||
logger.error("Failed to delete row",
|
||||
table_id=table_id,
|
||||
row_id=row_id,
|
||||
error=str(e),
|
||||
code=e.code)
|
||||
raise
|
||||
|
||||
def list_rows(
|
||||
self,
|
||||
table_id: str,
|
||||
queries: Optional[List[str]] = None,
|
||||
limit: int = 25,
|
||||
offset: int = 0
|
||||
) -> List[DatabaseRow]:
|
||||
"""
|
||||
List rows in a table with optional filtering.
|
||||
|
||||
Args:
|
||||
table_id: Table ID
|
||||
queries: Optional Appwrite query filters
|
||||
limit: Maximum rows to return (default 25, max 100)
|
||||
offset: Number of rows to skip
|
||||
|
||||
Returns:
|
||||
List of DatabaseRow instances
|
||||
|
||||
Raises:
|
||||
AppwriteException: If query fails
|
||||
"""
|
||||
try:
|
||||
logger.debug("Listing rows",
|
||||
table_id=table_id,
|
||||
has_queries=bool(queries),
|
||||
limit=limit,
|
||||
offset=offset)
|
||||
|
||||
result = self.tables_db.list_rows(
|
||||
database_id=self.database_id,
|
||||
table_id=table_id,
|
||||
queries=queries or []
|
||||
)
|
||||
|
||||
rows = [self._parse_row(row, table_id) for row in result['rows']]
|
||||
|
||||
logger.debug("Rows listed successfully",
|
||||
table_id=table_id,
|
||||
count=len(rows),
|
||||
total=result.get('total', len(rows)))
|
||||
|
||||
return rows
|
||||
|
||||
except AppwriteException as e:
|
||||
logger.error("Failed to list rows",
|
||||
table_id=table_id,
|
||||
error=str(e),
|
||||
code=e.code)
|
||||
raise
|
||||
|
||||
def count_rows(self, table_id: str, queries: Optional[List[str]] = None) -> int:
|
||||
"""
|
||||
Count rows in a table with optional filtering.
|
||||
|
||||
Args:
|
||||
table_id: Table ID
|
||||
queries: Optional Appwrite query filters
|
||||
|
||||
Returns:
|
||||
Row count
|
||||
|
||||
Raises:
|
||||
AppwriteException: If query fails
|
||||
"""
|
||||
try:
|
||||
logger.debug("Counting rows", table_id=table_id, has_queries=bool(queries))
|
||||
|
||||
result = self.tables_db.list_rows(
|
||||
database_id=self.database_id,
|
||||
table_id=table_id,
|
||||
queries=queries or []
|
||||
)
|
||||
|
||||
count = result.get('total', len(result.get('rows', [])))
|
||||
logger.debug("Rows counted", table_id=table_id, count=count)
|
||||
return count
|
||||
|
||||
except AppwriteException as e:
|
||||
logger.error("Failed to count rows",
|
||||
table_id=table_id,
|
||||
error=str(e),
|
||||
code=e.code)
|
||||
raise
|
||||
|
||||
def _parse_row(self, row: Dict[str, Any], table_id: str) -> DatabaseRow:
|
||||
"""
|
||||
Parse Appwrite row into DatabaseRow.
|
||||
|
||||
Args:
|
||||
row: Appwrite row dictionary
|
||||
table_id: Table ID
|
||||
|
||||
Returns:
|
||||
DatabaseRow instance
|
||||
"""
|
||||
# Extract metadata
|
||||
row_id = row['$id']
|
||||
created_at = row.get('$createdAt', datetime.now(timezone.utc).isoformat())
|
||||
updated_at = row.get('$updatedAt', datetime.now(timezone.utc).isoformat())
|
||||
|
||||
# Parse timestamps
|
||||
if isinstance(created_at, str):
|
||||
created_at = datetime.fromisoformat(created_at.replace('Z', '+00:00'))
|
||||
if isinstance(updated_at, str):
|
||||
updated_at = datetime.fromisoformat(updated_at.replace('Z', '+00:00'))
|
||||
|
||||
# Remove Appwrite metadata from data
|
||||
data = {k: v for k, v in row.items() if not k.startswith('$')}
|
||||
|
||||
return DatabaseRow(
|
||||
id=row_id,
|
||||
table_id=table_id,
|
||||
data=data,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at
|
||||
)
|
||||
|
||||
# Backward compatibility aliases (deprecated, use new methods)
|
||||
def create_document(self, collection_id: str, data: Dict[str, Any],
|
||||
document_id: Optional[str] = None,
|
||||
permissions: Optional[List[str]] = None) -> DatabaseRow:
|
||||
"""Deprecated: Use create_row() instead."""
|
||||
logger.warning("create_document() is deprecated, use create_row() instead")
|
||||
return self.create_row(collection_id, data, document_id, permissions)
|
||||
|
||||
def get_document(self, collection_id: str, document_id: str) -> Optional[DatabaseRow]:
|
||||
"""Deprecated: Use get_row() instead."""
|
||||
logger.warning("get_document() is deprecated, use get_row() instead")
|
||||
return self.get_row(collection_id, document_id)
|
||||
|
||||
def update_document(self, collection_id: str, document_id: str,
|
||||
data: Dict[str, Any],
|
||||
permissions: Optional[List[str]] = None) -> DatabaseRow:
|
||||
"""Deprecated: Use update_row() instead."""
|
||||
logger.warning("update_document() is deprecated, use update_row() instead")
|
||||
return self.update_row(collection_id, document_id, data, permissions)
|
||||
|
||||
def delete_document(self, collection_id: str, document_id: str) -> bool:
|
||||
"""Deprecated: Use delete_row() instead."""
|
||||
logger.warning("delete_document() is deprecated, use delete_row() instead")
|
||||
return self.delete_row(collection_id, document_id)
|
||||
|
||||
def list_documents(self, collection_id: str, queries: Optional[List[str]] = None,
|
||||
limit: int = 25, offset: int = 0) -> List[DatabaseRow]:
|
||||
"""Deprecated: Use list_rows() instead."""
|
||||
logger.warning("list_documents() is deprecated, use list_rows() instead")
|
||||
return self.list_rows(collection_id, queries, limit, offset)
|
||||
|
||||
def count_documents(self, collection_id: str, queries: Optional[List[str]] = None) -> int:
|
||||
"""Deprecated: Use count_rows() instead."""
|
||||
logger.warning("count_documents() is deprecated, use count_rows() instead")
|
||||
return self.count_rows(collection_id, queries)
|
||||
|
||||
|
||||
# Backward compatibility alias
|
||||
DatabaseDocument = DatabaseRow
|
||||
|
||||
|
||||
# Global instance for convenience
|
||||
_service_instance: Optional[DatabaseService] = None
|
||||
|
||||
|
||||
def get_database_service() -> DatabaseService:
|
||||
"""
|
||||
Get the global DatabaseService instance.
|
||||
|
||||
Returns:
|
||||
Singleton DatabaseService instance
|
||||
"""
|
||||
global _service_instance
|
||||
if _service_instance is None:
|
||||
_service_instance = DatabaseService()
|
||||
return _service_instance
|
||||
351
api/app/services/item_validator.py
Normal file
351
api/app/services/item_validator.py
Normal file
@@ -0,0 +1,351 @@
|
||||
"""
|
||||
Item validation service for AI-granted items.
|
||||
|
||||
This module validates and resolves items that the AI grants to players during
|
||||
gameplay, ensuring they meet character requirements and game balance rules.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import structlog
|
||||
import yaml
|
||||
|
||||
from app.models.items import Item
|
||||
from app.models.enums import ItemType
|
||||
from app.models.character import Character
|
||||
from app.ai.response_parser import ItemGrant
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class ItemValidationError(Exception):
|
||||
"""
|
||||
Exception raised when an item fails validation.
|
||||
|
||||
Attributes:
|
||||
message: Human-readable error message
|
||||
item_grant: The ItemGrant that failed validation
|
||||
reason: Machine-readable reason code
|
||||
"""
|
||||
|
||||
def __init__(self, message: str, item_grant: ItemGrant, reason: str):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.item_grant = item_grant
|
||||
self.reason = reason
|
||||
|
||||
|
||||
class ItemValidator:
|
||||
"""
|
||||
Validates and resolves items granted by the AI.
|
||||
|
||||
This service:
|
||||
1. Resolves item references (by ID or creates generic items)
|
||||
2. Validates items against character requirements
|
||||
3. Logs validation failures for review
|
||||
"""
|
||||
|
||||
# Map of generic item type strings to ItemType enums
|
||||
TYPE_MAP = {
|
||||
"weapon": ItemType.WEAPON,
|
||||
"armor": ItemType.ARMOR,
|
||||
"consumable": ItemType.CONSUMABLE,
|
||||
"quest_item": ItemType.QUEST_ITEM,
|
||||
}
|
||||
|
||||
def __init__(self, data_path: Optional[Path] = None):
|
||||
"""
|
||||
Initialize the item validator.
|
||||
|
||||
Args:
|
||||
data_path: Path to game data directory. Defaults to app/data/
|
||||
"""
|
||||
if data_path is None:
|
||||
# Default to api/app/data/
|
||||
data_path = Path(__file__).parent.parent / "data"
|
||||
|
||||
self.data_path = data_path
|
||||
self._item_registry: dict[str, dict] = {}
|
||||
self._generic_templates: dict[str, dict] = {}
|
||||
self._load_data()
|
||||
|
||||
logger.info(
|
||||
"ItemValidator initialized",
|
||||
items_loaded=len(self._item_registry),
|
||||
generic_templates_loaded=len(self._generic_templates)
|
||||
)
|
||||
|
||||
def _load_data(self) -> None:
|
||||
"""Load item data from YAML files."""
|
||||
# Load main item registry if it exists
|
||||
items_file = self.data_path / "items.yaml"
|
||||
if items_file.exists():
|
||||
with open(items_file) as f:
|
||||
data = yaml.safe_load(f) or {}
|
||||
self._item_registry = data.get("items", {})
|
||||
|
||||
# Load generic item templates
|
||||
generic_file = self.data_path / "generic_items.yaml"
|
||||
if generic_file.exists():
|
||||
with open(generic_file) as f:
|
||||
data = yaml.safe_load(f) or {}
|
||||
self._generic_templates = data.get("templates", {})
|
||||
|
||||
def resolve_item(self, item_grant: ItemGrant) -> Item:
|
||||
"""
|
||||
Resolve an ItemGrant to an actual Item instance.
|
||||
|
||||
For existing items (by item_id), looks up from item registry.
|
||||
For generic items (by name/type), creates a new Item.
|
||||
|
||||
Args:
|
||||
item_grant: The ItemGrant from AI response
|
||||
|
||||
Returns:
|
||||
Resolved Item instance
|
||||
|
||||
Raises:
|
||||
ItemValidationError: If item cannot be resolved
|
||||
"""
|
||||
if item_grant.is_existing_item():
|
||||
return self._resolve_existing_item(item_grant)
|
||||
elif item_grant.is_generic_item():
|
||||
return self._create_generic_item(item_grant)
|
||||
else:
|
||||
raise ItemValidationError(
|
||||
"ItemGrant has neither item_id nor name",
|
||||
item_grant,
|
||||
"INVALID_ITEM_GRANT"
|
||||
)
|
||||
|
||||
def _resolve_existing_item(self, item_grant: ItemGrant) -> Item:
|
||||
"""
|
||||
Look up an existing item by ID.
|
||||
|
||||
Args:
|
||||
item_grant: ItemGrant with item_id set
|
||||
|
||||
Returns:
|
||||
Item instance from registry
|
||||
|
||||
Raises:
|
||||
ItemValidationError: If item not found
|
||||
"""
|
||||
item_id = item_grant.item_id
|
||||
|
||||
if item_id not in self._item_registry:
|
||||
logger.warning(
|
||||
"Item not found in registry",
|
||||
item_id=item_id
|
||||
)
|
||||
raise ItemValidationError(
|
||||
f"Unknown item_id: {item_id}",
|
||||
item_grant,
|
||||
"ITEM_NOT_FOUND"
|
||||
)
|
||||
|
||||
item_data = self._item_registry[item_id]
|
||||
|
||||
# Convert to Item instance
|
||||
return Item.from_dict({
|
||||
"item_id": item_id,
|
||||
**item_data
|
||||
})
|
||||
|
||||
def _create_generic_item(self, item_grant: ItemGrant) -> Item:
|
||||
"""
|
||||
Create a generic item from AI-provided details.
|
||||
|
||||
Generic items are simple items with no special stats,
|
||||
suitable for mundane objects like torches, food, etc.
|
||||
|
||||
Args:
|
||||
item_grant: ItemGrant with name, type, description
|
||||
|
||||
Returns:
|
||||
New Item instance
|
||||
|
||||
Raises:
|
||||
ItemValidationError: If item type is invalid
|
||||
"""
|
||||
# Validate item type
|
||||
item_type_str = (item_grant.item_type or "consumable").lower()
|
||||
if item_type_str not in self.TYPE_MAP:
|
||||
logger.warning(
|
||||
"Invalid item type from AI",
|
||||
item_type=item_type_str,
|
||||
item_name=item_grant.name
|
||||
)
|
||||
# Default to consumable for unknown types
|
||||
item_type_str = "consumable"
|
||||
|
||||
item_type = self.TYPE_MAP[item_type_str]
|
||||
|
||||
# Generate unique ID for this item instance
|
||||
item_id = f"generic_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Check if we have a template for this item name
|
||||
template = self._find_template(item_grant.name or "")
|
||||
|
||||
if template:
|
||||
# Use template values as defaults
|
||||
return Item(
|
||||
item_id=item_id,
|
||||
name=item_grant.name or template.get("name", "Unknown Item"),
|
||||
item_type=item_type,
|
||||
description=item_grant.description or template.get("description", ""),
|
||||
value=item_grant.value or template.get("value", 0),
|
||||
is_tradeable=template.get("is_tradeable", True),
|
||||
required_level=template.get("required_level", 1),
|
||||
)
|
||||
else:
|
||||
# Create with provided values only
|
||||
return Item(
|
||||
item_id=item_id,
|
||||
name=item_grant.name or "Unknown Item",
|
||||
item_type=item_type,
|
||||
description=item_grant.description or "A simple item.",
|
||||
value=item_grant.value,
|
||||
is_tradeable=True,
|
||||
required_level=1,
|
||||
)
|
||||
|
||||
def _find_template(self, item_name: str) -> Optional[dict]:
|
||||
"""
|
||||
Find a generic item template by name.
|
||||
|
||||
Uses case-insensitive partial matching.
|
||||
|
||||
Args:
|
||||
item_name: Name of the item to find
|
||||
|
||||
Returns:
|
||||
Template dict or None if not found
|
||||
"""
|
||||
name_lower = item_name.lower()
|
||||
|
||||
# Exact match first
|
||||
if name_lower in self._generic_templates:
|
||||
return self._generic_templates[name_lower]
|
||||
|
||||
# Partial match
|
||||
for template_name, template in self._generic_templates.items():
|
||||
if template_name in name_lower or name_lower in template_name:
|
||||
return template
|
||||
|
||||
return None
|
||||
|
||||
def validate_item_for_character(
|
||||
self,
|
||||
item: Item,
|
||||
character: Character
|
||||
) -> tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Validate that a character can receive an item.
|
||||
|
||||
Checks:
|
||||
- Level requirements
|
||||
- Class restrictions
|
||||
|
||||
Args:
|
||||
item: The Item to validate
|
||||
character: The Character to receive the item
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
# Check level requirement
|
||||
if item.required_level > character.level:
|
||||
error_msg = (
|
||||
f"Item '{item.name}' requires level {item.required_level}, "
|
||||
f"but character is level {character.level}"
|
||||
)
|
||||
logger.warning(
|
||||
"Item validation failed: level requirement",
|
||||
item_name=item.name,
|
||||
required_level=item.required_level,
|
||||
character_level=character.level,
|
||||
character_name=character.name
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
# Check class restriction
|
||||
if item.required_class:
|
||||
character_class = character.player_class.class_id
|
||||
if item.required_class.lower() != character_class.lower():
|
||||
error_msg = (
|
||||
f"Item '{item.name}' requires class {item.required_class}, "
|
||||
f"but character is {character_class}"
|
||||
)
|
||||
logger.warning(
|
||||
"Item validation failed: class restriction",
|
||||
item_name=item.name,
|
||||
required_class=item.required_class,
|
||||
character_class=character_class,
|
||||
character_name=character.name
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
return True, None
|
||||
|
||||
def validate_and_resolve_item(
|
||||
self,
|
||||
item_grant: ItemGrant,
|
||||
character: Character
|
||||
) -> tuple[Optional[Item], Optional[str]]:
|
||||
"""
|
||||
Resolve an item grant and validate it for a character.
|
||||
|
||||
This is the main entry point for processing AI-granted items.
|
||||
|
||||
Args:
|
||||
item_grant: The ItemGrant from AI response
|
||||
character: The Character to receive the item
|
||||
|
||||
Returns:
|
||||
Tuple of (Item if valid else None, error_message if invalid else None)
|
||||
"""
|
||||
try:
|
||||
# Resolve the item
|
||||
item = self.resolve_item(item_grant)
|
||||
|
||||
# Validate for character
|
||||
is_valid, error_msg = self.validate_item_for_character(item, character)
|
||||
|
||||
if not is_valid:
|
||||
return None, error_msg
|
||||
|
||||
logger.info(
|
||||
"Item validated successfully",
|
||||
item_name=item.name,
|
||||
item_id=item.item_id,
|
||||
character_name=character.name
|
||||
)
|
||||
return item, None
|
||||
|
||||
except ItemValidationError as e:
|
||||
logger.warning(
|
||||
"Item resolution failed",
|
||||
error=e.message,
|
||||
reason=e.reason
|
||||
)
|
||||
return None, e.message
|
||||
|
||||
|
||||
# Global instance for convenience
|
||||
_validator_instance: Optional[ItemValidator] = None
|
||||
|
||||
|
||||
def get_item_validator() -> ItemValidator:
|
||||
"""
|
||||
Get or create the global ItemValidator instance.
|
||||
|
||||
Returns:
|
||||
ItemValidator singleton instance
|
||||
"""
|
||||
global _validator_instance
|
||||
if _validator_instance is None:
|
||||
_validator_instance = ItemValidator()
|
||||
return _validator_instance
|
||||
326
api/app/services/location_loader.py
Normal file
326
api/app/services/location_loader.py
Normal file
@@ -0,0 +1,326 @@
|
||||
"""
|
||||
LocationLoader service for loading location definitions from YAML files.
|
||||
|
||||
This service reads location configuration files and converts them into Location
|
||||
dataclass instances, providing caching for performance. Locations are organized
|
||||
by region subdirectories.
|
||||
"""
|
||||
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
import structlog
|
||||
|
||||
from app.models.location import Location, Region
|
||||
from app.models.enums import LocationType
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class LocationLoader:
|
||||
"""
|
||||
Loads location definitions from YAML configuration files.
|
||||
|
||||
Locations are organized in region subdirectories:
|
||||
/app/data/locations/
|
||||
regions/
|
||||
crossville.yaml
|
||||
crossville/
|
||||
crossville_village.yaml
|
||||
crossville_tavern.yaml
|
||||
|
||||
This allows game designers to define world locations without touching code.
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir: Optional[str] = None):
|
||||
"""
|
||||
Initialize the location loader.
|
||||
|
||||
Args:
|
||||
data_dir: Path to directory containing location YAML files.
|
||||
Defaults to /app/data/locations/
|
||||
"""
|
||||
if data_dir is None:
|
||||
# Default to app/data/locations relative to this file
|
||||
current_file = Path(__file__)
|
||||
app_dir = current_file.parent.parent # Go up to /app
|
||||
data_dir = str(app_dir / "data" / "locations")
|
||||
|
||||
self.data_dir = Path(data_dir)
|
||||
self._location_cache: Dict[str, Location] = {}
|
||||
self._region_cache: Dict[str, Region] = {}
|
||||
|
||||
logger.info("LocationLoader initialized", data_dir=str(self.data_dir))
|
||||
|
||||
def load_location(self, location_id: str) -> Optional[Location]:
|
||||
"""
|
||||
Load a single location by ID.
|
||||
|
||||
Searches all region subdirectories for the location file.
|
||||
|
||||
Args:
|
||||
location_id: Unique location identifier (e.g., "crossville_tavern")
|
||||
|
||||
Returns:
|
||||
Location instance or None if not found
|
||||
"""
|
||||
# Check cache first
|
||||
if location_id in self._location_cache:
|
||||
logger.debug("Location loaded from cache", location_id=location_id)
|
||||
return self._location_cache[location_id]
|
||||
|
||||
# Search in region subdirectories
|
||||
if not self.data_dir.exists():
|
||||
logger.error("Location data directory does not exist", data_dir=str(self.data_dir))
|
||||
return None
|
||||
|
||||
for region_dir in self.data_dir.iterdir():
|
||||
# Skip non-directories and the regions folder
|
||||
if not region_dir.is_dir() or region_dir.name == "regions":
|
||||
continue
|
||||
|
||||
file_path = region_dir / f"{location_id}.yaml"
|
||||
if file_path.exists():
|
||||
return self._load_location_file(file_path)
|
||||
|
||||
logger.warning("Location not found", location_id=location_id)
|
||||
return None
|
||||
|
||||
def _load_location_file(self, file_path: Path) -> Optional[Location]:
|
||||
"""
|
||||
Load a location from a specific file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the YAML file
|
||||
|
||||
Returns:
|
||||
Location instance or None if loading fails
|
||||
"""
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
location = self._parse_location_data(data)
|
||||
self._location_cache[location.location_id] = location
|
||||
|
||||
logger.info("Location loaded successfully", location_id=location.location_id)
|
||||
return location
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to load location", file=str(file_path), error=str(e))
|
||||
return None
|
||||
|
||||
def _parse_location_data(self, data: Dict) -> Location:
|
||||
"""
|
||||
Parse YAML data into a Location dataclass.
|
||||
|
||||
Args:
|
||||
data: Dictionary loaded from YAML file
|
||||
|
||||
Returns:
|
||||
Location instance
|
||||
|
||||
Raises:
|
||||
ValueError: If data is invalid or missing required fields
|
||||
"""
|
||||
# Validate required fields
|
||||
required_fields = ["location_id", "name", "region_id", "description"]
|
||||
for field in required_fields:
|
||||
if field not in data:
|
||||
raise ValueError(f"Missing required field: {field}")
|
||||
|
||||
# Parse location type - default to town
|
||||
location_type_str = data.get("location_type", "town")
|
||||
try:
|
||||
location_type = LocationType(location_type_str)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"Invalid location type, defaulting to town",
|
||||
location_id=data["location_id"],
|
||||
invalid_type=location_type_str
|
||||
)
|
||||
location_type = LocationType.TOWN
|
||||
|
||||
return Location(
|
||||
location_id=data["location_id"],
|
||||
name=data["name"],
|
||||
location_type=location_type,
|
||||
region_id=data["region_id"],
|
||||
description=data["description"],
|
||||
lore=data.get("lore"),
|
||||
ambient_description=data.get("ambient_description"),
|
||||
available_quests=data.get("available_quests", []),
|
||||
npc_ids=data.get("npc_ids", []),
|
||||
discoverable_locations=data.get("discoverable_locations", []),
|
||||
is_starting_location=data.get("is_starting_location", False),
|
||||
tags=data.get("tags", []),
|
||||
)
|
||||
|
||||
def load_all_locations(self) -> List[Location]:
|
||||
"""
|
||||
Load all locations from all region directories.
|
||||
|
||||
Returns:
|
||||
List of Location instances
|
||||
"""
|
||||
locations = []
|
||||
|
||||
if not self.data_dir.exists():
|
||||
logger.error("Location data directory does not exist", data_dir=str(self.data_dir))
|
||||
return locations
|
||||
|
||||
for region_dir in self.data_dir.iterdir():
|
||||
# Skip non-directories and the regions folder
|
||||
if not region_dir.is_dir() or region_dir.name == "regions":
|
||||
continue
|
||||
|
||||
for file_path in region_dir.glob("*.yaml"):
|
||||
location = self._load_location_file(file_path)
|
||||
if location:
|
||||
locations.append(location)
|
||||
|
||||
logger.info("All locations loaded", count=len(locations))
|
||||
return locations
|
||||
|
||||
def load_region(self, region_id: str) -> Optional[Region]:
|
||||
"""
|
||||
Load a region definition.
|
||||
|
||||
Args:
|
||||
region_id: Unique region identifier (e.g., "crossville")
|
||||
|
||||
Returns:
|
||||
Region instance or None if not found
|
||||
"""
|
||||
# Check cache first
|
||||
if region_id in self._region_cache:
|
||||
logger.debug("Region loaded from cache", region_id=region_id)
|
||||
return self._region_cache[region_id]
|
||||
|
||||
file_path = self.data_dir / "regions" / f"{region_id}.yaml"
|
||||
|
||||
if not file_path.exists():
|
||||
logger.warning("Region file not found", region_id=region_id)
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
region = Region.from_dict(data)
|
||||
self._region_cache[region_id] = region
|
||||
|
||||
logger.info("Region loaded successfully", region_id=region_id)
|
||||
return region
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to load region", region_id=region_id, error=str(e))
|
||||
return None
|
||||
|
||||
def get_locations_in_region(self, region_id: str) -> List[Location]:
|
||||
"""
|
||||
Get all locations belonging to a specific region.
|
||||
|
||||
Args:
|
||||
region_id: Region identifier
|
||||
|
||||
Returns:
|
||||
List of Location instances in this region
|
||||
"""
|
||||
# Load all locations if cache is empty
|
||||
if not self._location_cache:
|
||||
self.load_all_locations()
|
||||
|
||||
return [
|
||||
loc for loc in self._location_cache.values()
|
||||
if loc.region_id == region_id
|
||||
]
|
||||
|
||||
def get_starting_locations(self) -> List[Location]:
|
||||
"""
|
||||
Get all locations that can be starting points.
|
||||
|
||||
Returns:
|
||||
List of Location instances marked as starting locations
|
||||
"""
|
||||
# Load all locations if cache is empty
|
||||
if not self._location_cache:
|
||||
self.load_all_locations()
|
||||
|
||||
return [
|
||||
loc for loc in self._location_cache.values()
|
||||
if loc.is_starting_location
|
||||
]
|
||||
|
||||
def get_location_by_type(self, location_type: LocationType) -> List[Location]:
|
||||
"""
|
||||
Get all locations of a specific type.
|
||||
|
||||
Args:
|
||||
location_type: Type to filter by
|
||||
|
||||
Returns:
|
||||
List of Location instances of this type
|
||||
"""
|
||||
# Load all locations if cache is empty
|
||||
if not self._location_cache:
|
||||
self.load_all_locations()
|
||||
|
||||
return [
|
||||
loc for loc in self._location_cache.values()
|
||||
if loc.location_type == location_type
|
||||
]
|
||||
|
||||
def get_all_location_ids(self) -> List[str]:
|
||||
"""
|
||||
Get a list of all available location IDs.
|
||||
|
||||
Returns:
|
||||
List of location IDs
|
||||
"""
|
||||
# Load all locations if cache is empty
|
||||
if not self._location_cache:
|
||||
self.load_all_locations()
|
||||
|
||||
return list(self._location_cache.keys())
|
||||
|
||||
def reload_location(self, location_id: str) -> Optional[Location]:
|
||||
"""
|
||||
Force reload a location from disk, bypassing cache.
|
||||
|
||||
Useful for development/testing when location definitions change.
|
||||
|
||||
Args:
|
||||
location_id: Unique location identifier
|
||||
|
||||
Returns:
|
||||
Location instance or None if not found
|
||||
"""
|
||||
# Remove from cache if present
|
||||
if location_id in self._location_cache:
|
||||
del self._location_cache[location_id]
|
||||
|
||||
return self.load_location(location_id)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear all cached data. Useful for testing."""
|
||||
self._location_cache.clear()
|
||||
self._region_cache.clear()
|
||||
logger.info("Location cache cleared")
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
_loader_instance: Optional[LocationLoader] = None
|
||||
|
||||
|
||||
def get_location_loader() -> LocationLoader:
|
||||
"""
|
||||
Get the global LocationLoader instance.
|
||||
|
||||
Returns:
|
||||
Singleton LocationLoader instance
|
||||
"""
|
||||
global _loader_instance
|
||||
if _loader_instance is None:
|
||||
_loader_instance = LocationLoader()
|
||||
return _loader_instance
|
||||
385
api/app/services/npc_loader.py
Normal file
385
api/app/services/npc_loader.py
Normal file
@@ -0,0 +1,385 @@
|
||||
"""
|
||||
NPCLoader service for loading NPC definitions from YAML files.
|
||||
|
||||
This service reads NPC configuration files and converts them into NPC
|
||||
dataclass instances, providing caching for performance. NPCs are organized
|
||||
by region subdirectories.
|
||||
"""
|
||||
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
import structlog
|
||||
|
||||
from app.models.npc import (
|
||||
NPC,
|
||||
NPCPersonality,
|
||||
NPCAppearance,
|
||||
NPCKnowledge,
|
||||
NPCKnowledgeCondition,
|
||||
NPCRelationship,
|
||||
NPCInventoryItem,
|
||||
NPCDialogueHooks,
|
||||
)
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class NPCLoader:
|
||||
"""
|
||||
Loads NPC definitions from YAML configuration files.
|
||||
|
||||
NPCs are organized in region subdirectories:
|
||||
/app/data/npcs/
|
||||
crossville/
|
||||
npc_grom_001.yaml
|
||||
npc_mira_001.yaml
|
||||
|
||||
This allows game designers to define NPCs without touching code.
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir: Optional[str] = None):
|
||||
"""
|
||||
Initialize the NPC loader.
|
||||
|
||||
Args:
|
||||
data_dir: Path to directory containing NPC YAML files.
|
||||
Defaults to /app/data/npcs/
|
||||
"""
|
||||
if data_dir is None:
|
||||
# Default to app/data/npcs relative to this file
|
||||
current_file = Path(__file__)
|
||||
app_dir = current_file.parent.parent # Go up to /app
|
||||
data_dir = str(app_dir / "data" / "npcs")
|
||||
|
||||
self.data_dir = Path(data_dir)
|
||||
self._npc_cache: Dict[str, NPC] = {}
|
||||
self._location_npc_cache: Dict[str, List[str]] = {}
|
||||
|
||||
logger.info("NPCLoader initialized", data_dir=str(self.data_dir))
|
||||
|
||||
def load_npc(self, npc_id: str) -> Optional[NPC]:
|
||||
"""
|
||||
Load a single NPC by ID.
|
||||
|
||||
Searches all region subdirectories for the NPC file.
|
||||
|
||||
Args:
|
||||
npc_id: Unique NPC identifier (e.g., "npc_grom_001")
|
||||
|
||||
Returns:
|
||||
NPC instance or None if not found
|
||||
"""
|
||||
# Check cache first
|
||||
if npc_id in self._npc_cache:
|
||||
logger.debug("NPC loaded from cache", npc_id=npc_id)
|
||||
return self._npc_cache[npc_id]
|
||||
|
||||
# Search in region subdirectories
|
||||
if not self.data_dir.exists():
|
||||
logger.error("NPC data directory does not exist", data_dir=str(self.data_dir))
|
||||
return None
|
||||
|
||||
for region_dir in self.data_dir.iterdir():
|
||||
if not region_dir.is_dir():
|
||||
continue
|
||||
|
||||
file_path = region_dir / f"{npc_id}.yaml"
|
||||
if file_path.exists():
|
||||
return self._load_npc_file(file_path)
|
||||
|
||||
logger.warning("NPC not found", npc_id=npc_id)
|
||||
return None
|
||||
|
||||
def _load_npc_file(self, file_path: Path) -> Optional[NPC]:
|
||||
"""
|
||||
Load an NPC from a specific file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the YAML file
|
||||
|
||||
Returns:
|
||||
NPC instance or None if loading fails
|
||||
"""
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
npc = self._parse_npc_data(data)
|
||||
self._npc_cache[npc.npc_id] = npc
|
||||
|
||||
# Update location cache
|
||||
if npc.location_id not in self._location_npc_cache:
|
||||
self._location_npc_cache[npc.location_id] = []
|
||||
if npc.npc_id not in self._location_npc_cache[npc.location_id]:
|
||||
self._location_npc_cache[npc.location_id].append(npc.npc_id)
|
||||
|
||||
logger.info("NPC loaded successfully", npc_id=npc.npc_id)
|
||||
return npc
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to load NPC", file=str(file_path), error=str(e))
|
||||
return None
|
||||
|
||||
def _parse_npc_data(self, data: Dict) -> NPC:
|
||||
"""
|
||||
Parse YAML data into an NPC dataclass.
|
||||
|
||||
Args:
|
||||
data: Dictionary loaded from YAML file
|
||||
|
||||
Returns:
|
||||
NPC instance
|
||||
|
||||
Raises:
|
||||
ValueError: If data is invalid or missing required fields
|
||||
"""
|
||||
# Validate required fields
|
||||
required_fields = ["npc_id", "name", "role", "location_id"]
|
||||
for field in required_fields:
|
||||
if field not in data:
|
||||
raise ValueError(f"Missing required field: {field}")
|
||||
|
||||
# Parse personality
|
||||
personality_data = data.get("personality", {})
|
||||
personality = NPCPersonality(
|
||||
traits=personality_data.get("traits", []),
|
||||
speech_style=personality_data.get("speech_style", ""),
|
||||
quirks=personality_data.get("quirks", []),
|
||||
)
|
||||
|
||||
# Parse appearance
|
||||
appearance_data = data.get("appearance", {})
|
||||
if isinstance(appearance_data, str):
|
||||
appearance = NPCAppearance(brief=appearance_data)
|
||||
else:
|
||||
appearance = NPCAppearance(
|
||||
brief=appearance_data.get("brief", ""),
|
||||
detailed=appearance_data.get("detailed"),
|
||||
)
|
||||
|
||||
# Parse knowledge (optional)
|
||||
knowledge = None
|
||||
if data.get("knowledge"):
|
||||
knowledge_data = data["knowledge"]
|
||||
conditions = [
|
||||
NPCKnowledgeCondition(
|
||||
condition=c.get("condition", ""),
|
||||
reveals=c.get("reveals", ""),
|
||||
)
|
||||
for c in knowledge_data.get("will_share_if", [])
|
||||
]
|
||||
knowledge = NPCKnowledge(
|
||||
public=knowledge_data.get("public", []),
|
||||
secret=knowledge_data.get("secret", []),
|
||||
will_share_if=conditions,
|
||||
)
|
||||
|
||||
# Parse relationships
|
||||
relationships = [
|
||||
NPCRelationship(
|
||||
npc_id=r["npc_id"],
|
||||
attitude=r["attitude"],
|
||||
reason=r.get("reason"),
|
||||
)
|
||||
for r in data.get("relationships", [])
|
||||
]
|
||||
|
||||
# Parse inventory
|
||||
inventory = []
|
||||
for item_data in data.get("inventory_for_sale", []):
|
||||
# Handle shorthand format: { item: "ale", price: 2 }
|
||||
item_id = item_data.get("item_id") or item_data.get("item", "")
|
||||
inventory.append(NPCInventoryItem(
|
||||
item_id=item_id,
|
||||
price=item_data.get("price", 0),
|
||||
quantity=item_data.get("quantity"),
|
||||
))
|
||||
|
||||
# Parse dialogue hooks (optional)
|
||||
dialogue_hooks = None
|
||||
if data.get("dialogue_hooks"):
|
||||
hooks_data = data["dialogue_hooks"]
|
||||
dialogue_hooks = NPCDialogueHooks(
|
||||
greeting=hooks_data.get("greeting"),
|
||||
farewell=hooks_data.get("farewell"),
|
||||
busy=hooks_data.get("busy"),
|
||||
quest_complete=hooks_data.get("quest_complete"),
|
||||
)
|
||||
|
||||
return NPC(
|
||||
npc_id=data["npc_id"],
|
||||
name=data["name"],
|
||||
role=data["role"],
|
||||
location_id=data["location_id"],
|
||||
personality=personality,
|
||||
appearance=appearance,
|
||||
knowledge=knowledge,
|
||||
relationships=relationships,
|
||||
inventory_for_sale=inventory,
|
||||
dialogue_hooks=dialogue_hooks,
|
||||
quest_giver_for=data.get("quest_giver_for", []),
|
||||
reveals_locations=data.get("reveals_locations", []),
|
||||
tags=data.get("tags", []),
|
||||
)
|
||||
|
||||
def load_all_npcs(self) -> List[NPC]:
|
||||
"""
|
||||
Load all NPCs from all region directories.
|
||||
|
||||
Returns:
|
||||
List of NPC instances
|
||||
"""
|
||||
npcs = []
|
||||
|
||||
if not self.data_dir.exists():
|
||||
logger.error("NPC data directory does not exist", data_dir=str(self.data_dir))
|
||||
return npcs
|
||||
|
||||
for region_dir in self.data_dir.iterdir():
|
||||
if not region_dir.is_dir():
|
||||
continue
|
||||
|
||||
for file_path in region_dir.glob("*.yaml"):
|
||||
npc = self._load_npc_file(file_path)
|
||||
if npc:
|
||||
npcs.append(npc)
|
||||
|
||||
logger.info("All NPCs loaded", count=len(npcs))
|
||||
return npcs
|
||||
|
||||
def get_npcs_at_location(self, location_id: str) -> List[NPC]:
|
||||
"""
|
||||
Get all NPCs at a specific location.
|
||||
|
||||
Args:
|
||||
location_id: Location identifier
|
||||
|
||||
Returns:
|
||||
List of NPC instances at this location
|
||||
"""
|
||||
# Ensure all NPCs are loaded
|
||||
if not self._npc_cache:
|
||||
self.load_all_npcs()
|
||||
|
||||
npc_ids = self._location_npc_cache.get(location_id, [])
|
||||
return [
|
||||
self._npc_cache[npc_id]
|
||||
for npc_id in npc_ids
|
||||
if npc_id in self._npc_cache
|
||||
]
|
||||
|
||||
def get_npc_ids_at_location(self, location_id: str) -> List[str]:
|
||||
"""
|
||||
Get NPC IDs at a specific location.
|
||||
|
||||
Args:
|
||||
location_id: Location identifier
|
||||
|
||||
Returns:
|
||||
List of NPC IDs at this location
|
||||
"""
|
||||
# Ensure all NPCs are loaded
|
||||
if not self._npc_cache:
|
||||
self.load_all_npcs()
|
||||
|
||||
return self._location_npc_cache.get(location_id, [])
|
||||
|
||||
def get_npcs_by_tag(self, tag: str) -> List[NPC]:
|
||||
"""
|
||||
Get all NPCs with a specific tag.
|
||||
|
||||
Args:
|
||||
tag: Tag to filter by (e.g., "merchant", "quest_giver")
|
||||
|
||||
Returns:
|
||||
List of NPC instances with this tag
|
||||
"""
|
||||
# Ensure all NPCs are loaded
|
||||
if not self._npc_cache:
|
||||
self.load_all_npcs()
|
||||
|
||||
return [
|
||||
npc for npc in self._npc_cache.values()
|
||||
if tag in npc.tags
|
||||
]
|
||||
|
||||
def get_quest_givers(self, quest_id: str) -> List[NPC]:
|
||||
"""
|
||||
Get all NPCs that can give a specific quest.
|
||||
|
||||
Args:
|
||||
quest_id: Quest identifier
|
||||
|
||||
Returns:
|
||||
List of NPC instances that give this quest
|
||||
"""
|
||||
# Ensure all NPCs are loaded
|
||||
if not self._npc_cache:
|
||||
self.load_all_npcs()
|
||||
|
||||
return [
|
||||
npc for npc in self._npc_cache.values()
|
||||
if quest_id in npc.quest_giver_for
|
||||
]
|
||||
|
||||
def get_all_npc_ids(self) -> List[str]:
|
||||
"""
|
||||
Get a list of all available NPC IDs.
|
||||
|
||||
Returns:
|
||||
List of NPC IDs
|
||||
"""
|
||||
# Ensure all NPCs are loaded
|
||||
if not self._npc_cache:
|
||||
self.load_all_npcs()
|
||||
|
||||
return list(self._npc_cache.keys())
|
||||
|
||||
def reload_npc(self, npc_id: str) -> Optional[NPC]:
|
||||
"""
|
||||
Force reload an NPC from disk, bypassing cache.
|
||||
|
||||
Useful for development/testing when NPC definitions change.
|
||||
|
||||
Args:
|
||||
npc_id: Unique NPC identifier
|
||||
|
||||
Returns:
|
||||
NPC instance or None if not found
|
||||
"""
|
||||
# Remove from caches if present
|
||||
if npc_id in self._npc_cache:
|
||||
old_npc = self._npc_cache[npc_id]
|
||||
# Remove from location cache
|
||||
if old_npc.location_id in self._location_npc_cache:
|
||||
self._location_npc_cache[old_npc.location_id] = [
|
||||
n for n in self._location_npc_cache[old_npc.location_id]
|
||||
if n != npc_id
|
||||
]
|
||||
del self._npc_cache[npc_id]
|
||||
|
||||
return self.load_npc(npc_id)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear all cached data. Useful for testing."""
|
||||
self._npc_cache.clear()
|
||||
self._location_npc_cache.clear()
|
||||
logger.info("NPC cache cleared")
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
_loader_instance: Optional[NPCLoader] = None
|
||||
|
||||
|
||||
def get_npc_loader() -> NPCLoader:
|
||||
"""
|
||||
Get the global NPCLoader instance.
|
||||
|
||||
Returns:
|
||||
Singleton NPCLoader instance
|
||||
"""
|
||||
global _loader_instance
|
||||
if _loader_instance is None:
|
||||
_loader_instance = NPCLoader()
|
||||
return _loader_instance
|
||||
236
api/app/services/origin_service.py
Normal file
236
api/app/services/origin_service.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""
|
||||
OriginService for loading character origin definitions from YAML files.
|
||||
|
||||
This service reads origin configuration and converts it into Origin
|
||||
dataclass instances, providing caching for performance.
|
||||
"""
|
||||
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
import structlog
|
||||
|
||||
from app.models.origins import Origin, StartingLocation, StartingBonus
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class OriginService:
|
||||
"""
|
||||
Loads character origin definitions from YAML configuration.
|
||||
|
||||
Origins define character backstories, starting locations, and narrative
|
||||
hooks that the AI DM uses to create personalized gameplay experiences.
|
||||
All origin definitions are stored in /app/data/origins.yaml.
|
||||
"""
|
||||
|
||||
def __init__(self, data_file: Optional[str] = None):
|
||||
"""
|
||||
Initialize the origin service.
|
||||
|
||||
Args:
|
||||
data_file: Path to origins YAML file.
|
||||
Defaults to /app/data/origins.yaml
|
||||
"""
|
||||
if data_file is None:
|
||||
# Default to app/data/origins.yaml relative to this file
|
||||
current_file = Path(__file__)
|
||||
app_dir = current_file.parent.parent # Go up to /app
|
||||
data_file = str(app_dir / "data" / "origins.yaml")
|
||||
|
||||
self.data_file = Path(data_file)
|
||||
self._origins_cache: Dict[str, Origin] = {}
|
||||
self._all_origins_loaded = False
|
||||
|
||||
logger.info("OriginService initialized", data_file=str(self.data_file))
|
||||
|
||||
def load_origin(self, origin_id: str) -> Optional[Origin]:
|
||||
"""
|
||||
Load a single origin by ID.
|
||||
|
||||
Args:
|
||||
origin_id: Unique origin identifier (e.g., "soul_revenant")
|
||||
|
||||
Returns:
|
||||
Origin instance or None if not found
|
||||
"""
|
||||
# Check cache first
|
||||
if origin_id in self._origins_cache:
|
||||
logger.debug("Origin loaded from cache", origin_id=origin_id)
|
||||
return self._origins_cache[origin_id]
|
||||
|
||||
# Load all origins if not already loaded
|
||||
if not self._all_origins_loaded:
|
||||
self._load_all_origins()
|
||||
|
||||
# Return from cache after loading
|
||||
origin = self._origins_cache.get(origin_id)
|
||||
if origin:
|
||||
logger.info("Origin loaded successfully", origin_id=origin_id)
|
||||
else:
|
||||
logger.warning("Origin not found", origin_id=origin_id)
|
||||
|
||||
return origin
|
||||
|
||||
def load_all_origins(self) -> List[Origin]:
|
||||
"""
|
||||
Load all origins from the data file.
|
||||
|
||||
Returns:
|
||||
List of Origin instances
|
||||
"""
|
||||
if self._all_origins_loaded and self._origins_cache:
|
||||
logger.debug("All origins loaded from cache")
|
||||
return list(self._origins_cache.values())
|
||||
|
||||
return self._load_all_origins()
|
||||
|
||||
def _load_all_origins(self) -> List[Origin]:
|
||||
"""
|
||||
Internal method to load all origins from YAML.
|
||||
|
||||
Returns:
|
||||
List of Origin instances
|
||||
"""
|
||||
if not self.data_file.exists():
|
||||
logger.error("Origins data file does not exist", data_file=str(self.data_file))
|
||||
return []
|
||||
|
||||
try:
|
||||
# Load YAML file
|
||||
with open(self.data_file, 'r') as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
origins_data = data.get("origins", {})
|
||||
origins = []
|
||||
|
||||
# Parse each origin
|
||||
for origin_id, origin_data in origins_data.items():
|
||||
try:
|
||||
origin = self._parse_origin_data(origin_id, origin_data)
|
||||
self._origins_cache[origin_id] = origin
|
||||
origins.append(origin)
|
||||
except Exception as e:
|
||||
logger.error("Failed to parse origin", origin_id=origin_id, error=str(e))
|
||||
continue
|
||||
|
||||
self._all_origins_loaded = True
|
||||
logger.info("All origins loaded successfully", count=len(origins))
|
||||
return origins
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to load origins file", error=str(e))
|
||||
return []
|
||||
|
||||
def get_origin_by_id(self, origin_id: str) -> Optional[Origin]:
|
||||
"""
|
||||
Get an origin by ID (alias for load_origin).
|
||||
|
||||
Args:
|
||||
origin_id: Unique origin identifier
|
||||
|
||||
Returns:
|
||||
Origin instance or None if not found
|
||||
"""
|
||||
return self.load_origin(origin_id)
|
||||
|
||||
def get_all_origin_ids(self) -> List[str]:
|
||||
"""
|
||||
Get a list of all available origin IDs.
|
||||
|
||||
Returns:
|
||||
List of origin IDs (e.g., ["soul_revenant", "memory_thief"])
|
||||
"""
|
||||
if not self._all_origins_loaded:
|
||||
self._load_all_origins()
|
||||
|
||||
return list(self._origins_cache.keys())
|
||||
|
||||
def reload_origins(self) -> List[Origin]:
|
||||
"""
|
||||
Force reload all origins from disk, bypassing cache.
|
||||
|
||||
Useful for development/testing when origin definitions change.
|
||||
|
||||
Returns:
|
||||
List of Origin instances
|
||||
"""
|
||||
self.clear_cache()
|
||||
return self._load_all_origins()
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear the origins cache. Useful for testing."""
|
||||
self._origins_cache.clear()
|
||||
self._all_origins_loaded = False
|
||||
logger.info("Origins cache cleared")
|
||||
|
||||
def _parse_origin_data(self, origin_id: str, data: Dict) -> Origin:
|
||||
"""
|
||||
Parse YAML data into an Origin dataclass.
|
||||
|
||||
Args:
|
||||
origin_id: The origin's unique identifier
|
||||
data: Dictionary loaded from YAML file
|
||||
|
||||
Returns:
|
||||
Origin instance
|
||||
|
||||
Raises:
|
||||
ValueError: If data is invalid or missing required fields
|
||||
"""
|
||||
# Validate required fields
|
||||
required_fields = ["name", "description", "starting_location"]
|
||||
for field in required_fields:
|
||||
if field not in data:
|
||||
raise ValueError(f"Missing required field in origin '{origin_id}': {field}")
|
||||
|
||||
# Parse starting location
|
||||
location_data = data["starting_location"]
|
||||
starting_location = StartingLocation(
|
||||
id=location_data.get("id", ""),
|
||||
name=location_data.get("name", ""),
|
||||
region=location_data.get("region", ""),
|
||||
description=location_data.get("description", "")
|
||||
)
|
||||
|
||||
# Parse starting bonus (optional)
|
||||
starting_bonus = None
|
||||
if "starting_bonus" in data:
|
||||
bonus_data = data["starting_bonus"]
|
||||
starting_bonus = StartingBonus(
|
||||
trait=bonus_data.get("trait", ""),
|
||||
description=bonus_data.get("description", ""),
|
||||
effect=bonus_data.get("effect", "")
|
||||
)
|
||||
|
||||
# Parse narrative hooks (optional)
|
||||
narrative_hooks = data.get("narrative_hooks", [])
|
||||
|
||||
# Create Origin instance
|
||||
origin = Origin(
|
||||
id=data.get("id", origin_id), # Use provided ID or fall back to key
|
||||
name=data["name"],
|
||||
description=data["description"],
|
||||
starting_location=starting_location,
|
||||
narrative_hooks=narrative_hooks,
|
||||
starting_bonus=starting_bonus
|
||||
)
|
||||
|
||||
return origin
|
||||
|
||||
|
||||
# Global instance for convenience
|
||||
_service_instance: Optional[OriginService] = None
|
||||
|
||||
|
||||
def get_origin_service() -> OriginService:
|
||||
"""
|
||||
Get the global OriginService instance.
|
||||
|
||||
Returns:
|
||||
Singleton OriginService instance
|
||||
"""
|
||||
global _service_instance
|
||||
if _service_instance is None:
|
||||
_service_instance = OriginService()
|
||||
return _service_instance
|
||||
373
api/app/services/outcome_service.py
Normal file
373
api/app/services/outcome_service.py
Normal file
@@ -0,0 +1,373 @@
|
||||
"""
|
||||
Outcome determination service for Code of Conquest.
|
||||
|
||||
This service handles all code-determined game outcomes before they're passed
|
||||
to AI for narration. It uses the dice mechanics system to determine success/failure
|
||||
and selects appropriate rewards from loot tables.
|
||||
"""
|
||||
|
||||
import random
|
||||
import yaml
|
||||
import structlog
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from app.models.character import Character
|
||||
from app.game_logic.dice import (
|
||||
CheckResult, SkillType, Difficulty,
|
||||
skill_check, get_stat_for_skill, perception_check
|
||||
)
|
||||
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ItemFound:
|
||||
"""
|
||||
Represents an item found during a search.
|
||||
|
||||
Uses template key from generic_items.yaml.
|
||||
"""
|
||||
template_key: str
|
||||
name: str
|
||||
description: str
|
||||
value: int
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Serialize for API response."""
|
||||
return {
|
||||
"template_key": self.template_key,
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"value": self.value,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchOutcome:
|
||||
"""
|
||||
Complete result of a search action.
|
||||
|
||||
Includes the dice check result and any items/gold found.
|
||||
"""
|
||||
check_result: CheckResult
|
||||
items_found: List[ItemFound]
|
||||
gold_found: int
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Serialize for API response."""
|
||||
return {
|
||||
"check_result": self.check_result.to_dict(),
|
||||
"items_found": [item.to_dict() for item in self.items_found],
|
||||
"gold_found": self.gold_found,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillCheckOutcome:
|
||||
"""
|
||||
Result of a generic skill check.
|
||||
|
||||
Used for persuasion, lockpicking, stealth, etc.
|
||||
"""
|
||||
check_result: CheckResult
|
||||
context: Dict[str, Any] # Additional context for AI narration
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Serialize for API response."""
|
||||
return {
|
||||
"check_result": self.check_result.to_dict(),
|
||||
"context": self.context,
|
||||
}
|
||||
|
||||
|
||||
class OutcomeService:
|
||||
"""
|
||||
Service for determining game action outcomes.
|
||||
|
||||
Handles all dice rolls and loot selection before passing to AI.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the outcome service with loot tables and item templates."""
|
||||
self._loot_tables: Dict[str, Any] = {}
|
||||
self._item_templates: Dict[str, Any] = {}
|
||||
self._load_data()
|
||||
|
||||
def _load_data(self) -> None:
|
||||
"""Load loot tables and item templates from YAML files."""
|
||||
data_dir = Path(__file__).parent.parent / "data"
|
||||
|
||||
# Load loot tables
|
||||
loot_path = data_dir / "loot_tables.yaml"
|
||||
if loot_path.exists():
|
||||
with open(loot_path, "r") as f:
|
||||
self._loot_tables = yaml.safe_load(f)
|
||||
logger.info("loaded_loot_tables", count=len(self._loot_tables))
|
||||
else:
|
||||
logger.warning("loot_tables_not_found", path=str(loot_path))
|
||||
|
||||
# Load generic item templates
|
||||
items_path = data_dir / "generic_items.yaml"
|
||||
if items_path.exists():
|
||||
with open(items_path, "r") as f:
|
||||
data = yaml.safe_load(f)
|
||||
self._item_templates = data.get("templates", {})
|
||||
logger.info("loaded_item_templates", count=len(self._item_templates))
|
||||
else:
|
||||
logger.warning("item_templates_not_found", path=str(items_path))
|
||||
|
||||
def determine_search_outcome(
|
||||
self,
|
||||
character: Character,
|
||||
location_type: str,
|
||||
dc: int = 12,
|
||||
bonus: int = 0
|
||||
) -> SearchOutcome:
|
||||
"""
|
||||
Determine the outcome of a search action.
|
||||
|
||||
Uses a perception check to determine success, then selects items
|
||||
from the appropriate loot table based on the roll margin.
|
||||
|
||||
Args:
|
||||
character: The character performing the search
|
||||
location_type: Type of location (forest, cave, town, etc.)
|
||||
dc: Difficulty class (default 12 = easy-medium)
|
||||
bonus: Additional bonus to the check
|
||||
|
||||
Returns:
|
||||
SearchOutcome with check result, items found, and gold found
|
||||
"""
|
||||
# Get character's effective wisdom for perception
|
||||
effective_stats = character.get_effective_stats()
|
||||
wisdom = effective_stats.wisdom
|
||||
|
||||
# Perform the perception check
|
||||
check_result = perception_check(wisdom, dc, bonus)
|
||||
|
||||
# Determine loot based on result
|
||||
items_found: List[ItemFound] = []
|
||||
gold_found: int = 0
|
||||
|
||||
if check_result.success:
|
||||
# Get loot table for this location (fall back to default)
|
||||
loot_table = self._loot_tables.get(
|
||||
location_type.lower(),
|
||||
self._loot_tables.get("default", {})
|
||||
)
|
||||
|
||||
# Select item rarity based on margin
|
||||
if check_result.margin >= 10:
|
||||
rarity = "rare"
|
||||
elif check_result.margin >= 5:
|
||||
rarity = "uncommon"
|
||||
else:
|
||||
rarity = "common"
|
||||
|
||||
# Get items for this rarity
|
||||
item_keys = loot_table.get(rarity, [])
|
||||
if item_keys:
|
||||
# Select 1-2 items based on margin
|
||||
num_items = 1 if check_result.margin < 8 else 2
|
||||
selected_keys = random.sample(
|
||||
item_keys,
|
||||
min(num_items, len(item_keys))
|
||||
)
|
||||
|
||||
for key in selected_keys:
|
||||
template = self._item_templates.get(key)
|
||||
if template:
|
||||
items_found.append(ItemFound(
|
||||
template_key=key,
|
||||
name=template.get("name", key.title()),
|
||||
description=template.get("description", ""),
|
||||
value=template.get("value", 1),
|
||||
))
|
||||
|
||||
# Calculate gold found
|
||||
gold_config = loot_table.get("gold", {})
|
||||
if gold_config:
|
||||
min_gold = gold_config.get("min", 0)
|
||||
max_gold = gold_config.get("max", 10)
|
||||
bonus_per_margin = gold_config.get("bonus_per_margin", 0)
|
||||
|
||||
base_gold = random.randint(min_gold, max_gold)
|
||||
margin_bonus = check_result.margin * bonus_per_margin
|
||||
gold_found = base_gold + margin_bonus
|
||||
|
||||
logger.info(
|
||||
"search_outcome_determined",
|
||||
character_id=character.character_id,
|
||||
location_type=location_type,
|
||||
dc=dc,
|
||||
success=check_result.success,
|
||||
margin=check_result.margin,
|
||||
items_count=len(items_found),
|
||||
gold_found=gold_found
|
||||
)
|
||||
|
||||
return SearchOutcome(
|
||||
check_result=check_result,
|
||||
items_found=items_found,
|
||||
gold_found=gold_found
|
||||
)
|
||||
|
||||
def determine_skill_check_outcome(
|
||||
self,
|
||||
character: Character,
|
||||
skill_type: SkillType,
|
||||
dc: int,
|
||||
bonus: int = 0,
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
) -> SkillCheckOutcome:
|
||||
"""
|
||||
Determine the outcome of a generic skill check.
|
||||
|
||||
Args:
|
||||
character: The character performing the check
|
||||
skill_type: The type of skill check (PERSUASION, STEALTH, etc.)
|
||||
dc: Difficulty class to beat
|
||||
bonus: Additional bonus to the check
|
||||
context: Optional context for AI narration (e.g., NPC name, door type)
|
||||
|
||||
Returns:
|
||||
SkillCheckOutcome with check result and context
|
||||
"""
|
||||
# Get the appropriate stat for this skill
|
||||
stat_name = get_stat_for_skill(skill_type)
|
||||
effective_stats = character.get_effective_stats()
|
||||
stat_value = getattr(effective_stats, stat_name, 10)
|
||||
|
||||
# Perform the check
|
||||
check_result = skill_check(stat_value, dc, skill_type, bonus)
|
||||
|
||||
# Build outcome context
|
||||
outcome_context = context or {}
|
||||
outcome_context["skill_used"] = skill_type.name.lower()
|
||||
outcome_context["stat_used"] = stat_name
|
||||
|
||||
logger.info(
|
||||
"skill_check_outcome_determined",
|
||||
character_id=character.character_id,
|
||||
skill=skill_type.name,
|
||||
stat=stat_name,
|
||||
dc=dc,
|
||||
success=check_result.success,
|
||||
margin=check_result.margin
|
||||
)
|
||||
|
||||
return SkillCheckOutcome(
|
||||
check_result=check_result,
|
||||
context=outcome_context
|
||||
)
|
||||
|
||||
def determine_persuasion_outcome(
|
||||
self,
|
||||
character: Character,
|
||||
dc: int,
|
||||
npc_name: Optional[str] = None,
|
||||
bonus: int = 0
|
||||
) -> SkillCheckOutcome:
|
||||
"""
|
||||
Convenience method for persuasion checks.
|
||||
|
||||
Args:
|
||||
character: The character attempting persuasion
|
||||
dc: Difficulty class based on NPC disposition
|
||||
npc_name: Name of the NPC being persuaded
|
||||
bonus: Additional bonus
|
||||
|
||||
Returns:
|
||||
SkillCheckOutcome
|
||||
"""
|
||||
context = {"npc_name": npc_name} if npc_name else {}
|
||||
return self.determine_skill_check_outcome(
|
||||
character,
|
||||
SkillType.PERSUASION,
|
||||
dc,
|
||||
bonus,
|
||||
context
|
||||
)
|
||||
|
||||
def determine_stealth_outcome(
|
||||
self,
|
||||
character: Character,
|
||||
dc: int,
|
||||
situation: Optional[str] = None,
|
||||
bonus: int = 0
|
||||
) -> SkillCheckOutcome:
|
||||
"""
|
||||
Convenience method for stealth checks.
|
||||
|
||||
Args:
|
||||
character: The character attempting stealth
|
||||
dc: Difficulty class based on environment/observers
|
||||
situation: Description of what they're sneaking past
|
||||
bonus: Additional bonus
|
||||
|
||||
Returns:
|
||||
SkillCheckOutcome
|
||||
"""
|
||||
context = {"situation": situation} if situation else {}
|
||||
return self.determine_skill_check_outcome(
|
||||
character,
|
||||
SkillType.STEALTH,
|
||||
dc,
|
||||
bonus,
|
||||
context
|
||||
)
|
||||
|
||||
def determine_lockpicking_outcome(
|
||||
self,
|
||||
character: Character,
|
||||
dc: int,
|
||||
lock_description: Optional[str] = None,
|
||||
bonus: int = 0
|
||||
) -> SkillCheckOutcome:
|
||||
"""
|
||||
Convenience method for lockpicking checks.
|
||||
|
||||
Args:
|
||||
character: The character attempting to pick the lock
|
||||
dc: Difficulty class based on lock quality
|
||||
lock_description: Description of the lock/door
|
||||
bonus: Additional bonus (e.g., from thieves' tools)
|
||||
|
||||
Returns:
|
||||
SkillCheckOutcome
|
||||
"""
|
||||
context = {"lock_description": lock_description} if lock_description else {}
|
||||
return self.determine_skill_check_outcome(
|
||||
character,
|
||||
SkillType.LOCKPICKING,
|
||||
dc,
|
||||
bonus,
|
||||
context
|
||||
)
|
||||
|
||||
def get_dc_for_difficulty(self, difficulty: str) -> int:
|
||||
"""
|
||||
Get the DC value for a named difficulty.
|
||||
|
||||
Args:
|
||||
difficulty: Difficulty name (trivial, easy, medium, hard, very_hard)
|
||||
|
||||
Returns:
|
||||
DC value
|
||||
"""
|
||||
difficulty_map = {
|
||||
"trivial": Difficulty.TRIVIAL.value,
|
||||
"easy": Difficulty.EASY.value,
|
||||
"medium": Difficulty.MEDIUM.value,
|
||||
"hard": Difficulty.HARD.value,
|
||||
"very_hard": Difficulty.VERY_HARD.value,
|
||||
"nearly_impossible": Difficulty.NEARLY_IMPOSSIBLE.value,
|
||||
}
|
||||
return difficulty_map.get(difficulty.lower(), Difficulty.MEDIUM.value)
|
||||
|
||||
|
||||
# Global instance for use in API endpoints
|
||||
outcome_service = OutcomeService()
|
||||
602
api/app/services/rate_limiter_service.py
Normal file
602
api/app/services/rate_limiter_service.py
Normal file
@@ -0,0 +1,602 @@
|
||||
"""
|
||||
Rate Limiter Service
|
||||
|
||||
This module implements tier-based rate limiting for AI requests using Redis
|
||||
for distributed counting. Each user tier has a different daily limit for
|
||||
AI-generated turns.
|
||||
|
||||
Usage:
|
||||
from app.services.rate_limiter_service import RateLimiterService, RateLimitExceeded
|
||||
from app.ai.model_selector import UserTier
|
||||
|
||||
# Initialize service
|
||||
rate_limiter = RateLimiterService()
|
||||
|
||||
# Check and increment usage
|
||||
try:
|
||||
rate_limiter.check_rate_limit("user_123", UserTier.FREE)
|
||||
rate_limiter.increment_usage("user_123")
|
||||
except RateLimitExceeded as e:
|
||||
print(f"Rate limit exceeded: {e}")
|
||||
|
||||
# Get remaining turns
|
||||
remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE)
|
||||
"""
|
||||
|
||||
from datetime import date, datetime, timezone, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from app.services.redis_service import RedisService, RedisServiceError
|
||||
from app.ai.model_selector import UserTier
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
|
||||
# Initialize logger
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
class RateLimitExceeded(Exception):
|
||||
"""
|
||||
Raised when a user has exceeded their daily rate limit.
|
||||
|
||||
Attributes:
|
||||
user_id: The user who exceeded the limit
|
||||
user_tier: The user's subscription tier
|
||||
limit: The daily limit for their tier
|
||||
current_usage: The current usage count
|
||||
reset_time: UTC timestamp when the limit resets
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str,
|
||||
user_tier: UserTier,
|
||||
limit: int,
|
||||
current_usage: int,
|
||||
reset_time: datetime
|
||||
):
|
||||
self.user_id = user_id
|
||||
self.user_tier = user_tier
|
||||
self.limit = limit
|
||||
self.current_usage = current_usage
|
||||
self.reset_time = reset_time
|
||||
|
||||
message = (
|
||||
f"Rate limit exceeded for user {user_id} ({user_tier.value} tier). "
|
||||
f"Used {current_usage}/{limit} turns. Resets at {reset_time.isoformat()}"
|
||||
)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class RateLimiterService:
|
||||
"""
|
||||
Service for managing tier-based rate limiting.
|
||||
|
||||
This service uses Redis to track daily AI usage per user and enforces
|
||||
limits based on subscription tier. Counters reset daily at midnight UTC.
|
||||
|
||||
Tier Limits:
|
||||
- Free: 20 turns/day
|
||||
- Basic: 50 turns/day
|
||||
- Premium: 100 turns/day
|
||||
- Elite: 200 turns/day
|
||||
|
||||
Attributes:
|
||||
redis: RedisService instance for counter storage
|
||||
tier_limits: Mapping of tier to daily turn limit
|
||||
"""
|
||||
|
||||
# Daily turn limits per tier
|
||||
TIER_LIMITS = {
|
||||
UserTier.FREE: 20,
|
||||
UserTier.BASIC: 50,
|
||||
UserTier.PREMIUM: 100,
|
||||
UserTier.ELITE: 200,
|
||||
}
|
||||
|
||||
# Daily DM question limits per tier
|
||||
DM_QUESTION_LIMITS = {
|
||||
UserTier.FREE: 10,
|
||||
UserTier.BASIC: 20,
|
||||
UserTier.PREMIUM: 50,
|
||||
UserTier.ELITE: -1, # -1 means unlimited
|
||||
}
|
||||
|
||||
# Redis key prefix for rate limit counters
|
||||
KEY_PREFIX = "rate_limit:daily:"
|
||||
DM_QUESTION_PREFIX = "rate_limit:dm_questions:"
|
||||
|
||||
def __init__(self, redis_service: Optional[RedisService] = None):
|
||||
"""
|
||||
Initialize the rate limiter service.
|
||||
|
||||
Args:
|
||||
redis_service: Optional RedisService instance. If not provided,
|
||||
a new instance will be created.
|
||||
"""
|
||||
self.redis = redis_service or RedisService()
|
||||
|
||||
logger.info(
|
||||
"RateLimiterService initialized",
|
||||
tier_limits=self.TIER_LIMITS
|
||||
)
|
||||
|
||||
def _get_daily_key(self, user_id: str, day: Optional[date] = None) -> str:
|
||||
"""
|
||||
Generate the Redis key for a user's daily counter.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
day: The date (defaults to today UTC)
|
||||
|
||||
Returns:
|
||||
Redis key in format "rate_limit:daily:user_id:YYYY-MM-DD"
|
||||
"""
|
||||
if day is None:
|
||||
day = datetime.now(timezone.utc).date()
|
||||
|
||||
return f"{self.KEY_PREFIX}{user_id}:{day.isoformat()}"
|
||||
|
||||
def _get_seconds_until_midnight_utc(self) -> int:
|
||||
"""
|
||||
Calculate seconds remaining until midnight UTC.
|
||||
|
||||
Returns:
|
||||
Number of seconds until the next UTC midnight
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
tomorrow = datetime(
|
||||
now.year, now.month, now.day,
|
||||
tzinfo=timezone.utc
|
||||
) + timedelta(days=1)
|
||||
|
||||
return int((tomorrow - now).total_seconds())
|
||||
|
||||
def _get_reset_time(self) -> datetime:
|
||||
"""
|
||||
Get the UTC datetime when the rate limit resets.
|
||||
|
||||
Returns:
|
||||
Datetime of next midnight UTC
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
return datetime(
|
||||
now.year, now.month, now.day,
|
||||
tzinfo=timezone.utc
|
||||
) + timedelta(days=1)
|
||||
|
||||
def get_limit_for_tier(self, user_tier: UserTier) -> int:
|
||||
"""
|
||||
Get the daily turn limit for a specific tier.
|
||||
|
||||
Args:
|
||||
user_tier: The user's subscription tier
|
||||
|
||||
Returns:
|
||||
Daily turn limit for the tier
|
||||
"""
|
||||
return self.TIER_LIMITS.get(user_tier, self.TIER_LIMITS[UserTier.FREE])
|
||||
|
||||
def get_current_usage(self, user_id: str) -> int:
|
||||
"""
|
||||
Get the current daily usage count for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to check
|
||||
|
||||
Returns:
|
||||
Current usage count (0 if no usage today)
|
||||
|
||||
Raises:
|
||||
RedisServiceError: If Redis operation fails
|
||||
"""
|
||||
key = self._get_daily_key(user_id)
|
||||
|
||||
try:
|
||||
value = self.redis.get(key)
|
||||
usage = int(value) if value else 0
|
||||
|
||||
logger.debug(
|
||||
"Retrieved current usage",
|
||||
user_id=user_id,
|
||||
usage=usage
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.error(
|
||||
"Invalid usage value in Redis",
|
||||
user_id=user_id,
|
||||
error=str(e)
|
||||
)
|
||||
return 0
|
||||
|
||||
def check_rate_limit(self, user_id: str, user_tier: UserTier) -> None:
|
||||
"""
|
||||
Check if a user has exceeded their daily rate limit.
|
||||
|
||||
This method checks the current usage against the tier limit and
|
||||
raises an exception if the limit has been reached.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to check
|
||||
user_tier: The user's subscription tier
|
||||
|
||||
Raises:
|
||||
RateLimitExceeded: If the user has reached their daily limit
|
||||
RedisServiceError: If Redis operation fails
|
||||
"""
|
||||
current_usage = self.get_current_usage(user_id)
|
||||
limit = self.get_limit_for_tier(user_tier)
|
||||
|
||||
if current_usage >= limit:
|
||||
reset_time = self._get_reset_time()
|
||||
|
||||
logger.warning(
|
||||
"Rate limit exceeded",
|
||||
user_id=user_id,
|
||||
user_tier=user_tier.value,
|
||||
current_usage=current_usage,
|
||||
limit=limit,
|
||||
reset_time=reset_time.isoformat()
|
||||
)
|
||||
|
||||
raise RateLimitExceeded(
|
||||
user_id=user_id,
|
||||
user_tier=user_tier,
|
||||
limit=limit,
|
||||
current_usage=current_usage,
|
||||
reset_time=reset_time
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Rate limit check passed",
|
||||
user_id=user_id,
|
||||
user_tier=user_tier.value,
|
||||
current_usage=current_usage,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
def increment_usage(self, user_id: str) -> int:
|
||||
"""
|
||||
Increment the daily usage counter for a user.
|
||||
|
||||
This method should be called after successfully processing an AI request.
|
||||
The counter will automatically expire at midnight UTC.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to increment
|
||||
|
||||
Returns:
|
||||
The new usage count after incrementing
|
||||
|
||||
Raises:
|
||||
RedisServiceError: If Redis operation fails
|
||||
"""
|
||||
key = self._get_daily_key(user_id)
|
||||
|
||||
# Increment the counter
|
||||
new_count = self.redis.incr(key)
|
||||
|
||||
# Set expiration if this is the first increment (new_count == 1)
|
||||
# This ensures the key expires at midnight UTC
|
||||
if new_count == 1:
|
||||
ttl = self._get_seconds_until_midnight_utc()
|
||||
self.redis.expire(key, ttl)
|
||||
|
||||
logger.debug(
|
||||
"Set expiration on new rate limit key",
|
||||
user_id=user_id,
|
||||
ttl=ttl
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Incremented usage counter",
|
||||
user_id=user_id,
|
||||
new_count=new_count
|
||||
)
|
||||
|
||||
return new_count
|
||||
|
||||
def get_remaining_turns(self, user_id: str, user_tier: UserTier) -> int:
|
||||
"""
|
||||
Get the number of remaining turns for a user today.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to check
|
||||
user_tier: The user's subscription tier
|
||||
|
||||
Returns:
|
||||
Number of turns remaining (0 if limit reached)
|
||||
"""
|
||||
current_usage = self.get_current_usage(user_id)
|
||||
limit = self.get_limit_for_tier(user_tier)
|
||||
|
||||
remaining = max(0, limit - current_usage)
|
||||
|
||||
logger.debug(
|
||||
"Calculated remaining turns",
|
||||
user_id=user_id,
|
||||
user_tier=user_tier.value,
|
||||
current_usage=current_usage,
|
||||
limit=limit,
|
||||
remaining=remaining
|
||||
)
|
||||
|
||||
return remaining
|
||||
|
||||
def get_usage_info(self, user_id: str, user_tier: UserTier) -> dict:
|
||||
"""
|
||||
Get comprehensive usage information for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to check
|
||||
user_tier: The user's subscription tier
|
||||
|
||||
Returns:
|
||||
Dictionary with usage info:
|
||||
- user_id: User identifier
|
||||
- user_tier: Subscription tier
|
||||
- current_usage: Current daily usage
|
||||
- daily_limit: Daily limit for tier
|
||||
- remaining: Remaining turns
|
||||
- reset_time: ISO format UTC reset time
|
||||
- is_limited: Whether limit has been reached
|
||||
"""
|
||||
current_usage = self.get_current_usage(user_id)
|
||||
limit = self.get_limit_for_tier(user_tier)
|
||||
remaining = max(0, limit - current_usage)
|
||||
reset_time = self._get_reset_time()
|
||||
|
||||
info = {
|
||||
"user_id": user_id,
|
||||
"user_tier": user_tier.value,
|
||||
"current_usage": current_usage,
|
||||
"daily_limit": limit,
|
||||
"remaining": remaining,
|
||||
"reset_time": reset_time.isoformat(),
|
||||
"is_limited": current_usage >= limit
|
||||
}
|
||||
|
||||
logger.debug("Retrieved usage info", **info)
|
||||
|
||||
return info
|
||||
|
||||
def reset_usage(self, user_id: str) -> bool:
|
||||
"""
|
||||
Reset the daily usage counter for a user.
|
||||
|
||||
This is primarily for admin/testing purposes.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to reset
|
||||
|
||||
Returns:
|
||||
True if the counter was deleted, False if it didn't exist
|
||||
|
||||
Raises:
|
||||
RedisServiceError: If Redis operation fails
|
||||
"""
|
||||
key = self._get_daily_key(user_id)
|
||||
deleted = self.redis.delete(key)
|
||||
|
||||
logger.info(
|
||||
"Reset usage counter",
|
||||
user_id=user_id,
|
||||
deleted=deleted > 0
|
||||
)
|
||||
|
||||
return deleted > 0
|
||||
|
||||
# ===== DM QUESTION RATE LIMITING =====
|
||||
|
||||
def _get_dm_question_key(self, user_id: str, day: Optional[date] = None) -> str:
|
||||
"""
|
||||
Generate the Redis key for a user's daily DM question counter.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
day: The date (defaults to today UTC)
|
||||
|
||||
Returns:
|
||||
Redis key in format "rate_limit:dm_questions:user_id:YYYY-MM-DD"
|
||||
"""
|
||||
if day is None:
|
||||
day = datetime.now(timezone.utc).date()
|
||||
|
||||
return f"{self.DM_QUESTION_PREFIX}{user_id}:{day.isoformat()}"
|
||||
|
||||
def get_dm_question_limit_for_tier(self, user_tier: UserTier) -> int:
|
||||
"""
|
||||
Get the daily DM question limit for a specific tier.
|
||||
|
||||
Args:
|
||||
user_tier: The user's subscription tier
|
||||
|
||||
Returns:
|
||||
Daily DM question limit for the tier (-1 for unlimited)
|
||||
"""
|
||||
return self.DM_QUESTION_LIMITS.get(user_tier, self.DM_QUESTION_LIMITS[UserTier.FREE])
|
||||
|
||||
def get_current_dm_usage(self, user_id: str) -> int:
|
||||
"""
|
||||
Get the current daily DM question usage count for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to check
|
||||
|
||||
Returns:
|
||||
Current DM question usage count (0 if no usage today)
|
||||
|
||||
Raises:
|
||||
RedisServiceError: If Redis operation fails
|
||||
"""
|
||||
key = self._get_dm_question_key(user_id)
|
||||
|
||||
try:
|
||||
value = self.redis.get(key)
|
||||
usage = int(value) if value else 0
|
||||
|
||||
logger.debug(
|
||||
"Retrieved current DM question usage",
|
||||
user_id=user_id,
|
||||
usage=usage
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.error(
|
||||
"Invalid DM question usage value in Redis",
|
||||
user_id=user_id,
|
||||
error=str(e)
|
||||
)
|
||||
return 0
|
||||
|
||||
def check_dm_question_limit(self, user_id: str, user_tier: UserTier) -> None:
|
||||
"""
|
||||
Check if a user has exceeded their daily DM question limit.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to check
|
||||
user_tier: The user's subscription tier
|
||||
|
||||
Raises:
|
||||
RateLimitExceeded: If the user has reached their daily DM question limit
|
||||
RedisServiceError: If Redis operation fails
|
||||
"""
|
||||
limit = self.get_dm_question_limit_for_tier(user_tier)
|
||||
|
||||
# -1 means unlimited
|
||||
if limit == -1:
|
||||
logger.debug(
|
||||
"DM question limit check passed (unlimited)",
|
||||
user_id=user_id,
|
||||
user_tier=user_tier.value
|
||||
)
|
||||
return
|
||||
|
||||
current_usage = self.get_current_dm_usage(user_id)
|
||||
|
||||
if current_usage >= limit:
|
||||
reset_time = self._get_reset_time()
|
||||
|
||||
logger.warning(
|
||||
"DM question limit exceeded",
|
||||
user_id=user_id,
|
||||
user_tier=user_tier.value,
|
||||
current_usage=current_usage,
|
||||
limit=limit,
|
||||
reset_time=reset_time.isoformat()
|
||||
)
|
||||
|
||||
raise RateLimitExceeded(
|
||||
user_id=user_id,
|
||||
user_tier=user_tier,
|
||||
limit=limit,
|
||||
current_usage=current_usage,
|
||||
reset_time=reset_time
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"DM question limit check passed",
|
||||
user_id=user_id,
|
||||
user_tier=user_tier.value,
|
||||
current_usage=current_usage,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
def increment_dm_usage(self, user_id: str) -> int:
|
||||
"""
|
||||
Increment the daily DM question counter for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to increment
|
||||
|
||||
Returns:
|
||||
The new DM question usage count after incrementing
|
||||
|
||||
Raises:
|
||||
RedisServiceError: If Redis operation fails
|
||||
"""
|
||||
key = self._get_dm_question_key(user_id)
|
||||
|
||||
# Increment the counter
|
||||
new_count = self.redis.incr(key)
|
||||
|
||||
# Set expiration if this is the first increment
|
||||
if new_count == 1:
|
||||
ttl = self._get_seconds_until_midnight_utc()
|
||||
self.redis.expire(key, ttl)
|
||||
|
||||
logger.debug(
|
||||
"Set expiration on new DM question key",
|
||||
user_id=user_id,
|
||||
ttl=ttl
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Incremented DM question counter",
|
||||
user_id=user_id,
|
||||
new_count=new_count
|
||||
)
|
||||
|
||||
return new_count
|
||||
|
||||
def get_remaining_dm_questions(self, user_id: str, user_tier: UserTier) -> int:
|
||||
"""
|
||||
Get the number of remaining DM questions for a user today.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to check
|
||||
user_tier: The user's subscription tier
|
||||
|
||||
Returns:
|
||||
Number of DM questions remaining (-1 if unlimited, 0 if limit reached)
|
||||
"""
|
||||
limit = self.get_dm_question_limit_for_tier(user_tier)
|
||||
|
||||
# -1 means unlimited
|
||||
if limit == -1:
|
||||
return -1
|
||||
|
||||
current_usage = self.get_current_dm_usage(user_id)
|
||||
remaining = max(0, limit - current_usage)
|
||||
|
||||
logger.debug(
|
||||
"Calculated remaining DM questions",
|
||||
user_id=user_id,
|
||||
user_tier=user_tier.value,
|
||||
current_usage=current_usage,
|
||||
limit=limit,
|
||||
remaining=remaining
|
||||
)
|
||||
|
||||
return remaining
|
||||
|
||||
def reset_dm_usage(self, user_id: str) -> bool:
|
||||
"""
|
||||
Reset the daily DM question counter for a user.
|
||||
|
||||
This is primarily for admin/testing purposes.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to reset
|
||||
|
||||
Returns:
|
||||
True if the counter was deleted, False if it didn't exist
|
||||
|
||||
Raises:
|
||||
RedisServiceError: If Redis operation fails
|
||||
"""
|
||||
key = self._get_dm_question_key(user_id)
|
||||
deleted = self.redis.delete(key)
|
||||
|
||||
logger.info(
|
||||
"Reset DM question counter",
|
||||
user_id=user_id,
|
||||
deleted=deleted > 0
|
||||
)
|
||||
|
||||
return deleted > 0
|
||||
505
api/app/services/redis_service.py
Normal file
505
api/app/services/redis_service.py
Normal file
@@ -0,0 +1,505 @@
|
||||
"""
|
||||
Redis Service Wrapper
|
||||
|
||||
This module provides a wrapper around the redis-py client for handling caching,
|
||||
job queue data, and temporary storage. It provides connection pooling, automatic
|
||||
reconnection, and a clean interface for common Redis operations.
|
||||
|
||||
Usage:
|
||||
from app.services.redis_service import RedisService
|
||||
|
||||
# Initialize service
|
||||
redis = RedisService()
|
||||
|
||||
# Basic operations
|
||||
redis.set("key", "value", ttl=3600) # Set with 1 hour TTL
|
||||
value = redis.get("key")
|
||||
redis.delete("key")
|
||||
|
||||
# Health check
|
||||
if redis.health_check():
|
||||
print("Redis is healthy")
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from typing import Optional, Any, Union
|
||||
|
||||
import redis
|
||||
from redis.exceptions import RedisError, ConnectionError as RedisConnectionError
|
||||
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
|
||||
# Initialize logger
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
class RedisServiceError(Exception):
|
||||
"""Base exception for Redis service errors."""
|
||||
pass
|
||||
|
||||
|
||||
class RedisConnectionFailed(RedisServiceError):
|
||||
"""Raised when Redis connection cannot be established."""
|
||||
pass
|
||||
|
||||
|
||||
class RedisService:
|
||||
"""
|
||||
Service class for interacting with Redis.
|
||||
|
||||
This class provides:
|
||||
- Connection pooling for efficient connection management
|
||||
- Basic operations: get, set, delete, exists
|
||||
- TTL support for caching
|
||||
- Health check for monitoring
|
||||
- Automatic JSON serialization for complex objects
|
||||
|
||||
Attributes:
|
||||
pool: Redis connection pool
|
||||
client: Redis client instance
|
||||
"""
|
||||
|
||||
def __init__(self, redis_url: Optional[str] = None):
|
||||
"""
|
||||
Initialize the Redis service.
|
||||
|
||||
Reads configuration from environment variables if not provided:
|
||||
- REDIS_URL: Full Redis URL (e.g., redis://localhost:6379/0)
|
||||
|
||||
Args:
|
||||
redis_url: Optional Redis URL to override environment variable
|
||||
|
||||
Raises:
|
||||
RedisConnectionFailed: If connection to Redis fails
|
||||
"""
|
||||
self.redis_url = redis_url or os.getenv('REDIS_URL', 'redis://localhost:6379/0')
|
||||
|
||||
if not self.redis_url:
|
||||
logger.error("Missing Redis URL configuration")
|
||||
raise ValueError("Redis URL not configured. Set REDIS_URL environment variable.")
|
||||
|
||||
try:
|
||||
# Create connection pool for efficient connection management
|
||||
# Connection pooling allows multiple operations to share connections
|
||||
# and automatically manages connection lifecycle
|
||||
self.pool = redis.ConnectionPool.from_url(
|
||||
self.redis_url,
|
||||
max_connections=10,
|
||||
decode_responses=True, # Return strings instead of bytes
|
||||
socket_connect_timeout=5, # Connection timeout in seconds
|
||||
socket_timeout=5, # Operation timeout in seconds
|
||||
retry_on_timeout=True, # Retry on timeout
|
||||
)
|
||||
|
||||
# Create client using the connection pool
|
||||
self.client = redis.Redis(connection_pool=self.pool)
|
||||
|
||||
# Test connection
|
||||
self.client.ping()
|
||||
|
||||
logger.info("Redis service initialized", redis_url=self._sanitize_url(self.redis_url))
|
||||
|
||||
except RedisConnectionError as e:
|
||||
logger.error("Failed to connect to Redis", redis_url=self._sanitize_url(self.redis_url), error=str(e))
|
||||
raise RedisConnectionFailed(f"Could not connect to Redis at {self._sanitize_url(self.redis_url)}: {e}")
|
||||
except RedisError as e:
|
||||
logger.error("Redis initialization error", error=str(e))
|
||||
raise RedisServiceError(f"Redis initialization failed: {e}")
|
||||
|
||||
def get(self, key: str) -> Optional[str]:
|
||||
"""
|
||||
Get a value from Redis by key.
|
||||
|
||||
Args:
|
||||
key: The key to retrieve
|
||||
|
||||
Returns:
|
||||
The value as string if found, None if key doesn't exist
|
||||
|
||||
Raises:
|
||||
RedisServiceError: If the operation fails
|
||||
"""
|
||||
try:
|
||||
value = self.client.get(key)
|
||||
|
||||
if value is not None:
|
||||
logger.debug("Redis GET", key=key, found=True)
|
||||
else:
|
||||
logger.debug("Redis GET", key=key, found=False)
|
||||
|
||||
return value
|
||||
|
||||
except RedisError as e:
|
||||
logger.error("Redis GET failed", key=key, error=str(e))
|
||||
raise RedisServiceError(f"Failed to get key '{key}': {e}")
|
||||
|
||||
def get_json(self, key: str) -> Optional[Any]:
|
||||
"""
|
||||
Get a value from Redis and deserialize it from JSON.
|
||||
|
||||
Args:
|
||||
key: The key to retrieve
|
||||
|
||||
Returns:
|
||||
The deserialized value if found, None if key doesn't exist
|
||||
|
||||
Raises:
|
||||
RedisServiceError: If the operation fails or JSON is invalid
|
||||
"""
|
||||
value = self.get(key)
|
||||
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error("Failed to decode JSON from Redis", key=key, error=str(e))
|
||||
raise RedisServiceError(f"Failed to decode JSON for key '{key}': {e}")
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str,
|
||||
ttl: Optional[int] = None,
|
||||
nx: bool = False,
|
||||
xx: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
Set a value in Redis.
|
||||
|
||||
Args:
|
||||
key: The key to set
|
||||
value: The value to store (must be string)
|
||||
ttl: Time to live in seconds (None for no expiration)
|
||||
nx: Only set if key does not exist (for locking)
|
||||
xx: Only set if key already exists
|
||||
|
||||
Returns:
|
||||
True if the key was set, False if not set (due to nx/xx conditions)
|
||||
|
||||
Raises:
|
||||
RedisServiceError: If the operation fails
|
||||
"""
|
||||
try:
|
||||
result = self.client.set(
|
||||
key,
|
||||
value,
|
||||
ex=ttl, # Expiration in seconds
|
||||
nx=nx, # Only set if not exists
|
||||
xx=xx # Only set if exists
|
||||
)
|
||||
|
||||
# set() returns True if set, None if not set due to nx/xx
|
||||
success = result is True or result == 1
|
||||
|
||||
logger.debug("Redis SET", key=key, ttl=ttl, nx=nx, xx=xx, success=success)
|
||||
|
||||
return success
|
||||
|
||||
except RedisError as e:
|
||||
logger.error("Redis SET failed", key=key, error=str(e))
|
||||
raise RedisServiceError(f"Failed to set key '{key}': {e}")
|
||||
|
||||
def set_json(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl: Optional[int] = None,
|
||||
nx: bool = False,
|
||||
xx: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
Serialize a value to JSON and store it in Redis.
|
||||
|
||||
Args:
|
||||
key: The key to set
|
||||
value: The value to serialize and store (must be JSON-serializable)
|
||||
ttl: Time to live in seconds (None for no expiration)
|
||||
nx: Only set if key does not exist
|
||||
xx: Only set if key already exists
|
||||
|
||||
Returns:
|
||||
True if the key was set, False if not set (due to nx/xx conditions)
|
||||
|
||||
Raises:
|
||||
RedisServiceError: If the operation fails or value is not JSON-serializable
|
||||
"""
|
||||
try:
|
||||
json_value = json.dumps(value)
|
||||
except (TypeError, ValueError) as e:
|
||||
logger.error("Failed to serialize value to JSON", key=key, error=str(e))
|
||||
raise RedisServiceError(f"Failed to serialize value for key '{key}': {e}")
|
||||
|
||||
return self.set(key, json_value, ttl=ttl, nx=nx, xx=xx)
|
||||
|
||||
def delete(self, *keys: str) -> int:
|
||||
"""
|
||||
Delete one or more keys from Redis.
|
||||
|
||||
Args:
|
||||
*keys: One or more keys to delete
|
||||
|
||||
Returns:
|
||||
The number of keys that were deleted
|
||||
|
||||
Raises:
|
||||
RedisServiceError: If the operation fails
|
||||
"""
|
||||
if not keys:
|
||||
return 0
|
||||
|
||||
try:
|
||||
deleted_count = self.client.delete(*keys)
|
||||
|
||||
logger.debug("Redis DELETE", keys=keys, deleted_count=deleted_count)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except RedisError as e:
|
||||
logger.error("Redis DELETE failed", keys=keys, error=str(e))
|
||||
raise RedisServiceError(f"Failed to delete keys {keys}: {e}")
|
||||
|
||||
def exists(self, *keys: str) -> int:
|
||||
"""
|
||||
Check if one or more keys exist in Redis.
|
||||
|
||||
Args:
|
||||
*keys: One or more keys to check
|
||||
|
||||
Returns:
|
||||
The number of keys that exist
|
||||
|
||||
Raises:
|
||||
RedisServiceError: If the operation fails
|
||||
"""
|
||||
if not keys:
|
||||
return 0
|
||||
|
||||
try:
|
||||
exists_count = self.client.exists(*keys)
|
||||
|
||||
logger.debug("Redis EXISTS", keys=keys, exists_count=exists_count)
|
||||
|
||||
return exists_count
|
||||
|
||||
except RedisError as e:
|
||||
logger.error("Redis EXISTS failed", keys=keys, error=str(e))
|
||||
raise RedisServiceError(f"Failed to check existence of keys {keys}: {e}")
|
||||
|
||||
def expire(self, key: str, ttl: int) -> bool:
|
||||
"""
|
||||
Set a TTL (time to live) on an existing key.
|
||||
|
||||
Args:
|
||||
key: The key to set expiration on
|
||||
ttl: Time to live in seconds
|
||||
|
||||
Returns:
|
||||
True if the timeout was set, False if key doesn't exist
|
||||
|
||||
Raises:
|
||||
RedisServiceError: If the operation fails
|
||||
"""
|
||||
try:
|
||||
result = self.client.expire(key, ttl)
|
||||
|
||||
logger.debug("Redis EXPIRE", key=key, ttl=ttl, success=result)
|
||||
|
||||
return result
|
||||
|
||||
except RedisError as e:
|
||||
logger.error("Redis EXPIRE failed", key=key, ttl=ttl, error=str(e))
|
||||
raise RedisServiceError(f"Failed to set expiration for key '{key}': {e}")
|
||||
|
||||
def ttl(self, key: str) -> int:
|
||||
"""
|
||||
Get the remaining TTL (time to live) for a key.
|
||||
|
||||
Args:
|
||||
key: The key to check
|
||||
|
||||
Returns:
|
||||
TTL in seconds, -1 if key exists but has no expiry, -2 if key doesn't exist
|
||||
|
||||
Raises:
|
||||
RedisServiceError: If the operation fails
|
||||
"""
|
||||
try:
|
||||
remaining = self.client.ttl(key)
|
||||
|
||||
logger.debug("Redis TTL", key=key, remaining=remaining)
|
||||
|
||||
return remaining
|
||||
|
||||
except RedisError as e:
|
||||
logger.error("Redis TTL failed", key=key, error=str(e))
|
||||
raise RedisServiceError(f"Failed to get TTL for key '{key}': {e}")
|
||||
|
||||
def incr(self, key: str, amount: int = 1) -> int:
|
||||
"""
|
||||
Increment a key's value by the given amount.
|
||||
|
||||
If the key doesn't exist, it will be created with the increment value.
|
||||
|
||||
Args:
|
||||
key: The key to increment
|
||||
amount: Amount to increment by (default 1)
|
||||
|
||||
Returns:
|
||||
The new value after incrementing
|
||||
|
||||
Raises:
|
||||
RedisServiceError: If the operation fails or value is not an integer
|
||||
"""
|
||||
try:
|
||||
new_value = self.client.incrby(key, amount)
|
||||
|
||||
logger.debug("Redis INCR", key=key, amount=amount, new_value=new_value)
|
||||
|
||||
return new_value
|
||||
|
||||
except RedisError as e:
|
||||
logger.error("Redis INCR failed", key=key, amount=amount, error=str(e))
|
||||
raise RedisServiceError(f"Failed to increment key '{key}': {e}")
|
||||
|
||||
def decr(self, key: str, amount: int = 1) -> int:
|
||||
"""
|
||||
Decrement a key's value by the given amount.
|
||||
|
||||
If the key doesn't exist, it will be created with the negative increment value.
|
||||
|
||||
Args:
|
||||
key: The key to decrement
|
||||
amount: Amount to decrement by (default 1)
|
||||
|
||||
Returns:
|
||||
The new value after decrementing
|
||||
|
||||
Raises:
|
||||
RedisServiceError: If the operation fails or value is not an integer
|
||||
"""
|
||||
try:
|
||||
new_value = self.client.decrby(key, amount)
|
||||
|
||||
logger.debug("Redis DECR", key=key, amount=amount, new_value=new_value)
|
||||
|
||||
return new_value
|
||||
|
||||
except RedisError as e:
|
||||
logger.error("Redis DECR failed", key=key, amount=amount, error=str(e))
|
||||
raise RedisServiceError(f"Failed to decrement key '{key}': {e}")
|
||||
|
||||
def health_check(self) -> bool:
|
||||
"""
|
||||
Check if Redis connection is healthy.
|
||||
|
||||
This performs a PING command to verify the connection is working.
|
||||
|
||||
Returns:
|
||||
True if Redis is healthy and responding, False otherwise
|
||||
"""
|
||||
try:
|
||||
response = self.client.ping()
|
||||
|
||||
if response:
|
||||
logger.debug("Redis health check passed")
|
||||
return True
|
||||
else:
|
||||
logger.warning("Redis health check failed - unexpected response", response=response)
|
||||
return False
|
||||
|
||||
except RedisError as e:
|
||||
logger.error("Redis health check failed", error=str(e))
|
||||
return False
|
||||
|
||||
def info(self) -> dict:
|
||||
"""
|
||||
Get Redis server information.
|
||||
|
||||
Returns:
|
||||
Dictionary containing server info (version, memory, clients, etc.)
|
||||
|
||||
Raises:
|
||||
RedisServiceError: If the operation fails
|
||||
"""
|
||||
try:
|
||||
info = self.client.info()
|
||||
|
||||
logger.debug("Redis INFO retrieved", redis_version=info.get('redis_version'))
|
||||
|
||||
return info
|
||||
|
||||
except RedisError as e:
|
||||
logger.error("Redis INFO failed", error=str(e))
|
||||
raise RedisServiceError(f"Failed to get Redis info: {e}")
|
||||
|
||||
def flush_db(self) -> bool:
|
||||
"""
|
||||
Delete all keys in the current database.
|
||||
|
||||
WARNING: This is a destructive operation. Use with caution.
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
|
||||
Raises:
|
||||
RedisServiceError: If the operation fails
|
||||
"""
|
||||
try:
|
||||
self.client.flushdb()
|
||||
|
||||
logger.warning("Redis database flushed")
|
||||
|
||||
return True
|
||||
|
||||
except RedisError as e:
|
||||
logger.error("Redis FLUSHDB failed", error=str(e))
|
||||
raise RedisServiceError(f"Failed to flush database: {e}")
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Close all connections in the pool.
|
||||
|
||||
Call this when shutting down the application to cleanly release connections.
|
||||
"""
|
||||
try:
|
||||
self.pool.disconnect()
|
||||
logger.info("Redis connection pool closed")
|
||||
except Exception as e:
|
||||
logger.error("Error closing Redis connection pool", error=str(e))
|
||||
|
||||
def _sanitize_url(self, url: str) -> str:
|
||||
"""
|
||||
Remove password from Redis URL for safe logging.
|
||||
|
||||
Args:
|
||||
url: Redis URL that may contain password
|
||||
|
||||
Returns:
|
||||
URL with password masked
|
||||
"""
|
||||
# Simple sanitization - mask password if present
|
||||
# Format: redis://user:password@host:port/db
|
||||
if '@' in url:
|
||||
# Split on @ and mask everything before it except the protocol
|
||||
parts = url.split('@')
|
||||
protocol_and_creds = parts[0]
|
||||
host_and_rest = parts[1]
|
||||
|
||||
if '://' in protocol_and_creds:
|
||||
protocol = protocol_and_creds.split('://')[0]
|
||||
return f"{protocol}://***@{host_and_rest}"
|
||||
|
||||
return url
|
||||
|
||||
def __enter__(self):
|
||||
"""Context manager entry."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Context manager exit - close connections."""
|
||||
self.close()
|
||||
return False
|
||||
705
api/app/services/session_service.py
Normal file
705
api/app/services/session_service.py
Normal file
@@ -0,0 +1,705 @@
|
||||
"""
|
||||
Session Service - CRUD operations for game sessions.
|
||||
|
||||
This service handles creating, reading, updating, and managing game sessions,
|
||||
with support for both solo and multiplayer sessions.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import List, Optional
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from appwrite.query import Query
|
||||
from appwrite.id import ID
|
||||
|
||||
from app.models.session import GameSession, GameState, ConversationEntry, SessionConfig
|
||||
from app.models.enums import SessionStatus, SessionType
|
||||
from app.models.action_prompt import LocationType
|
||||
from app.services.database_service import get_database_service
|
||||
from app.services.appwrite_service import AppwriteService
|
||||
from app.services.character_service import get_character_service, CharacterNotFound
|
||||
from app.services.location_loader import get_location_loader
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
# Session limits per user
|
||||
MAX_ACTIVE_SESSIONS = 5
|
||||
|
||||
|
||||
class SessionNotFound(Exception):
|
||||
"""Raised when session ID doesn't exist or user doesn't own it."""
|
||||
pass
|
||||
|
||||
|
||||
class SessionLimitExceeded(Exception):
|
||||
"""Raised when user tries to create more sessions than allowed."""
|
||||
pass
|
||||
|
||||
|
||||
class SessionValidationError(Exception):
|
||||
"""Raised when session validation fails."""
|
||||
pass
|
||||
|
||||
|
||||
class SessionService:
|
||||
"""
|
||||
Service for managing game sessions.
|
||||
|
||||
This service provides:
|
||||
- Session creation (solo and multiplayer)
|
||||
- Session retrieval and listing
|
||||
- Session state updates
|
||||
- Conversation history management
|
||||
- Game state tracking
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the session service with dependencies."""
|
||||
self.db = get_database_service()
|
||||
self.appwrite = AppwriteService()
|
||||
self.character_service = get_character_service()
|
||||
self.collection_id = "game_sessions"
|
||||
|
||||
logger.info("SessionService initialized")
|
||||
|
||||
def create_solo_session(
|
||||
self,
|
||||
user_id: str,
|
||||
character_id: str,
|
||||
starting_location: Optional[str] = None,
|
||||
starting_location_type: Optional[LocationType] = None
|
||||
) -> GameSession:
|
||||
"""
|
||||
Create a new solo game session.
|
||||
|
||||
This method:
|
||||
1. Validates user owns the character
|
||||
2. Validates user hasn't exceeded session limit
|
||||
3. Determines starting location from location data
|
||||
4. Creates session with initial game state
|
||||
5. Stores in Appwrite database
|
||||
|
||||
Args:
|
||||
user_id: Owner's user ID
|
||||
character_id: Character ID for this session
|
||||
starting_location: Initial location ID (optional, uses default starting location)
|
||||
starting_location_type: Initial location type (optional, derived from location data)
|
||||
|
||||
Returns:
|
||||
Created GameSession instance
|
||||
|
||||
Raises:
|
||||
CharacterNotFound: If character doesn't exist or user doesn't own it
|
||||
SessionLimitExceeded: If user has reached active session limit
|
||||
"""
|
||||
try:
|
||||
logger.info("Creating solo session",
|
||||
user_id=user_id,
|
||||
character_id=character_id)
|
||||
|
||||
# Validate user owns the character
|
||||
character = self.character_service.get_character(character_id, user_id)
|
||||
if not character:
|
||||
raise CharacterNotFound(f"Character not found: {character_id}")
|
||||
|
||||
# Determine starting location from location data if not provided
|
||||
if not starting_location:
|
||||
location_loader = get_location_loader()
|
||||
starting_locations = location_loader.get_starting_locations()
|
||||
|
||||
if starting_locations:
|
||||
# Use first starting location (usually crossville_village)
|
||||
start_loc = starting_locations[0]
|
||||
starting_location = start_loc.location_id
|
||||
# Convert from enums.LocationType to action_prompt.LocationType via string value
|
||||
starting_location_type = LocationType(start_loc.location_type.value)
|
||||
logger.info("Using starting location from data",
|
||||
location_id=starting_location,
|
||||
location_type=starting_location_type.value)
|
||||
else:
|
||||
# Fallback to crossville_village
|
||||
starting_location = "crossville_village"
|
||||
starting_location_type = LocationType.TOWN
|
||||
logger.warning("No starting locations found, using fallback",
|
||||
location_id=starting_location)
|
||||
|
||||
# Ensure location type is set
|
||||
if not starting_location_type:
|
||||
starting_location_type = LocationType.TOWN
|
||||
|
||||
# Check session limit
|
||||
active_count = self.count_user_sessions(user_id, active_only=True)
|
||||
if active_count >= MAX_ACTIVE_SESSIONS:
|
||||
logger.warning("Session limit exceeded",
|
||||
user_id=user_id,
|
||||
current=active_count,
|
||||
limit=MAX_ACTIVE_SESSIONS)
|
||||
raise SessionLimitExceeded(
|
||||
f"Maximum active sessions reached ({active_count}/{MAX_ACTIVE_SESSIONS}). "
|
||||
f"Please end an existing session to start a new one."
|
||||
)
|
||||
|
||||
# Generate unique session ID
|
||||
session_id = ID.unique()
|
||||
|
||||
# Create game state with starting location
|
||||
game_state = GameState(
|
||||
current_location=starting_location,
|
||||
location_type=starting_location_type,
|
||||
discovered_locations=[starting_location],
|
||||
active_quests=[],
|
||||
world_events=[]
|
||||
)
|
||||
|
||||
# Create session instance
|
||||
session = GameSession(
|
||||
session_id=session_id,
|
||||
session_type=SessionType.SOLO,
|
||||
solo_character_id=character_id,
|
||||
user_id=user_id,
|
||||
party_member_ids=[],
|
||||
config=SessionConfig(),
|
||||
game_state=game_state,
|
||||
turn_order=[character_id],
|
||||
current_turn=0,
|
||||
turn_number=0,
|
||||
status=SessionStatus.ACTIVE
|
||||
)
|
||||
|
||||
# Serialize and store
|
||||
session_dict = session.to_dict()
|
||||
session_json = json.dumps(session_dict)
|
||||
|
||||
document_data = {
|
||||
'userId': user_id,
|
||||
'characterId': character_id,
|
||||
'sessionData': session_json,
|
||||
'status': SessionStatus.ACTIVE.value,
|
||||
'sessionType': SessionType.SOLO.value
|
||||
}
|
||||
|
||||
self.db.create_document(
|
||||
collection_id=self.collection_id,
|
||||
data=document_data,
|
||||
document_id=session_id
|
||||
)
|
||||
|
||||
logger.info("Solo session created successfully",
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
character_id=character_id)
|
||||
|
||||
return session
|
||||
|
||||
except (CharacterNotFound, SessionLimitExceeded):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to create solo session",
|
||||
user_id=user_id,
|
||||
character_id=character_id,
|
||||
error=str(e))
|
||||
raise
|
||||
|
||||
def get_session(self, session_id: str, user_id: Optional[str] = None) -> GameSession:
|
||||
"""
|
||||
Get a session by ID.
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
user_id: Optional user ID for ownership validation
|
||||
|
||||
Returns:
|
||||
GameSession instance
|
||||
|
||||
Raises:
|
||||
SessionNotFound: If session doesn't exist or user doesn't own it
|
||||
"""
|
||||
try:
|
||||
logger.debug("Fetching session", session_id=session_id)
|
||||
|
||||
# Get document from database
|
||||
document = self.db.get_row(self.collection_id, session_id)
|
||||
|
||||
if not document:
|
||||
logger.warning("Session not found", session_id=session_id)
|
||||
raise SessionNotFound(f"Session not found: {session_id}")
|
||||
|
||||
# Verify ownership if user_id provided
|
||||
if user_id and document.data.get('userId') != user_id:
|
||||
logger.warning("Session ownership mismatch",
|
||||
session_id=session_id,
|
||||
expected_user=user_id,
|
||||
actual_user=document.data.get('userId'))
|
||||
raise SessionNotFound(f"Session not found: {session_id}")
|
||||
|
||||
# Parse session data
|
||||
session_json = document.data.get('sessionData')
|
||||
session_dict = json.loads(session_json)
|
||||
session = GameSession.from_dict(session_dict)
|
||||
|
||||
logger.debug("Session fetched successfully", session_id=session_id)
|
||||
return session
|
||||
|
||||
except SessionNotFound:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to fetch session",
|
||||
session_id=session_id,
|
||||
error=str(e))
|
||||
raise
|
||||
|
||||
def update_session(self, session: GameSession) -> GameSession:
|
||||
"""
|
||||
Update a session in the database.
|
||||
|
||||
Args:
|
||||
session: GameSession instance with updated data
|
||||
|
||||
Returns:
|
||||
Updated GameSession instance
|
||||
"""
|
||||
try:
|
||||
logger.debug("Updating session", session_id=session.session_id)
|
||||
|
||||
# Serialize session
|
||||
session_dict = session.to_dict()
|
||||
session_json = json.dumps(session_dict)
|
||||
|
||||
# Update in database
|
||||
self.db.update_document(
|
||||
collection_id=self.collection_id,
|
||||
document_id=session.session_id,
|
||||
data={
|
||||
'sessionData': session_json,
|
||||
'status': session.status.value
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug("Session updated successfully", session_id=session.session_id)
|
||||
return session
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to update session",
|
||||
session_id=session.session_id,
|
||||
error=str(e))
|
||||
raise
|
||||
|
||||
def get_user_sessions(
|
||||
self,
|
||||
user_id: str,
|
||||
active_only: bool = True,
|
||||
limit: int = 25
|
||||
) -> List[GameSession]:
|
||||
"""
|
||||
Get all sessions for a user.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
active_only: If True, only return active sessions
|
||||
limit: Maximum number of sessions to return
|
||||
|
||||
Returns:
|
||||
List of GameSession instances
|
||||
"""
|
||||
try:
|
||||
logger.debug("Fetching user sessions",
|
||||
user_id=user_id,
|
||||
active_only=active_only)
|
||||
|
||||
# Build query
|
||||
queries = [Query.equal('userId', user_id)]
|
||||
if active_only:
|
||||
queries.append(Query.equal('status', SessionStatus.ACTIVE.value))
|
||||
|
||||
documents = self.db.list_rows(
|
||||
table_id=self.collection_id,
|
||||
queries=queries,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
# Parse all sessions
|
||||
sessions = []
|
||||
for document in documents:
|
||||
try:
|
||||
session_json = document.data.get('sessionData')
|
||||
session_dict = json.loads(session_json)
|
||||
session = GameSession.from_dict(session_dict)
|
||||
sessions.append(session)
|
||||
except Exception as e:
|
||||
logger.error("Failed to parse session",
|
||||
document_id=document.id,
|
||||
error=str(e))
|
||||
continue
|
||||
|
||||
logger.debug("User sessions fetched",
|
||||
user_id=user_id,
|
||||
count=len(sessions))
|
||||
|
||||
return sessions
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to fetch user sessions",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise
|
||||
|
||||
def count_user_sessions(self, user_id: str, active_only: bool = True) -> int:
|
||||
"""
|
||||
Count sessions for a user.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
active_only: If True, only count active sessions
|
||||
|
||||
Returns:
|
||||
Number of sessions
|
||||
"""
|
||||
try:
|
||||
queries = [Query.equal('userId', user_id)]
|
||||
if active_only:
|
||||
queries.append(Query.equal('status', SessionStatus.ACTIVE.value))
|
||||
|
||||
count = self.db.count_documents(
|
||||
collection_id=self.collection_id,
|
||||
queries=queries
|
||||
)
|
||||
|
||||
logger.debug("Session count",
|
||||
user_id=user_id,
|
||||
active_only=active_only,
|
||||
count=count)
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to count sessions",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
return 0
|
||||
|
||||
def end_session(self, session_id: str, user_id: str) -> GameSession:
|
||||
"""
|
||||
End a session by marking it as completed.
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
user_id: User ID for ownership validation
|
||||
|
||||
Returns:
|
||||
Updated GameSession instance
|
||||
|
||||
Raises:
|
||||
SessionNotFound: If session doesn't exist or user doesn't own it
|
||||
"""
|
||||
try:
|
||||
logger.info("Ending session", session_id=session_id, user_id=user_id)
|
||||
|
||||
session = self.get_session(session_id, user_id)
|
||||
session.status = SessionStatus.COMPLETED
|
||||
session.update_activity()
|
||||
|
||||
return self.update_session(session)
|
||||
|
||||
except SessionNotFound:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to end session",
|
||||
session_id=session_id,
|
||||
error=str(e))
|
||||
raise
|
||||
|
||||
def add_conversation_entry(
|
||||
self,
|
||||
session_id: str,
|
||||
character_id: str,
|
||||
character_name: str,
|
||||
action: str,
|
||||
dm_response: str,
|
||||
combat_log: Optional[List] = None,
|
||||
quest_offered: Optional[dict] = None
|
||||
) -> GameSession:
|
||||
"""
|
||||
Add an entry to the conversation history.
|
||||
|
||||
This method automatically:
|
||||
- Increments turn number
|
||||
- Adds timestamp
|
||||
- Updates last activity
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
character_id: Acting character's ID
|
||||
character_name: Acting character's name
|
||||
action: Player's action text
|
||||
dm_response: AI DM's response
|
||||
combat_log: Optional combat actions
|
||||
quest_offered: Optional quest offering info
|
||||
|
||||
Returns:
|
||||
Updated GameSession instance
|
||||
"""
|
||||
try:
|
||||
logger.debug("Adding conversation entry",
|
||||
session_id=session_id,
|
||||
character_id=character_id)
|
||||
|
||||
session = self.get_session(session_id)
|
||||
|
||||
# Create conversation entry
|
||||
entry = ConversationEntry(
|
||||
turn=session.turn_number + 1,
|
||||
character_id=character_id,
|
||||
character_name=character_name,
|
||||
action=action,
|
||||
dm_response=dm_response,
|
||||
combat_log=combat_log or [],
|
||||
quest_offered=quest_offered
|
||||
)
|
||||
|
||||
# Add entry and increment turn
|
||||
session.conversation_history.append(entry)
|
||||
session.turn_number += 1
|
||||
session.update_activity()
|
||||
|
||||
# Save to database
|
||||
return self.update_session(session)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to add conversation entry",
|
||||
session_id=session_id,
|
||||
error=str(e))
|
||||
raise
|
||||
|
||||
def get_conversation_history(
|
||||
self,
|
||||
session_id: str,
|
||||
limit: Optional[int] = None,
|
||||
offset: int = 0
|
||||
) -> List[ConversationEntry]:
|
||||
"""
|
||||
Get conversation history for a session.
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
limit: Maximum entries to return (None for all)
|
||||
offset: Number of entries to skip from end
|
||||
|
||||
Returns:
|
||||
List of ConversationEntry instances
|
||||
"""
|
||||
try:
|
||||
session = self.get_session(session_id)
|
||||
history = session.conversation_history
|
||||
|
||||
# Apply offset (from end)
|
||||
if offset > 0:
|
||||
history = history[:-offset] if offset < len(history) else []
|
||||
|
||||
# Apply limit (from end)
|
||||
if limit and len(history) > limit:
|
||||
history = history[-limit:]
|
||||
|
||||
return history
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get conversation history",
|
||||
session_id=session_id,
|
||||
error=str(e))
|
||||
raise
|
||||
|
||||
def get_recent_history(self, session_id: str, num_turns: int = 3) -> List[ConversationEntry]:
|
||||
"""
|
||||
Get the most recent conversation entries for AI context.
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
num_turns: Number of recent turns to return
|
||||
|
||||
Returns:
|
||||
List of most recent ConversationEntry instances
|
||||
"""
|
||||
return self.get_conversation_history(session_id, limit=num_turns)
|
||||
|
||||
def update_location(
|
||||
self,
|
||||
session_id: str,
|
||||
new_location: str,
|
||||
location_type: LocationType
|
||||
) -> GameSession:
|
||||
"""
|
||||
Update the current location in the session.
|
||||
|
||||
Also adds location to discovered_locations if not already there.
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
new_location: New location name
|
||||
location_type: New location type
|
||||
|
||||
Returns:
|
||||
Updated GameSession instance
|
||||
"""
|
||||
try:
|
||||
logger.debug("Updating location",
|
||||
session_id=session_id,
|
||||
new_location=new_location)
|
||||
|
||||
session = self.get_session(session_id)
|
||||
session.game_state.current_location = new_location
|
||||
session.game_state.location_type = location_type
|
||||
|
||||
# Track discovered locations
|
||||
if new_location not in session.game_state.discovered_locations:
|
||||
session.game_state.discovered_locations.append(new_location)
|
||||
|
||||
session.update_activity()
|
||||
return self.update_session(session)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to update location",
|
||||
session_id=session_id,
|
||||
error=str(e))
|
||||
raise
|
||||
|
||||
def add_discovered_location(self, session_id: str, location: str) -> GameSession:
|
||||
"""
|
||||
Add a location to the discovered locations list.
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
location: Location name to add
|
||||
|
||||
Returns:
|
||||
Updated GameSession instance
|
||||
"""
|
||||
try:
|
||||
session = self.get_session(session_id)
|
||||
|
||||
if location not in session.game_state.discovered_locations:
|
||||
session.game_state.discovered_locations.append(location)
|
||||
session.update_activity()
|
||||
return self.update_session(session)
|
||||
|
||||
return session
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to add discovered location",
|
||||
session_id=session_id,
|
||||
error=str(e))
|
||||
raise
|
||||
|
||||
def add_active_quest(self, session_id: str, quest_id: str) -> GameSession:
|
||||
"""
|
||||
Add a quest to the active quests list.
|
||||
|
||||
Validates max 2 active quests limit.
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
quest_id: Quest ID to add
|
||||
|
||||
Returns:
|
||||
Updated GameSession instance
|
||||
|
||||
Raises:
|
||||
SessionValidationError: If max quests limit exceeded
|
||||
"""
|
||||
try:
|
||||
session = self.get_session(session_id)
|
||||
|
||||
# Check max active quests (2)
|
||||
if len(session.game_state.active_quests) >= 2:
|
||||
raise SessionValidationError(
|
||||
"Maximum active quests reached (2/2). "
|
||||
"Complete or abandon a quest to accept a new one."
|
||||
)
|
||||
|
||||
if quest_id not in session.game_state.active_quests:
|
||||
session.game_state.active_quests.append(quest_id)
|
||||
session.update_activity()
|
||||
return self.update_session(session)
|
||||
|
||||
return session
|
||||
|
||||
except SessionValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to add active quest",
|
||||
session_id=session_id,
|
||||
quest_id=quest_id,
|
||||
error=str(e))
|
||||
raise
|
||||
|
||||
def remove_active_quest(self, session_id: str, quest_id: str) -> GameSession:
|
||||
"""
|
||||
Remove a quest from the active quests list.
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
quest_id: Quest ID to remove
|
||||
|
||||
Returns:
|
||||
Updated GameSession instance
|
||||
"""
|
||||
try:
|
||||
session = self.get_session(session_id)
|
||||
|
||||
if quest_id in session.game_state.active_quests:
|
||||
session.game_state.active_quests.remove(quest_id)
|
||||
session.update_activity()
|
||||
return self.update_session(session)
|
||||
|
||||
return session
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to remove active quest",
|
||||
session_id=session_id,
|
||||
quest_id=quest_id,
|
||||
error=str(e))
|
||||
raise
|
||||
|
||||
def add_world_event(self, session_id: str, event: dict) -> GameSession:
|
||||
"""
|
||||
Add a world event to the session.
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
event: Event dictionary with type, description, etc.
|
||||
|
||||
Returns:
|
||||
Updated GameSession instance
|
||||
"""
|
||||
try:
|
||||
session = self.get_session(session_id)
|
||||
|
||||
# Add timestamp if not present
|
||||
if 'timestamp' not in event:
|
||||
event['timestamp'] = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
|
||||
session.game_state.world_events.append(event)
|
||||
session.update_activity()
|
||||
return self.update_session(session)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to add world event",
|
||||
session_id=session_id,
|
||||
error=str(e))
|
||||
raise
|
||||
|
||||
|
||||
# Global instance for convenience
|
||||
_service_instance: Optional[SessionService] = None
|
||||
|
||||
|
||||
def get_session_service() -> SessionService:
|
||||
"""
|
||||
Get the global SessionService instance.
|
||||
|
||||
Returns:
|
||||
Singleton SessionService instance
|
||||
"""
|
||||
global _service_instance
|
||||
if _service_instance is None:
|
||||
_service_instance = SessionService()
|
||||
return _service_instance
|
||||
528
api/app/services/usage_tracking_service.py
Normal file
528
api/app/services/usage_tracking_service.py
Normal file
@@ -0,0 +1,528 @@
|
||||
"""
|
||||
Usage Tracking Service for AI cost and usage monitoring.
|
||||
|
||||
This service tracks all AI usage events, calculates costs, and provides
|
||||
analytics for monitoring and rate limiting purposes.
|
||||
|
||||
Usage:
|
||||
from app.services.usage_tracking_service import UsageTrackingService
|
||||
|
||||
tracker = UsageTrackingService()
|
||||
|
||||
# Log a usage event
|
||||
tracker.log_usage(
|
||||
user_id="user_123",
|
||||
model="anthropic/claude-3.5-sonnet",
|
||||
tokens_input=100,
|
||||
tokens_output=350,
|
||||
task_type=TaskType.STORY_PROGRESSION
|
||||
)
|
||||
|
||||
# Get daily usage
|
||||
usage = tracker.get_daily_usage("user_123", date.today())
|
||||
print(f"Total requests: {usage.total_requests}")
|
||||
print(f"Estimated cost: ${usage.estimated_cost:.4f}")
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime, timezone, date, timedelta
|
||||
from typing import Dict, Any, List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from appwrite.client import Client
|
||||
from appwrite.services.tables_db import TablesDB
|
||||
from appwrite.exception import AppwriteException
|
||||
from appwrite.id import ID
|
||||
from appwrite.query import Query
|
||||
|
||||
from app.utils.logging import get_logger
|
||||
from app.models.ai_usage import (
|
||||
AIUsageLog,
|
||||
DailyUsageSummary,
|
||||
MonthlyUsageSummary,
|
||||
TaskType
|
||||
)
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
# Cost per 1000 tokens by model (in USD)
|
||||
# These are estimates based on Replicate pricing
|
||||
MODEL_COSTS = {
|
||||
# Llama models (via Replicate) - very cheap
|
||||
"meta/meta-llama-3-8b-instruct": {
|
||||
"input": 0.0001, # $0.0001 per 1K input tokens
|
||||
"output": 0.0001, # $0.0001 per 1K output tokens
|
||||
},
|
||||
"meta/meta-llama-3-70b-instruct": {
|
||||
"input": 0.0006,
|
||||
"output": 0.0006,
|
||||
},
|
||||
# Claude models (via Replicate)
|
||||
"anthropic/claude-3.5-haiku": {
|
||||
"input": 0.001, # $0.001 per 1K input tokens
|
||||
"output": 0.005, # $0.005 per 1K output tokens
|
||||
},
|
||||
"anthropic/claude-3-haiku": {
|
||||
"input": 0.00025,
|
||||
"output": 0.00125,
|
||||
},
|
||||
"anthropic/claude-3.5-sonnet": {
|
||||
"input": 0.003, # $0.003 per 1K input tokens
|
||||
"output": 0.015, # $0.015 per 1K output tokens
|
||||
},
|
||||
"anthropic/claude-4.5-sonnet": {
|
||||
"input": 0.003,
|
||||
"output": 0.015,
|
||||
},
|
||||
"anthropic/claude-3-opus": {
|
||||
"input": 0.015, # $0.015 per 1K input tokens
|
||||
"output": 0.075, # $0.075 per 1K output tokens
|
||||
},
|
||||
}
|
||||
|
||||
# Default cost for unknown models
|
||||
DEFAULT_COST = {"input": 0.001, "output": 0.005}
|
||||
|
||||
|
||||
class UsageTrackingService:
|
||||
"""
|
||||
Service for tracking AI usage and calculating costs.
|
||||
|
||||
This service provides:
|
||||
- Logging individual AI usage events to Appwrite
|
||||
- Calculating estimated costs based on model pricing
|
||||
- Retrieving daily and monthly usage summaries
|
||||
- Analytics for monitoring and rate limiting
|
||||
|
||||
The service stores usage logs in an Appwrite collection named 'ai_usage_logs'.
|
||||
"""
|
||||
|
||||
# Collection ID for usage logs
|
||||
COLLECTION_ID = "ai_usage_logs"
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the usage tracking service.
|
||||
|
||||
Reads configuration from environment variables:
|
||||
- APPWRITE_ENDPOINT: Appwrite API endpoint
|
||||
- APPWRITE_PROJECT_ID: Appwrite project ID
|
||||
- APPWRITE_API_KEY: Appwrite API key
|
||||
- APPWRITE_DATABASE_ID: Appwrite database ID
|
||||
|
||||
Raises:
|
||||
ValueError: If required environment variables are missing
|
||||
"""
|
||||
self.endpoint = os.getenv('APPWRITE_ENDPOINT')
|
||||
self.project_id = os.getenv('APPWRITE_PROJECT_ID')
|
||||
self.api_key = os.getenv('APPWRITE_API_KEY')
|
||||
self.database_id = os.getenv('APPWRITE_DATABASE_ID', 'main')
|
||||
|
||||
if not all([self.endpoint, self.project_id, self.api_key]):
|
||||
logger.error("Missing Appwrite configuration in environment variables")
|
||||
raise ValueError("Appwrite configuration incomplete. Check APPWRITE_* environment variables.")
|
||||
|
||||
# Initialize Appwrite client
|
||||
self.client = Client()
|
||||
self.client.set_endpoint(self.endpoint)
|
||||
self.client.set_project(self.project_id)
|
||||
self.client.set_key(self.api_key)
|
||||
|
||||
# Initialize TablesDB service
|
||||
self.tables_db = TablesDB(self.client)
|
||||
|
||||
logger.info("UsageTrackingService initialized", database_id=self.database_id)
|
||||
|
||||
def log_usage(
|
||||
self,
|
||||
user_id: str,
|
||||
model: str,
|
||||
tokens_input: int,
|
||||
tokens_output: int,
|
||||
task_type: TaskType,
|
||||
session_id: Optional[str] = None,
|
||||
character_id: Optional[str] = None,
|
||||
request_duration_ms: int = 0,
|
||||
success: bool = True,
|
||||
error_message: Optional[str] = None
|
||||
) -> AIUsageLog:
|
||||
"""
|
||||
Log an AI usage event.
|
||||
|
||||
This method creates a new usage log entry in Appwrite with all
|
||||
relevant information about the AI request including calculated
|
||||
estimated cost.
|
||||
|
||||
Args:
|
||||
user_id: User who made the request
|
||||
model: Model identifier (e.g., "anthropic/claude-3.5-sonnet")
|
||||
tokens_input: Number of input tokens (prompt)
|
||||
tokens_output: Number of output tokens (response)
|
||||
task_type: Type of task (story, combat, quest, npc)
|
||||
session_id: Optional game session ID
|
||||
character_id: Optional character ID
|
||||
request_duration_ms: Request duration in milliseconds
|
||||
success: Whether the request succeeded
|
||||
error_message: Error message if failed
|
||||
|
||||
Returns:
|
||||
AIUsageLog with the logged data
|
||||
|
||||
Raises:
|
||||
AppwriteException: If storage fails
|
||||
"""
|
||||
# Calculate total tokens
|
||||
tokens_total = tokens_input + tokens_output
|
||||
|
||||
# Calculate estimated cost
|
||||
estimated_cost = self._calculate_cost(model, tokens_input, tokens_output)
|
||||
|
||||
# Generate log ID
|
||||
log_id = str(uuid4())
|
||||
|
||||
# Create usage log
|
||||
usage_log = AIUsageLog(
|
||||
log_id=log_id,
|
||||
user_id=user_id,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
model=model,
|
||||
tokens_input=tokens_input,
|
||||
tokens_output=tokens_output,
|
||||
tokens_total=tokens_total,
|
||||
estimated_cost=estimated_cost,
|
||||
task_type=task_type,
|
||||
session_id=session_id,
|
||||
character_id=character_id,
|
||||
request_duration_ms=request_duration_ms,
|
||||
success=success,
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
try:
|
||||
# Store in Appwrite
|
||||
result = self.tables_db.create_row(
|
||||
database_id=self.database_id,
|
||||
table_id=self.COLLECTION_ID,
|
||||
row_id=log_id,
|
||||
data=usage_log.to_dict()
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"AI usage logged",
|
||||
log_id=log_id,
|
||||
user_id=user_id,
|
||||
model=model,
|
||||
tokens_total=tokens_total,
|
||||
estimated_cost=estimated_cost,
|
||||
task_type=task_type.value,
|
||||
success=success
|
||||
)
|
||||
|
||||
return usage_log
|
||||
|
||||
except AppwriteException as e:
|
||||
logger.error(
|
||||
"Failed to log AI usage",
|
||||
user_id=user_id,
|
||||
model=model,
|
||||
error=str(e),
|
||||
code=e.code
|
||||
)
|
||||
raise
|
||||
|
||||
def get_daily_usage(self, user_id: str, target_date: date) -> DailyUsageSummary:
|
||||
"""
|
||||
Get AI usage summary for a specific day.
|
||||
|
||||
Args:
|
||||
user_id: User ID to get usage for
|
||||
target_date: Date to get usage for
|
||||
|
||||
Returns:
|
||||
DailyUsageSummary with aggregated usage data
|
||||
|
||||
Raises:
|
||||
AppwriteException: If query fails
|
||||
"""
|
||||
try:
|
||||
# Build date range for the target day (UTC)
|
||||
start_of_day = datetime.combine(target_date, datetime.min.time()).replace(tzinfo=timezone.utc)
|
||||
end_of_day = datetime.combine(target_date, datetime.max.time()).replace(tzinfo=timezone.utc)
|
||||
|
||||
# Query usage logs for this user and date
|
||||
result = self.tables_db.list_rows(
|
||||
database_id=self.database_id,
|
||||
table_id=self.COLLECTION_ID,
|
||||
queries=[
|
||||
Query.equal("user_id", user_id),
|
||||
Query.greater_than_equal("timestamp", start_of_day.isoformat()),
|
||||
Query.less_than_equal("timestamp", end_of_day.isoformat()),
|
||||
Query.limit(1000) # Cap at 1000 entries per day
|
||||
]
|
||||
)
|
||||
|
||||
# Aggregate the data
|
||||
total_requests = 0
|
||||
total_tokens = 0
|
||||
total_input_tokens = 0
|
||||
total_output_tokens = 0
|
||||
total_cost = 0.0
|
||||
requests_by_task: Dict[str, int] = {}
|
||||
|
||||
for doc in result['rows']:
|
||||
total_requests += 1
|
||||
total_tokens += doc.get('tokens_total', 0)
|
||||
total_input_tokens += doc.get('tokens_input', 0)
|
||||
total_output_tokens += doc.get('tokens_output', 0)
|
||||
total_cost += doc.get('estimated_cost', 0.0)
|
||||
|
||||
task_type = doc.get('task_type', 'general')
|
||||
requests_by_task[task_type] = requests_by_task.get(task_type, 0) + 1
|
||||
|
||||
summary = DailyUsageSummary(
|
||||
date=target_date,
|
||||
user_id=user_id,
|
||||
total_requests=total_requests,
|
||||
total_tokens=total_tokens,
|
||||
total_input_tokens=total_input_tokens,
|
||||
total_output_tokens=total_output_tokens,
|
||||
estimated_cost=total_cost,
|
||||
requests_by_task=requests_by_task
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Daily usage retrieved",
|
||||
user_id=user_id,
|
||||
date=target_date.isoformat(),
|
||||
total_requests=total_requests,
|
||||
estimated_cost=total_cost
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
except AppwriteException as e:
|
||||
logger.error(
|
||||
"Failed to get daily usage",
|
||||
user_id=user_id,
|
||||
date=target_date.isoformat(),
|
||||
error=str(e),
|
||||
code=e.code
|
||||
)
|
||||
raise
|
||||
|
||||
def get_monthly_cost(self, user_id: str, year: int, month: int) -> MonthlyUsageSummary:
|
||||
"""
|
||||
Get AI usage cost summary for a specific month.
|
||||
|
||||
Args:
|
||||
user_id: User ID to get cost for
|
||||
year: Year (e.g., 2025)
|
||||
month: Month (1-12)
|
||||
|
||||
Returns:
|
||||
MonthlyUsageSummary with aggregated cost data
|
||||
|
||||
Raises:
|
||||
AppwriteException: If query fails
|
||||
ValueError: If month is invalid
|
||||
"""
|
||||
if not 1 <= month <= 12:
|
||||
raise ValueError(f"Invalid month: {month}. Must be 1-12.")
|
||||
|
||||
try:
|
||||
# Build date range for the month
|
||||
start_of_month = datetime(year, month, 1, 0, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
# Calculate end of month
|
||||
if month == 12:
|
||||
end_of_month = datetime(year + 1, 1, 1, 0, 0, 0, tzinfo=timezone.utc) - timedelta(seconds=1)
|
||||
else:
|
||||
end_of_month = datetime(year, month + 1, 1, 0, 0, 0, tzinfo=timezone.utc) - timedelta(seconds=1)
|
||||
|
||||
# Query usage logs for this user and month
|
||||
result = self.tables_db.list_rows(
|
||||
database_id=self.database_id,
|
||||
table_id=self.COLLECTION_ID,
|
||||
queries=[
|
||||
Query.equal("user_id", user_id),
|
||||
Query.greater_than_equal("timestamp", start_of_month.isoformat()),
|
||||
Query.less_than_equal("timestamp", end_of_month.isoformat()),
|
||||
Query.limit(5000) # Cap at 5000 entries per month
|
||||
]
|
||||
)
|
||||
|
||||
# Aggregate the data
|
||||
total_requests = 0
|
||||
total_tokens = 0
|
||||
total_cost = 0.0
|
||||
|
||||
for doc in result['rows']:
|
||||
total_requests += 1
|
||||
total_tokens += doc.get('tokens_total', 0)
|
||||
total_cost += doc.get('estimated_cost', 0.0)
|
||||
|
||||
summary = MonthlyUsageSummary(
|
||||
year=year,
|
||||
month=month,
|
||||
user_id=user_id,
|
||||
total_requests=total_requests,
|
||||
total_tokens=total_tokens,
|
||||
estimated_cost=total_cost
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Monthly cost retrieved",
|
||||
user_id=user_id,
|
||||
year=year,
|
||||
month=month,
|
||||
total_requests=total_requests,
|
||||
estimated_cost=total_cost
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
except AppwriteException as e:
|
||||
logger.error(
|
||||
"Failed to get monthly cost",
|
||||
user_id=user_id,
|
||||
year=year,
|
||||
month=month,
|
||||
error=str(e),
|
||||
code=e.code
|
||||
)
|
||||
raise
|
||||
|
||||
def get_total_daily_cost(self, target_date: date) -> float:
|
||||
"""
|
||||
Get the total AI cost across all users for a specific day.
|
||||
|
||||
Used for admin monitoring and alerting.
|
||||
|
||||
Args:
|
||||
target_date: Date to get cost for
|
||||
|
||||
Returns:
|
||||
Total estimated cost in USD
|
||||
|
||||
Raises:
|
||||
AppwriteException: If query fails
|
||||
"""
|
||||
try:
|
||||
# Build date range for the target day
|
||||
start_of_day = datetime.combine(target_date, datetime.min.time()).replace(tzinfo=timezone.utc)
|
||||
end_of_day = datetime.combine(target_date, datetime.max.time()).replace(tzinfo=timezone.utc)
|
||||
|
||||
# Query all usage logs for this date
|
||||
result = self.tables_db.list_rows(
|
||||
database_id=self.database_id,
|
||||
table_id=self.COLLECTION_ID,
|
||||
queries=[
|
||||
Query.greater_than_equal("timestamp", start_of_day.isoformat()),
|
||||
Query.less_than_equal("timestamp", end_of_day.isoformat()),
|
||||
Query.limit(10000)
|
||||
]
|
||||
)
|
||||
|
||||
# Sum up costs
|
||||
total_cost = sum(doc.get('estimated_cost', 0.0) for doc in result['rows'])
|
||||
|
||||
logger.debug(
|
||||
"Total daily cost retrieved",
|
||||
date=target_date.isoformat(),
|
||||
total_cost=total_cost,
|
||||
total_documents=len(result['rows'])
|
||||
)
|
||||
|
||||
return total_cost
|
||||
|
||||
except AppwriteException as e:
|
||||
logger.error(
|
||||
"Failed to get total daily cost",
|
||||
date=target_date.isoformat(),
|
||||
error=str(e),
|
||||
code=e.code
|
||||
)
|
||||
raise
|
||||
|
||||
def get_user_request_count_today(self, user_id: str) -> int:
|
||||
"""
|
||||
Get the number of AI requests a user has made today.
|
||||
|
||||
Used for rate limiting checks.
|
||||
|
||||
Args:
|
||||
user_id: User ID to check
|
||||
|
||||
Returns:
|
||||
Number of requests made today
|
||||
|
||||
Raises:
|
||||
AppwriteException: If query fails
|
||||
"""
|
||||
try:
|
||||
summary = self.get_daily_usage(user_id, date.today())
|
||||
return summary.total_requests
|
||||
|
||||
except AppwriteException:
|
||||
# If there's an error, return 0 to be safe (fail open)
|
||||
logger.warning(
|
||||
"Failed to get user request count, returning 0",
|
||||
user_id=user_id
|
||||
)
|
||||
return 0
|
||||
|
||||
def _calculate_cost(self, model: str, tokens_input: int, tokens_output: int) -> float:
|
||||
"""
|
||||
Calculate the estimated cost for an AI request.
|
||||
|
||||
Args:
|
||||
model: Model identifier
|
||||
tokens_input: Number of input tokens
|
||||
tokens_output: Number of output tokens
|
||||
|
||||
Returns:
|
||||
Estimated cost in USD
|
||||
"""
|
||||
# Get cost per 1K tokens for this model
|
||||
model_cost = MODEL_COSTS.get(model, DEFAULT_COST)
|
||||
|
||||
# Calculate cost (costs are per 1K tokens)
|
||||
input_cost = (tokens_input / 1000) * model_cost["input"]
|
||||
output_cost = (tokens_output / 1000) * model_cost["output"]
|
||||
total_cost = input_cost + output_cost
|
||||
|
||||
return round(total_cost, 6) # Round to 6 decimal places
|
||||
|
||||
@staticmethod
|
||||
def estimate_cost_for_model(model: str, tokens_input: int, tokens_output: int) -> float:
|
||||
"""
|
||||
Static method to estimate cost without needing a service instance.
|
||||
|
||||
Useful for pre-calculation and UI display.
|
||||
|
||||
Args:
|
||||
model: Model identifier
|
||||
tokens_input: Number of input tokens
|
||||
tokens_output: Number of output tokens
|
||||
|
||||
Returns:
|
||||
Estimated cost in USD
|
||||
"""
|
||||
model_cost = MODEL_COSTS.get(model, DEFAULT_COST)
|
||||
input_cost = (tokens_input / 1000) * model_cost["input"]
|
||||
output_cost = (tokens_output / 1000) * model_cost["output"]
|
||||
return round(input_cost + output_cost, 6)
|
||||
|
||||
@staticmethod
|
||||
def get_model_cost_info(model: str) -> Dict[str, float]:
|
||||
"""
|
||||
Get cost information for a model.
|
||||
|
||||
Args:
|
||||
model: Model identifier
|
||||
|
||||
Returns:
|
||||
Dictionary with 'input' and 'output' cost per 1K tokens
|
||||
"""
|
||||
return MODEL_COSTS.get(model, DEFAULT_COST)
|
||||
Reference in New Issue
Block a user