Files
Code_of_Conquest/api/app/ai/prompt_templates.py
2025-11-24 23:10:55 -06:00

319 lines
8.9 KiB
Python

"""
Jinja2 prompt template system for AI generation.
This module provides a templating system for building AI prompts
with consistent structure and context injection.
"""
import os
from pathlib import Path
from typing import Any
import structlog
from jinja2 import Environment, FileSystemLoader, select_autoescape
logger = structlog.get_logger(__name__)
class PromptTemplateError(Exception):
"""Error in prompt template processing."""
pass
class PromptTemplates:
"""
Manages Jinja2 templates for AI prompt generation.
Provides caching, helper functions, and consistent template rendering
for all AI prompt types.
"""
# Template directory relative to this module
TEMPLATE_DIR = Path(__file__).parent / "templates"
def __init__(self, template_dir: Path | str | None = None):
"""
Initialize the prompt template system.
Args:
template_dir: Optional custom template directory path.
"""
self.template_dir = Path(template_dir) if template_dir else self.TEMPLATE_DIR
# Ensure template directory exists
if not self.template_dir.exists():
self.template_dir.mkdir(parents=True, exist_ok=True)
logger.warning(
"Template directory created",
path=str(self.template_dir)
)
# Set up Jinja2 environment with caching
self.env = Environment(
loader=FileSystemLoader(str(self.template_dir)),
autoescape=select_autoescape(['html', 'xml']),
trim_blocks=True,
lstrip_blocks=True,
)
# Register custom filters
self._register_filters()
# Register custom globals
self._register_globals()
logger.info(
"PromptTemplates initialized",
template_dir=str(self.template_dir)
)
def _register_filters(self):
"""Register custom Jinja2 filters."""
self.env.filters['format_inventory'] = self._format_inventory
self.env.filters['format_stats'] = self._format_stats
self.env.filters['format_skills'] = self._format_skills
self.env.filters['format_effects'] = self._format_effects
self.env.filters['truncate_text'] = self._truncate_text
self.env.filters['format_gold'] = self._format_gold
def _register_globals(self):
"""Register global functions available in templates."""
self.env.globals['len'] = len
self.env.globals['min'] = min
self.env.globals['max'] = max
self.env.globals['enumerate'] = enumerate
# Custom filters
@staticmethod
def _format_inventory(items: list[dict], max_items: int = 10) -> str:
"""
Format inventory items for prompt context.
Args:
items: List of item dictionaries with 'name' and 'quantity'.
max_items: Maximum number of items to display.
Returns:
Formatted inventory string.
"""
if not items:
return "Empty inventory"
formatted = []
for item in items[:max_items]:
name = item.get('name', 'Unknown')
qty = item.get('quantity', 1)
if qty > 1:
formatted.append(f"{name} (x{qty})")
else:
formatted.append(name)
result = ", ".join(formatted)
if len(items) > max_items:
result += f", and {len(items) - max_items} more items"
return result
@staticmethod
def _format_stats(stats: dict) -> str:
"""
Format character stats for prompt context.
Args:
stats: Dictionary of stat names to values.
Returns:
Formatted stats string.
"""
if not stats:
return "No stats available"
formatted = []
for stat, value in stats.items():
# Convert snake_case to Title Case
display_name = stat.replace('_', ' ').title()
formatted.append(f"{display_name}: {value}")
return ", ".join(formatted)
@staticmethod
def _format_skills(skills: list[dict], max_skills: int = 5) -> str:
"""
Format character skills for prompt context.
Args:
skills: List of skill dictionaries with 'name' and 'level'.
max_skills: Maximum number of skills to display.
Returns:
Formatted skills string.
"""
if not skills:
return "No skills"
formatted = []
for skill in skills[:max_skills]:
name = skill.get('name', 'Unknown')
level = skill.get('level', 1)
formatted.append(f"{name} (Lv.{level})")
result = ", ".join(formatted)
if len(skills) > max_skills:
result += f", and {len(skills) - max_skills} more skills"
return result
@staticmethod
def _format_effects(effects: list[dict]) -> str:
"""
Format active effects/buffs/debuffs for prompt context.
Args:
effects: List of effect dictionaries.
Returns:
Formatted effects string.
"""
if not effects:
return "No active effects"
formatted = []
for effect in effects:
name = effect.get('name', 'Unknown')
duration = effect.get('remaining_turns')
if duration:
formatted.append(f"{name} ({duration} turns)")
else:
formatted.append(name)
return ", ".join(formatted)
@staticmethod
def _truncate_text(text: str, max_length: int = 100) -> str:
"""
Truncate text to maximum length with ellipsis.
Args:
text: Text to truncate.
max_length: Maximum character length.
Returns:
Truncated text with ellipsis if needed.
"""
if len(text) <= max_length:
return text
return text[:max_length - 3] + "..."
@staticmethod
def _format_gold(amount: int) -> str:
"""
Format gold amount with commas.
Args:
amount: Gold amount.
Returns:
Formatted gold string.
"""
return f"{amount:,} gold"
def render(self, template_name: str, **context: Any) -> str:
"""
Render a template with the given context.
Args:
template_name: Name of the template file (e.g., 'story_action.j2').
**context: Variables to pass to the template.
Returns:
Rendered template string.
Raises:
PromptTemplateError: If template not found or rendering fails.
"""
try:
template = self.env.get_template(template_name)
rendered = template.render(**context)
logger.debug(
"Template rendered",
template=template_name,
context_keys=list(context.keys()),
output_length=len(rendered)
)
return rendered.strip()
except Exception as e:
logger.error(
"Template rendering failed",
template=template_name,
error=str(e)
)
raise PromptTemplateError(f"Failed to render {template_name}: {e}")
def render_string(self, template_string: str, **context: Any) -> str:
"""
Render a template string directly.
Args:
template_string: Jinja2 template string.
**context: Variables to pass to the template.
Returns:
Rendered string.
Raises:
PromptTemplateError: If rendering fails.
"""
try:
template = self.env.from_string(template_string)
rendered = template.render(**context)
return rendered.strip()
except Exception as e:
logger.error(
"String template rendering failed",
error=str(e)
)
raise PromptTemplateError(f"Failed to render template string: {e}")
def get_template_names(self) -> list[str]:
"""
Get list of available template names.
Returns:
List of template file names.
"""
return self.env.list_templates(extensions=['j2'])
# Global instance for convenience
_templates: PromptTemplates | None = None
def get_prompt_templates() -> PromptTemplates:
"""
Get the global PromptTemplates instance.
Returns:
Singleton PromptTemplates instance.
"""
global _templates
if _templates is None:
_templates = PromptTemplates()
return _templates
def render_prompt(template_name: str, **context: Any) -> str:
"""
Convenience function to render a prompt template.
Args:
template_name: Name of the template file.
**context: Variables to pass to the template.
Returns:
Rendered template string.
"""
return get_prompt_templates().render(template_name, **context)