321 lines
8.9 KiB
Python
321 lines
8.9 KiB
Python
"""
|
|
Action Prompt Loader Service
|
|
|
|
This module provides a service for loading and filtering action prompts from YAML.
|
|
It implements a singleton pattern to cache loaded prompts in memory.
|
|
|
|
Usage:
|
|
from app.services.action_prompt_loader import ActionPromptLoader
|
|
|
|
loader = ActionPromptLoader()
|
|
loader.load_from_yaml("app/data/action_prompts.yaml")
|
|
|
|
# Get available actions for a user at a location
|
|
actions = loader.get_available_actions(
|
|
user_tier=UserTier.FREE,
|
|
location_type=LocationType.TOWN
|
|
)
|
|
|
|
# Get specific action
|
|
action = loader.get_action_by_id("ask_locals")
|
|
"""
|
|
|
|
import os
|
|
from typing import List, Optional, Dict
|
|
|
|
import yaml
|
|
|
|
from app.models.action_prompt import ActionPrompt, LocationType
|
|
from app.ai.model_selector import UserTier
|
|
from app.utils.logging import get_logger
|
|
|
|
|
|
# Initialize logger
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
class ActionPromptLoaderError(Exception):
|
|
"""Base exception for action prompt loader errors."""
|
|
pass
|
|
|
|
|
|
class ActionPromptNotFoundError(ActionPromptLoaderError):
|
|
"""Raised when a requested action prompt is not found."""
|
|
pass
|
|
|
|
|
|
class ActionPromptLoader:
|
|
"""
|
|
Service for loading and filtering action prompts.
|
|
|
|
This class loads action prompts from YAML files and provides methods
|
|
to filter them based on user tier and location type.
|
|
|
|
Uses singleton pattern to cache loaded prompts in memory.
|
|
|
|
Attributes:
|
|
_prompts: Dictionary of loaded action prompts keyed by prompt_id
|
|
_loaded: Flag indicating if prompts have been loaded
|
|
"""
|
|
|
|
_instance = None
|
|
_prompts: Dict[str, ActionPrompt] = {}
|
|
_loaded: bool = False
|
|
|
|
def __new__(cls):
|
|
"""Implement singleton pattern."""
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
cls._instance._prompts = {}
|
|
cls._instance._loaded = False
|
|
return cls._instance
|
|
|
|
def load_from_yaml(self, filepath: str) -> int:
|
|
"""
|
|
Load action prompts from a YAML file.
|
|
|
|
Args:
|
|
filepath: Path to the YAML file
|
|
|
|
Returns:
|
|
Number of prompts loaded
|
|
|
|
Raises:
|
|
ActionPromptLoaderError: If file cannot be read or parsed
|
|
"""
|
|
if not os.path.exists(filepath):
|
|
logger.error("Action prompts file not found", filepath=filepath)
|
|
raise ActionPromptLoaderError(f"File not found: {filepath}")
|
|
|
|
try:
|
|
with open(filepath, 'r', encoding='utf-8') as f:
|
|
data = yaml.safe_load(f)
|
|
|
|
except yaml.YAMLError as e:
|
|
logger.error("Failed to parse YAML", filepath=filepath, error=str(e))
|
|
raise ActionPromptLoaderError(f"Invalid YAML in {filepath}: {e}")
|
|
|
|
except IOError as e:
|
|
logger.error("Failed to read file", filepath=filepath, error=str(e))
|
|
raise ActionPromptLoaderError(f"Cannot read {filepath}: {e}")
|
|
|
|
if not data or 'action_prompts' not in data:
|
|
logger.error("No action_prompts key in YAML", filepath=filepath)
|
|
raise ActionPromptLoaderError(f"Missing 'action_prompts' key in {filepath}")
|
|
|
|
# Clear existing prompts
|
|
self._prompts = {}
|
|
|
|
# Parse each prompt
|
|
prompts_data = data['action_prompts']
|
|
errors = []
|
|
|
|
for i, prompt_data in enumerate(prompts_data):
|
|
try:
|
|
prompt = ActionPrompt.from_dict(prompt_data)
|
|
self._prompts[prompt.prompt_id] = prompt
|
|
|
|
except (ValueError, KeyError) as e:
|
|
errors.append(f"Prompt {i}: {e}")
|
|
logger.warning(
|
|
"Failed to parse action prompt",
|
|
index=i,
|
|
error=str(e)
|
|
)
|
|
|
|
if errors:
|
|
logger.warning(
|
|
"Some action prompts failed to load",
|
|
error_count=len(errors),
|
|
errors=errors
|
|
)
|
|
|
|
self._loaded = True
|
|
loaded_count = len(self._prompts)
|
|
|
|
logger.info(
|
|
"Action prompts loaded",
|
|
filepath=filepath,
|
|
count=loaded_count,
|
|
errors=len(errors)
|
|
)
|
|
|
|
return loaded_count
|
|
|
|
def get_all_actions(self) -> List[ActionPrompt]:
|
|
"""
|
|
Get all loaded action prompts.
|
|
|
|
Returns:
|
|
List of all action prompts
|
|
"""
|
|
self._ensure_loaded()
|
|
return list(self._prompts.values())
|
|
|
|
def get_action_by_id(self, prompt_id: str) -> ActionPrompt:
|
|
"""
|
|
Get a specific action prompt by ID.
|
|
|
|
Args:
|
|
prompt_id: The unique identifier of the action
|
|
|
|
Returns:
|
|
The ActionPrompt object
|
|
|
|
Raises:
|
|
ActionPromptNotFoundError: If action not found
|
|
"""
|
|
self._ensure_loaded()
|
|
|
|
if prompt_id not in self._prompts:
|
|
logger.warning("Action prompt not found", prompt_id=prompt_id)
|
|
raise ActionPromptNotFoundError(f"Action prompt '{prompt_id}' not found")
|
|
|
|
return self._prompts[prompt_id]
|
|
|
|
def get_available_actions(
|
|
self,
|
|
user_tier: UserTier,
|
|
location_type: LocationType
|
|
) -> List[ActionPrompt]:
|
|
"""
|
|
Get actions available to a user at a specific location.
|
|
|
|
Args:
|
|
user_tier: The user's subscription tier
|
|
location_type: The current location type
|
|
|
|
Returns:
|
|
List of available action prompts
|
|
"""
|
|
self._ensure_loaded()
|
|
|
|
available = []
|
|
for prompt in self._prompts.values():
|
|
if prompt.is_available(user_tier, location_type):
|
|
available.append(prompt)
|
|
|
|
logger.debug(
|
|
"Filtered available actions",
|
|
user_tier=user_tier.value,
|
|
location_type=location_type.value,
|
|
count=len(available)
|
|
)
|
|
|
|
return available
|
|
|
|
def get_actions_by_tier(self, user_tier: UserTier) -> List[ActionPrompt]:
|
|
"""
|
|
Get all actions available to a user tier (ignoring location).
|
|
|
|
Args:
|
|
user_tier: The user's subscription tier
|
|
|
|
Returns:
|
|
List of action prompts available to the tier
|
|
"""
|
|
self._ensure_loaded()
|
|
|
|
available = []
|
|
for prompt in self._prompts.values():
|
|
if prompt._tier_meets_requirement(user_tier):
|
|
available.append(prompt)
|
|
|
|
return available
|
|
|
|
def get_actions_by_category(self, category: str) -> List[ActionPrompt]:
|
|
"""
|
|
Get all actions in a specific category.
|
|
|
|
Args:
|
|
category: The action category (e.g., "ask_question", "explore")
|
|
|
|
Returns:
|
|
List of action prompts in the category
|
|
"""
|
|
self._ensure_loaded()
|
|
|
|
return [
|
|
prompt for prompt in self._prompts.values()
|
|
if prompt.category.value == category
|
|
]
|
|
|
|
def get_locked_actions(
|
|
self,
|
|
user_tier: UserTier,
|
|
location_type: LocationType
|
|
) -> List[ActionPrompt]:
|
|
"""
|
|
Get actions that are locked due to tier restrictions.
|
|
|
|
Used to show locked actions with upgrade prompts in UI.
|
|
|
|
Args:
|
|
user_tier: The user's subscription tier
|
|
location_type: The current location type
|
|
|
|
Returns:
|
|
List of locked action prompts
|
|
"""
|
|
self._ensure_loaded()
|
|
|
|
locked = []
|
|
for prompt in self._prompts.values():
|
|
# Must match location but be tier-locked
|
|
if prompt._location_matches_filter(location_type) and prompt.is_locked(user_tier):
|
|
locked.append(prompt)
|
|
|
|
return locked
|
|
|
|
def reload(self, filepath: str) -> int:
|
|
"""
|
|
Force reload prompts from YAML file.
|
|
|
|
Args:
|
|
filepath: Path to the YAML file
|
|
|
|
Returns:
|
|
Number of prompts loaded
|
|
"""
|
|
self._loaded = False
|
|
return self.load_from_yaml(filepath)
|
|
|
|
def is_loaded(self) -> bool:
|
|
"""Check if prompts have been loaded."""
|
|
return self._loaded
|
|
|
|
def get_prompt_count(self) -> int:
|
|
"""Get the number of loaded prompts."""
|
|
return len(self._prompts)
|
|
|
|
def _ensure_loaded(self) -> None:
|
|
"""
|
|
Ensure prompts are loaded, auto-load from default path if not.
|
|
|
|
Raises:
|
|
ActionPromptLoaderError: If prompts cannot be loaded
|
|
"""
|
|
if not self._loaded:
|
|
# Try default path
|
|
default_path = os.path.join(
|
|
os.path.dirname(__file__),
|
|
'..', 'data', 'action_prompts.yaml'
|
|
)
|
|
default_path = os.path.normpath(default_path)
|
|
|
|
if os.path.exists(default_path):
|
|
self.load_from_yaml(default_path)
|
|
else:
|
|
raise ActionPromptLoaderError(
|
|
"Action prompts not loaded. Call load_from_yaml() first."
|
|
)
|
|
|
|
@classmethod
|
|
def reset_instance(cls) -> None:
|
|
"""
|
|
Reset the singleton instance.
|
|
|
|
Primarily for testing purposes.
|
|
"""
|
|
cls._instance = None
|