541 lines
19 KiB
Python
541 lines
19 KiB
Python
"""
|
|
Narrative generator wrapper for AI content generation.
|
|
|
|
This module provides a high-level API for generating narrative content
|
|
using the appropriate AI models based on user tier and context.
|
|
"""
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
import structlog
|
|
|
|
from app.ai.replicate_client import (
|
|
ReplicateClient,
|
|
ReplicateResponse,
|
|
ReplicateClientError,
|
|
)
|
|
from app.ai.model_selector import (
|
|
ModelSelector,
|
|
ModelConfig,
|
|
UserTier,
|
|
ContextType,
|
|
)
|
|
from app.ai.prompt_templates import (
|
|
PromptTemplates,
|
|
PromptTemplateError,
|
|
get_prompt_templates,
|
|
)
|
|
|
|
logger = structlog.get_logger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class NarrativeResponse:
|
|
"""Response from narrative generation."""
|
|
narrative: str
|
|
tokens_used: int
|
|
tokens_input: int
|
|
tokens_output: int
|
|
model: str
|
|
context_type: str
|
|
generation_time: float
|
|
|
|
|
|
class NarrativeGeneratorError(Exception):
|
|
"""Base exception for narrative generator errors."""
|
|
pass
|
|
|
|
|
|
class NarrativeGenerator:
|
|
"""
|
|
High-level wrapper for AI narrative generation.
|
|
|
|
This class coordinates between the model selector, prompt templates,
|
|
and AI clients to generate narrative content for the game.
|
|
|
|
It provides specialized methods for different narrative contexts:
|
|
- Story progression responses
|
|
- Combat narration
|
|
- Quest selection
|
|
- NPC dialogue
|
|
"""
|
|
|
|
# System prompts for different contexts
|
|
SYSTEM_PROMPTS = {
|
|
ContextType.STORY_PROGRESSION: (
|
|
"You are an expert Dungeon Master running a solo D&D-style adventure. "
|
|
"Create immersive, engaging narratives that respond to player actions. "
|
|
"Be descriptive but concise. Always end with a clear opportunity for the player to act. "
|
|
"CRITICAL: NEVER give the player items, gold, equipment, or any rewards unless the action "
|
|
"instructions explicitly state they should receive them. Only narrate what the template "
|
|
"describes - do not improvise rewards or discoveries."
|
|
),
|
|
ContextType.COMBAT_NARRATION: (
|
|
"You are a combat narrator for a fantasy RPG. "
|
|
"Describe actions with visceral, cinematic detail. "
|
|
"Keep narration punchy and exciting. Never include game mechanics in prose."
|
|
),
|
|
ContextType.QUEST_SELECTION: (
|
|
"You are a quest selection system. "
|
|
"Analyze the context and select the most narratively appropriate quest. "
|
|
"Respond only with the quest_id - no explanation."
|
|
),
|
|
ContextType.NPC_DIALOGUE: (
|
|
"You are a skilled voice actor portraying NPCs in a fantasy world. "
|
|
"Stay in character at all times. Give each NPC a distinct voice and personality. "
|
|
"Provide useful information while maintaining immersion."
|
|
),
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
model_selector: ModelSelector | None = None,
|
|
replicate_client: ReplicateClient | None = None,
|
|
prompt_templates: PromptTemplates | None = None
|
|
):
|
|
"""
|
|
Initialize the narrative generator.
|
|
|
|
Args:
|
|
model_selector: Optional custom model selector.
|
|
replicate_client: Optional custom Replicate client.
|
|
prompt_templates: Optional custom prompt templates.
|
|
"""
|
|
self.model_selector = model_selector or ModelSelector()
|
|
self.replicate_client = replicate_client
|
|
self.prompt_templates = prompt_templates or get_prompt_templates()
|
|
|
|
logger.info("NarrativeGenerator initialized")
|
|
|
|
def _get_client(self, model_config: ModelConfig) -> ReplicateClient:
|
|
"""
|
|
Get or create a Replicate client for the given model configuration.
|
|
|
|
Args:
|
|
model_config: The model configuration to use.
|
|
|
|
Returns:
|
|
ReplicateClient configured for the specified model.
|
|
"""
|
|
# If a client was provided at init, use it
|
|
if self.replicate_client:
|
|
return self.replicate_client
|
|
|
|
# Otherwise create a new client with the specified model
|
|
return ReplicateClient(model=model_config.model_type)
|
|
|
|
def generate_story_response(
|
|
self,
|
|
character: dict[str, Any],
|
|
action: str,
|
|
game_state: dict[str, Any],
|
|
user_tier: UserTier,
|
|
conversation_history: list[dict[str, Any]] | None = None,
|
|
world_context: str | None = None,
|
|
action_instructions: str | None = None
|
|
) -> NarrativeResponse:
|
|
"""
|
|
Generate a DM response to a player's story action.
|
|
|
|
Args:
|
|
character: Character data dictionary with name, level, player_class, stats, etc.
|
|
action: The action the player wants to take.
|
|
game_state: Current game state with location, quests, etc.
|
|
user_tier: The user's subscription tier.
|
|
conversation_history: Optional list of recent conversation entries.
|
|
world_context: Optional additional world information.
|
|
action_instructions: Optional action-specific instructions for the AI from
|
|
the dm_prompt_template field in action_prompts.yaml.
|
|
|
|
Returns:
|
|
NarrativeResponse with the generated narrative and metadata.
|
|
|
|
Raises:
|
|
NarrativeGeneratorError: If generation fails.
|
|
|
|
Example:
|
|
>>> generator = NarrativeGenerator()
|
|
>>> response = generator.generate_story_response(
|
|
... character={"name": "Aldric", "level": 3, "player_class": "Fighter", ...},
|
|
... action="I search the room for hidden doors",
|
|
... game_state={"current_location": "Ancient Library", ...},
|
|
... user_tier=UserTier.PREMIUM
|
|
... )
|
|
>>> print(response.narrative)
|
|
"""
|
|
context_type = ContextType.STORY_PROGRESSION
|
|
|
|
logger.info(
|
|
"Generating story response",
|
|
character_name=character.get("name"),
|
|
action=action[:50],
|
|
user_tier=user_tier.value,
|
|
location=game_state.get("current_location")
|
|
)
|
|
|
|
# Get model configuration for this tier and context
|
|
model_config = self.model_selector.select_model(user_tier, context_type)
|
|
|
|
# Build the prompt from template
|
|
try:
|
|
prompt = self.prompt_templates.render(
|
|
"story_action.j2",
|
|
character=character,
|
|
action=action,
|
|
game_state=game_state,
|
|
conversation_history=conversation_history or [],
|
|
world_context=world_context,
|
|
max_tokens=model_config.max_tokens,
|
|
action_instructions=action_instructions
|
|
)
|
|
except PromptTemplateError as e:
|
|
logger.error("Failed to render story prompt", error=str(e))
|
|
raise NarrativeGeneratorError(f"Prompt template error: {e}")
|
|
|
|
# Debug: Log the full prompt being sent
|
|
logger.debug(
|
|
"Full prompt being sent to AI",
|
|
prompt_length=len(prompt),
|
|
conversation_history_count=len(conversation_history) if conversation_history else 0,
|
|
prompt_preview=prompt[:500] + "..." if len(prompt) > 500 else prompt
|
|
)
|
|
# For detailed debugging, uncomment the line below:
|
|
print(f"\n{'='*60}\nFULL PROMPT:\n{'='*60}\n{prompt}\n{'='*60}\n")
|
|
|
|
# Get system prompt
|
|
system_prompt = self.SYSTEM_PROMPTS[context_type]
|
|
|
|
# Generate response
|
|
try:
|
|
client = self._get_client(model_config)
|
|
response = client.generate(
|
|
prompt=prompt,
|
|
system_prompt=system_prompt,
|
|
max_tokens=model_config.max_tokens,
|
|
temperature=model_config.temperature,
|
|
model=model_config.model_type
|
|
)
|
|
except ReplicateClientError as e:
|
|
logger.error(
|
|
"AI generation failed",
|
|
error=str(e),
|
|
context_type=context_type.value
|
|
)
|
|
raise NarrativeGeneratorError(f"AI generation failed: {e}")
|
|
|
|
logger.info(
|
|
"Story response generated",
|
|
tokens_used=response.tokens_used,
|
|
model=response.model,
|
|
generation_time=f"{response.generation_time:.2f}s"
|
|
)
|
|
|
|
return NarrativeResponse(
|
|
narrative=response.text,
|
|
tokens_used=response.tokens_used,
|
|
tokens_input=response.tokens_input,
|
|
tokens_output=response.tokens_output,
|
|
model=response.model,
|
|
context_type=context_type.value,
|
|
generation_time=response.generation_time
|
|
)
|
|
|
|
def generate_combat_narration(
|
|
self,
|
|
character: dict[str, Any],
|
|
combat_state: dict[str, Any],
|
|
action: str,
|
|
action_result: dict[str, Any],
|
|
user_tier: UserTier,
|
|
is_critical: bool = False,
|
|
is_finishing_blow: bool = False
|
|
) -> NarrativeResponse:
|
|
"""
|
|
Generate narration for a combat action.
|
|
|
|
Args:
|
|
character: Character data dictionary.
|
|
combat_state: Current combat state with enemies, round number, etc.
|
|
action: Description of the combat action taken.
|
|
action_result: Result of the action (hit, damage, effects, etc.).
|
|
user_tier: The user's subscription tier.
|
|
is_critical: Whether this was a critical hit/miss.
|
|
is_finishing_blow: Whether this defeats the enemy.
|
|
|
|
Returns:
|
|
NarrativeResponse with combat narration.
|
|
|
|
Raises:
|
|
NarrativeGeneratorError: If generation fails.
|
|
|
|
Example:
|
|
>>> response = generator.generate_combat_narration(
|
|
... character={"name": "Aldric", ...},
|
|
... combat_state={"round_number": 3, "enemies": [...], ...},
|
|
... action="swings their sword at the goblin",
|
|
... action_result={"hit": True, "damage": 12, ...},
|
|
... user_tier=UserTier.BASIC
|
|
... )
|
|
"""
|
|
context_type = ContextType.COMBAT_NARRATION
|
|
|
|
logger.info(
|
|
"Generating combat narration",
|
|
character_name=character.get("name"),
|
|
action=action[:50],
|
|
is_critical=is_critical,
|
|
is_finishing_blow=is_finishing_blow
|
|
)
|
|
|
|
# Get model configuration
|
|
model_config = self.model_selector.select_model(user_tier, context_type)
|
|
|
|
# Build the prompt
|
|
try:
|
|
prompt = self.prompt_templates.render(
|
|
"combat_action.j2",
|
|
character=character,
|
|
combat_state=combat_state,
|
|
action=action,
|
|
action_result=action_result,
|
|
is_critical=is_critical,
|
|
is_finishing_blow=is_finishing_blow,
|
|
max_tokens=model_config.max_tokens
|
|
)
|
|
except PromptTemplateError as e:
|
|
logger.error("Failed to render combat prompt", error=str(e))
|
|
raise NarrativeGeneratorError(f"Prompt template error: {e}")
|
|
|
|
# Generate response
|
|
system_prompt = self.SYSTEM_PROMPTS[context_type]
|
|
|
|
try:
|
|
client = self._get_client(model_config)
|
|
response = client.generate(
|
|
prompt=prompt,
|
|
system_prompt=system_prompt,
|
|
max_tokens=model_config.max_tokens,
|
|
temperature=model_config.temperature,
|
|
model=model_config.model_type
|
|
)
|
|
except ReplicateClientError as e:
|
|
logger.error("Combat narration generation failed", error=str(e))
|
|
raise NarrativeGeneratorError(f"AI generation failed: {e}")
|
|
|
|
logger.info(
|
|
"Combat narration generated",
|
|
tokens_used=response.tokens_used,
|
|
generation_time=f"{response.generation_time:.2f}s"
|
|
)
|
|
|
|
return NarrativeResponse(
|
|
narrative=response.text,
|
|
tokens_used=response.tokens_used,
|
|
tokens_input=response.tokens_input,
|
|
tokens_output=response.tokens_output,
|
|
model=response.model,
|
|
context_type=context_type.value,
|
|
generation_time=response.generation_time
|
|
)
|
|
|
|
def generate_quest_selection(
|
|
self,
|
|
character: dict[str, Any],
|
|
eligible_quests: list[dict[str, Any]],
|
|
game_context: dict[str, Any],
|
|
user_tier: UserTier,
|
|
recent_actions: list[str] | None = None
|
|
) -> str:
|
|
"""
|
|
Use AI to select the most contextually appropriate quest.
|
|
|
|
Args:
|
|
character: Character data dictionary.
|
|
eligible_quests: List of quest data dictionaries that can be offered.
|
|
game_context: Current game context (location, events, etc.).
|
|
user_tier: The user's subscription tier.
|
|
recent_actions: Optional list of recent player actions.
|
|
|
|
Returns:
|
|
The quest_id of the selected quest.
|
|
|
|
Raises:
|
|
NarrativeGeneratorError: If generation fails or response is invalid.
|
|
|
|
Example:
|
|
>>> quest_id = generator.generate_quest_selection(
|
|
... character={"name": "Aldric", "level": 3, ...},
|
|
... eligible_quests=[{"quest_id": "goblin_cave", ...}, ...],
|
|
... game_context={"current_location": "Tavern", ...},
|
|
... user_tier=UserTier.FREE
|
|
... )
|
|
>>> print(quest_id) # "goblin_cave"
|
|
"""
|
|
context_type = ContextType.QUEST_SELECTION
|
|
|
|
logger.info(
|
|
"Generating quest selection",
|
|
character_name=character.get("name"),
|
|
num_eligible_quests=len(eligible_quests),
|
|
location=game_context.get("current_location")
|
|
)
|
|
|
|
if not eligible_quests:
|
|
raise NarrativeGeneratorError("No eligible quests provided")
|
|
|
|
# Get model configuration
|
|
model_config = self.model_selector.select_model(user_tier, context_type)
|
|
|
|
# Build the prompt
|
|
try:
|
|
prompt = self.prompt_templates.render(
|
|
"quest_offering.j2",
|
|
character=character,
|
|
eligible_quests=eligible_quests,
|
|
game_context=game_context,
|
|
recent_actions=recent_actions or []
|
|
)
|
|
except PromptTemplateError as e:
|
|
logger.error("Failed to render quest selection prompt", error=str(e))
|
|
raise NarrativeGeneratorError(f"Prompt template error: {e}")
|
|
|
|
# Generate response
|
|
system_prompt = self.SYSTEM_PROMPTS[context_type]
|
|
|
|
try:
|
|
client = self._get_client(model_config)
|
|
response = client.generate(
|
|
prompt=prompt,
|
|
system_prompt=system_prompt,
|
|
max_tokens=model_config.max_tokens,
|
|
temperature=model_config.temperature,
|
|
model=model_config.model_type
|
|
)
|
|
except ReplicateClientError as e:
|
|
logger.error("Quest selection generation failed", error=str(e))
|
|
raise NarrativeGeneratorError(f"AI generation failed: {e}")
|
|
|
|
# Parse the response to get quest_id
|
|
quest_id = response.text.strip().lower()
|
|
|
|
# Validate the response is a valid quest_id
|
|
valid_quest_ids = {q.get("quest_id", "").lower() for q in eligible_quests}
|
|
if quest_id not in valid_quest_ids:
|
|
logger.warning(
|
|
"AI returned invalid quest_id, using first eligible quest",
|
|
returned_id=quest_id,
|
|
valid_ids=list(valid_quest_ids)
|
|
)
|
|
quest_id = eligible_quests[0].get("quest_id", "")
|
|
|
|
logger.info(
|
|
"Quest selected",
|
|
quest_id=quest_id,
|
|
tokens_used=response.tokens_used,
|
|
generation_time=f"{response.generation_time:.2f}s"
|
|
)
|
|
|
|
return quest_id
|
|
|
|
def generate_npc_dialogue(
|
|
self,
|
|
character: dict[str, Any],
|
|
npc: dict[str, Any],
|
|
conversation_topic: str,
|
|
game_state: dict[str, Any],
|
|
user_tier: UserTier,
|
|
npc_relationship: str | None = None,
|
|
previous_dialogue: list[dict[str, Any]] | None = None,
|
|
npc_knowledge: list[str] | None = None
|
|
) -> NarrativeResponse:
|
|
"""
|
|
Generate NPC dialogue in response to player conversation.
|
|
|
|
Args:
|
|
character: Character data dictionary.
|
|
npc: NPC data with name, role, personality, etc.
|
|
conversation_topic: What the player said or wants to discuss.
|
|
game_state: Current game state.
|
|
user_tier: The user's subscription tier.
|
|
npc_relationship: Optional description of relationship with NPC.
|
|
previous_dialogue: Optional list of previous exchanges.
|
|
npc_knowledge: Optional list of things this NPC knows about.
|
|
|
|
Returns:
|
|
NarrativeResponse with NPC dialogue.
|
|
|
|
Raises:
|
|
NarrativeGeneratorError: If generation fails.
|
|
|
|
Example:
|
|
>>> response = generator.generate_npc_dialogue(
|
|
... character={"name": "Aldric", ...},
|
|
... npc={"name": "Old Barkeep", "role": "Tavern Owner", ...},
|
|
... conversation_topic="What rumors have you heard lately?",
|
|
... game_state={"current_location": "The Rusty Anchor", ...},
|
|
... user_tier=UserTier.PREMIUM
|
|
... )
|
|
"""
|
|
context_type = ContextType.NPC_DIALOGUE
|
|
|
|
logger.info(
|
|
"Generating NPC dialogue",
|
|
character_name=character.get("name"),
|
|
npc_name=npc.get("name"),
|
|
topic=conversation_topic[:50]
|
|
)
|
|
|
|
# Get model configuration
|
|
model_config = self.model_selector.select_model(user_tier, context_type)
|
|
|
|
# Build the prompt
|
|
try:
|
|
prompt = self.prompt_templates.render(
|
|
"npc_dialogue.j2",
|
|
character=character,
|
|
npc=npc,
|
|
conversation_topic=conversation_topic,
|
|
game_state=game_state,
|
|
npc_relationship=npc_relationship,
|
|
previous_dialogue=previous_dialogue or [],
|
|
npc_knowledge=npc_knowledge or [],
|
|
max_tokens=model_config.max_tokens
|
|
)
|
|
except PromptTemplateError as e:
|
|
logger.error("Failed to render NPC dialogue prompt", error=str(e))
|
|
raise NarrativeGeneratorError(f"Prompt template error: {e}")
|
|
|
|
# Generate response
|
|
system_prompt = self.SYSTEM_PROMPTS[context_type]
|
|
|
|
try:
|
|
client = self._get_client(model_config)
|
|
response = client.generate(
|
|
prompt=prompt,
|
|
system_prompt=system_prompt,
|
|
max_tokens=model_config.max_tokens,
|
|
temperature=model_config.temperature,
|
|
model=model_config.model_type
|
|
)
|
|
except ReplicateClientError as e:
|
|
logger.error("NPC dialogue generation failed", error=str(e))
|
|
raise NarrativeGeneratorError(f"AI generation failed: {e}")
|
|
|
|
logger.info(
|
|
"NPC dialogue generated",
|
|
npc_name=npc.get("name"),
|
|
tokens_used=response.tokens_used,
|
|
generation_time=f"{response.generation_time:.2f}s"
|
|
)
|
|
|
|
return NarrativeResponse(
|
|
narrative=response.text,
|
|
tokens_used=response.tokens_used,
|
|
tokens_input=response.tokens_input,
|
|
tokens_output=response.tokens_output,
|
|
model=response.model,
|
|
context_type=context_type.value,
|
|
generation_time=response.generation_time
|
|
)
|