""" Jinja2 prompt template system for AI generation. This module provides a templating system for building AI prompts with consistent structure and context injection. """ import os from pathlib import Path from typing import Any import structlog from jinja2 import Environment, FileSystemLoader, select_autoescape logger = structlog.get_logger(__name__) class PromptTemplateError(Exception): """Error in prompt template processing.""" pass class PromptTemplates: """ Manages Jinja2 templates for AI prompt generation. Provides caching, helper functions, and consistent template rendering for all AI prompt types. """ # Template directory relative to this module TEMPLATE_DIR = Path(__file__).parent / "templates" def __init__(self, template_dir: Path | str | None = None): """ Initialize the prompt template system. Args: template_dir: Optional custom template directory path. """ self.template_dir = Path(template_dir) if template_dir else self.TEMPLATE_DIR # Ensure template directory exists if not self.template_dir.exists(): self.template_dir.mkdir(parents=True, exist_ok=True) logger.warning( "Template directory created", path=str(self.template_dir) ) # Set up Jinja2 environment with caching self.env = Environment( loader=FileSystemLoader(str(self.template_dir)), autoescape=select_autoescape(['html', 'xml']), trim_blocks=True, lstrip_blocks=True, ) # Register custom filters self._register_filters() # Register custom globals self._register_globals() logger.info( "PromptTemplates initialized", template_dir=str(self.template_dir) ) def _register_filters(self): """Register custom Jinja2 filters.""" self.env.filters['format_inventory'] = self._format_inventory self.env.filters['format_stats'] = self._format_stats self.env.filters['format_skills'] = self._format_skills self.env.filters['format_effects'] = self._format_effects self.env.filters['truncate_text'] = self._truncate_text self.env.filters['format_gold'] = self._format_gold def _register_globals(self): """Register global functions available in templates.""" self.env.globals['len'] = len self.env.globals['min'] = min self.env.globals['max'] = max self.env.globals['enumerate'] = enumerate # Custom filters @staticmethod def _format_inventory(items: list[dict], max_items: int = 10) -> str: """ Format inventory items for prompt context. Args: items: List of item dictionaries with 'name' and 'quantity'. max_items: Maximum number of items to display. Returns: Formatted inventory string. """ if not items: return "Empty inventory" formatted = [] for item in items[:max_items]: name = item.get('name', 'Unknown') qty = item.get('quantity', 1) if qty > 1: formatted.append(f"{name} (x{qty})") else: formatted.append(name) result = ", ".join(formatted) if len(items) > max_items: result += f", and {len(items) - max_items} more items" return result @staticmethod def _format_stats(stats: dict) -> str: """ Format character stats for prompt context. Args: stats: Dictionary of stat names to values. Returns: Formatted stats string. """ if not stats: return "No stats available" formatted = [] for stat, value in stats.items(): # Convert snake_case to Title Case display_name = stat.replace('_', ' ').title() formatted.append(f"{display_name}: {value}") return ", ".join(formatted) @staticmethod def _format_skills(skills: list[dict], max_skills: int = 5) -> str: """ Format character skills for prompt context. Args: skills: List of skill dictionaries with 'name' and 'level'. max_skills: Maximum number of skills to display. Returns: Formatted skills string. """ if not skills: return "No skills" formatted = [] for skill in skills[:max_skills]: name = skill.get('name', 'Unknown') level = skill.get('level', 1) formatted.append(f"{name} (Lv.{level})") result = ", ".join(formatted) if len(skills) > max_skills: result += f", and {len(skills) - max_skills} more skills" return result @staticmethod def _format_effects(effects: list[dict]) -> str: """ Format active effects/buffs/debuffs for prompt context. Args: effects: List of effect dictionaries. Returns: Formatted effects string. """ if not effects: return "No active effects" formatted = [] for effect in effects: name = effect.get('name', 'Unknown') duration = effect.get('remaining_turns') if duration: formatted.append(f"{name} ({duration} turns)") else: formatted.append(name) return ", ".join(formatted) @staticmethod def _truncate_text(text: str, max_length: int = 100) -> str: """ Truncate text to maximum length with ellipsis. Args: text: Text to truncate. max_length: Maximum character length. Returns: Truncated text with ellipsis if needed. """ if len(text) <= max_length: return text return text[:max_length - 3] + "..." @staticmethod def _format_gold(amount: int) -> str: """ Format gold amount with commas. Args: amount: Gold amount. Returns: Formatted gold string. """ return f"{amount:,} gold" def render(self, template_name: str, **context: Any) -> str: """ Render a template with the given context. Args: template_name: Name of the template file (e.g., 'story_action.j2'). **context: Variables to pass to the template. Returns: Rendered template string. Raises: PromptTemplateError: If template not found or rendering fails. """ try: template = self.env.get_template(template_name) rendered = template.render(**context) logger.debug( "Template rendered", template=template_name, context_keys=list(context.keys()), output_length=len(rendered) ) return rendered.strip() except Exception as e: logger.error( "Template rendering failed", template=template_name, error=str(e) ) raise PromptTemplateError(f"Failed to render {template_name}: {e}") def render_string(self, template_string: str, **context: Any) -> str: """ Render a template string directly. Args: template_string: Jinja2 template string. **context: Variables to pass to the template. Returns: Rendered string. Raises: PromptTemplateError: If rendering fails. """ try: template = self.env.from_string(template_string) rendered = template.render(**context) return rendered.strip() except Exception as e: logger.error( "String template rendering failed", error=str(e) ) raise PromptTemplateError(f"Failed to render template string: {e}") def get_template_names(self) -> list[str]: """ Get list of available template names. Returns: List of template file names. """ return self.env.list_templates(extensions=['j2']) # Global instance for convenience _templates: PromptTemplates | None = None def get_prompt_templates() -> PromptTemplates: """ Get the global PromptTemplates instance. Returns: Singleton PromptTemplates instance. """ global _templates if _templates is None: _templates = PromptTemplates() return _templates def render_prompt(template_name: str, **context: Any) -> str: """ Convenience function to render a prompt template. Args: template_name: Name of the template file. **context: Variables to pass to the template. Returns: Rendered template string. """ return get_prompt_templates().render(template_name, **context)