first commit
This commit is contained in:
318
api/app/ai/prompt_templates.py
Normal file
318
api/app/ai/prompt_templates.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""
|
||||
Jinja2 prompt template system for AI generation.
|
||||
|
||||
This module provides a templating system for building AI prompts
|
||||
with consistent structure and context injection.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
from jinja2 import Environment, FileSystemLoader, select_autoescape
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class PromptTemplateError(Exception):
|
||||
"""Error in prompt template processing."""
|
||||
pass
|
||||
|
||||
|
||||
class PromptTemplates:
|
||||
"""
|
||||
Manages Jinja2 templates for AI prompt generation.
|
||||
|
||||
Provides caching, helper functions, and consistent template rendering
|
||||
for all AI prompt types.
|
||||
"""
|
||||
|
||||
# Template directory relative to this module
|
||||
TEMPLATE_DIR = Path(__file__).parent / "templates"
|
||||
|
||||
def __init__(self, template_dir: Path | str | None = None):
|
||||
"""
|
||||
Initialize the prompt template system.
|
||||
|
||||
Args:
|
||||
template_dir: Optional custom template directory path.
|
||||
"""
|
||||
self.template_dir = Path(template_dir) if template_dir else self.TEMPLATE_DIR
|
||||
|
||||
# Ensure template directory exists
|
||||
if not self.template_dir.exists():
|
||||
self.template_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.warning(
|
||||
"Template directory created",
|
||||
path=str(self.template_dir)
|
||||
)
|
||||
|
||||
# Set up Jinja2 environment with caching
|
||||
self.env = Environment(
|
||||
loader=FileSystemLoader(str(self.template_dir)),
|
||||
autoescape=select_autoescape(['html', 'xml']),
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
)
|
||||
|
||||
# Register custom filters
|
||||
self._register_filters()
|
||||
|
||||
# Register custom globals
|
||||
self._register_globals()
|
||||
|
||||
logger.info(
|
||||
"PromptTemplates initialized",
|
||||
template_dir=str(self.template_dir)
|
||||
)
|
||||
|
||||
def _register_filters(self):
|
||||
"""Register custom Jinja2 filters."""
|
||||
self.env.filters['format_inventory'] = self._format_inventory
|
||||
self.env.filters['format_stats'] = self._format_stats
|
||||
self.env.filters['format_skills'] = self._format_skills
|
||||
self.env.filters['format_effects'] = self._format_effects
|
||||
self.env.filters['truncate_text'] = self._truncate_text
|
||||
self.env.filters['format_gold'] = self._format_gold
|
||||
|
||||
def _register_globals(self):
|
||||
"""Register global functions available in templates."""
|
||||
self.env.globals['len'] = len
|
||||
self.env.globals['min'] = min
|
||||
self.env.globals['max'] = max
|
||||
self.env.globals['enumerate'] = enumerate
|
||||
|
||||
# Custom filters
|
||||
@staticmethod
|
||||
def _format_inventory(items: list[dict], max_items: int = 10) -> str:
|
||||
"""
|
||||
Format inventory items for prompt context.
|
||||
|
||||
Args:
|
||||
items: List of item dictionaries with 'name' and 'quantity'.
|
||||
max_items: Maximum number of items to display.
|
||||
|
||||
Returns:
|
||||
Formatted inventory string.
|
||||
"""
|
||||
if not items:
|
||||
return "Empty inventory"
|
||||
|
||||
formatted = []
|
||||
for item in items[:max_items]:
|
||||
name = item.get('name', 'Unknown')
|
||||
qty = item.get('quantity', 1)
|
||||
if qty > 1:
|
||||
formatted.append(f"{name} (x{qty})")
|
||||
else:
|
||||
formatted.append(name)
|
||||
|
||||
result = ", ".join(formatted)
|
||||
if len(items) > max_items:
|
||||
result += f", and {len(items) - max_items} more items"
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _format_stats(stats: dict) -> str:
|
||||
"""
|
||||
Format character stats for prompt context.
|
||||
|
||||
Args:
|
||||
stats: Dictionary of stat names to values.
|
||||
|
||||
Returns:
|
||||
Formatted stats string.
|
||||
"""
|
||||
if not stats:
|
||||
return "No stats available"
|
||||
|
||||
formatted = []
|
||||
for stat, value in stats.items():
|
||||
# Convert snake_case to Title Case
|
||||
display_name = stat.replace('_', ' ').title()
|
||||
formatted.append(f"{display_name}: {value}")
|
||||
|
||||
return ", ".join(formatted)
|
||||
|
||||
@staticmethod
|
||||
def _format_skills(skills: list[dict], max_skills: int = 5) -> str:
|
||||
"""
|
||||
Format character skills for prompt context.
|
||||
|
||||
Args:
|
||||
skills: List of skill dictionaries with 'name' and 'level'.
|
||||
max_skills: Maximum number of skills to display.
|
||||
|
||||
Returns:
|
||||
Formatted skills string.
|
||||
"""
|
||||
if not skills:
|
||||
return "No skills"
|
||||
|
||||
formatted = []
|
||||
for skill in skills[:max_skills]:
|
||||
name = skill.get('name', 'Unknown')
|
||||
level = skill.get('level', 1)
|
||||
formatted.append(f"{name} (Lv.{level})")
|
||||
|
||||
result = ", ".join(formatted)
|
||||
if len(skills) > max_skills:
|
||||
result += f", and {len(skills) - max_skills} more skills"
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _format_effects(effects: list[dict]) -> str:
|
||||
"""
|
||||
Format active effects/buffs/debuffs for prompt context.
|
||||
|
||||
Args:
|
||||
effects: List of effect dictionaries.
|
||||
|
||||
Returns:
|
||||
Formatted effects string.
|
||||
"""
|
||||
if not effects:
|
||||
return "No active effects"
|
||||
|
||||
formatted = []
|
||||
for effect in effects:
|
||||
name = effect.get('name', 'Unknown')
|
||||
duration = effect.get('remaining_turns')
|
||||
if duration:
|
||||
formatted.append(f"{name} ({duration} turns)")
|
||||
else:
|
||||
formatted.append(name)
|
||||
|
||||
return ", ".join(formatted)
|
||||
|
||||
@staticmethod
|
||||
def _truncate_text(text: str, max_length: int = 100) -> str:
|
||||
"""
|
||||
Truncate text to maximum length with ellipsis.
|
||||
|
||||
Args:
|
||||
text: Text to truncate.
|
||||
max_length: Maximum character length.
|
||||
|
||||
Returns:
|
||||
Truncated text with ellipsis if needed.
|
||||
"""
|
||||
if len(text) <= max_length:
|
||||
return text
|
||||
return text[:max_length - 3] + "..."
|
||||
|
||||
@staticmethod
|
||||
def _format_gold(amount: int) -> str:
|
||||
"""
|
||||
Format gold amount with commas.
|
||||
|
||||
Args:
|
||||
amount: Gold amount.
|
||||
|
||||
Returns:
|
||||
Formatted gold string.
|
||||
"""
|
||||
return f"{amount:,} gold"
|
||||
|
||||
def render(self, template_name: str, **context: Any) -> str:
|
||||
"""
|
||||
Render a template with the given context.
|
||||
|
||||
Args:
|
||||
template_name: Name of the template file (e.g., 'story_action.j2').
|
||||
**context: Variables to pass to the template.
|
||||
|
||||
Returns:
|
||||
Rendered template string.
|
||||
|
||||
Raises:
|
||||
PromptTemplateError: If template not found or rendering fails.
|
||||
"""
|
||||
try:
|
||||
template = self.env.get_template(template_name)
|
||||
rendered = template.render(**context)
|
||||
|
||||
logger.debug(
|
||||
"Template rendered",
|
||||
template=template_name,
|
||||
context_keys=list(context.keys()),
|
||||
output_length=len(rendered)
|
||||
)
|
||||
|
||||
return rendered.strip()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Template rendering failed",
|
||||
template=template_name,
|
||||
error=str(e)
|
||||
)
|
||||
raise PromptTemplateError(f"Failed to render {template_name}: {e}")
|
||||
|
||||
def render_string(self, template_string: str, **context: Any) -> str:
|
||||
"""
|
||||
Render a template string directly.
|
||||
|
||||
Args:
|
||||
template_string: Jinja2 template string.
|
||||
**context: Variables to pass to the template.
|
||||
|
||||
Returns:
|
||||
Rendered string.
|
||||
|
||||
Raises:
|
||||
PromptTemplateError: If rendering fails.
|
||||
"""
|
||||
try:
|
||||
template = self.env.from_string(template_string)
|
||||
rendered = template.render(**context)
|
||||
return rendered.strip()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"String template rendering failed",
|
||||
error=str(e)
|
||||
)
|
||||
raise PromptTemplateError(f"Failed to render template string: {e}")
|
||||
|
||||
def get_template_names(self) -> list[str]:
|
||||
"""
|
||||
Get list of available template names.
|
||||
|
||||
Returns:
|
||||
List of template file names.
|
||||
"""
|
||||
return self.env.list_templates(extensions=['j2'])
|
||||
|
||||
|
||||
# Global instance for convenience
|
||||
_templates: PromptTemplates | None = None
|
||||
|
||||
|
||||
def get_prompt_templates() -> PromptTemplates:
|
||||
"""
|
||||
Get the global PromptTemplates instance.
|
||||
|
||||
Returns:
|
||||
Singleton PromptTemplates instance.
|
||||
"""
|
||||
global _templates
|
||||
if _templates is None:
|
||||
_templates = PromptTemplates()
|
||||
return _templates
|
||||
|
||||
|
||||
def render_prompt(template_name: str, **context: Any) -> str:
|
||||
"""
|
||||
Convenience function to render a prompt template.
|
||||
|
||||
Args:
|
||||
template_name: Name of the template file.
|
||||
**context: Variables to pass to the template.
|
||||
|
||||
Returns:
|
||||
Rendered template string.
|
||||
"""
|
||||
return get_prompt_templates().render(template_name, **context)
|
||||
Reference in New Issue
Block a user