first commit
This commit is contained in:
540
api/app/ai/narrative_generator.py
Normal file
540
api/app/ai/narrative_generator.py
Normal file
@@ -0,0 +1,540 @@
|
||||
"""
|
||||
Narrative generator wrapper for AI content generation.
|
||||
|
||||
This module provides a high-level API for generating narrative content
|
||||
using the appropriate AI models based on user tier and context.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
|
||||
from app.ai.replicate_client import (
|
||||
ReplicateClient,
|
||||
ReplicateResponse,
|
||||
ReplicateClientError,
|
||||
)
|
||||
from app.ai.model_selector import (
|
||||
ModelSelector,
|
||||
ModelConfig,
|
||||
UserTier,
|
||||
ContextType,
|
||||
)
|
||||
from app.ai.prompt_templates import (
|
||||
PromptTemplates,
|
||||
PromptTemplateError,
|
||||
get_prompt_templates,
|
||||
)
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NarrativeResponse:
|
||||
"""Response from narrative generation."""
|
||||
narrative: str
|
||||
tokens_used: int
|
||||
tokens_input: int
|
||||
tokens_output: int
|
||||
model: str
|
||||
context_type: str
|
||||
generation_time: float
|
||||
|
||||
|
||||
class NarrativeGeneratorError(Exception):
|
||||
"""Base exception for narrative generator errors."""
|
||||
pass
|
||||
|
||||
|
||||
class NarrativeGenerator:
|
||||
"""
|
||||
High-level wrapper for AI narrative generation.
|
||||
|
||||
This class coordinates between the model selector, prompt templates,
|
||||
and AI clients to generate narrative content for the game.
|
||||
|
||||
It provides specialized methods for different narrative contexts:
|
||||
- Story progression responses
|
||||
- Combat narration
|
||||
- Quest selection
|
||||
- NPC dialogue
|
||||
"""
|
||||
|
||||
# System prompts for different contexts
|
||||
SYSTEM_PROMPTS = {
|
||||
ContextType.STORY_PROGRESSION: (
|
||||
"You are an expert Dungeon Master running a solo D&D-style adventure. "
|
||||
"Create immersive, engaging narratives that respond to player actions. "
|
||||
"Be descriptive but concise. Always end with a clear opportunity for the player to act. "
|
||||
"CRITICAL: NEVER give the player items, gold, equipment, or any rewards unless the action "
|
||||
"instructions explicitly state they should receive them. Only narrate what the template "
|
||||
"describes - do not improvise rewards or discoveries."
|
||||
),
|
||||
ContextType.COMBAT_NARRATION: (
|
||||
"You are a combat narrator for a fantasy RPG. "
|
||||
"Describe actions with visceral, cinematic detail. "
|
||||
"Keep narration punchy and exciting. Never include game mechanics in prose."
|
||||
),
|
||||
ContextType.QUEST_SELECTION: (
|
||||
"You are a quest selection system. "
|
||||
"Analyze the context and select the most narratively appropriate quest. "
|
||||
"Respond only with the quest_id - no explanation."
|
||||
),
|
||||
ContextType.NPC_DIALOGUE: (
|
||||
"You are a skilled voice actor portraying NPCs in a fantasy world. "
|
||||
"Stay in character at all times. Give each NPC a distinct voice and personality. "
|
||||
"Provide useful information while maintaining immersion."
|
||||
),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_selector: ModelSelector | None = None,
|
||||
replicate_client: ReplicateClient | None = None,
|
||||
prompt_templates: PromptTemplates | None = None
|
||||
):
|
||||
"""
|
||||
Initialize the narrative generator.
|
||||
|
||||
Args:
|
||||
model_selector: Optional custom model selector.
|
||||
replicate_client: Optional custom Replicate client.
|
||||
prompt_templates: Optional custom prompt templates.
|
||||
"""
|
||||
self.model_selector = model_selector or ModelSelector()
|
||||
self.replicate_client = replicate_client
|
||||
self.prompt_templates = prompt_templates or get_prompt_templates()
|
||||
|
||||
logger.info("NarrativeGenerator initialized")
|
||||
|
||||
def _get_client(self, model_config: ModelConfig) -> ReplicateClient:
|
||||
"""
|
||||
Get or create a Replicate client for the given model configuration.
|
||||
|
||||
Args:
|
||||
model_config: The model configuration to use.
|
||||
|
||||
Returns:
|
||||
ReplicateClient configured for the specified model.
|
||||
"""
|
||||
# If a client was provided at init, use it
|
||||
if self.replicate_client:
|
||||
return self.replicate_client
|
||||
|
||||
# Otherwise create a new client with the specified model
|
||||
return ReplicateClient(model=model_config.model_type)
|
||||
|
||||
def generate_story_response(
|
||||
self,
|
||||
character: dict[str, Any],
|
||||
action: str,
|
||||
game_state: dict[str, Any],
|
||||
user_tier: UserTier,
|
||||
conversation_history: list[dict[str, Any]] | None = None,
|
||||
world_context: str | None = None,
|
||||
action_instructions: str | None = None
|
||||
) -> NarrativeResponse:
|
||||
"""
|
||||
Generate a DM response to a player's story action.
|
||||
|
||||
Args:
|
||||
character: Character data dictionary with name, level, player_class, stats, etc.
|
||||
action: The action the player wants to take.
|
||||
game_state: Current game state with location, quests, etc.
|
||||
user_tier: The user's subscription tier.
|
||||
conversation_history: Optional list of recent conversation entries.
|
||||
world_context: Optional additional world information.
|
||||
action_instructions: Optional action-specific instructions for the AI from
|
||||
the dm_prompt_template field in action_prompts.yaml.
|
||||
|
||||
Returns:
|
||||
NarrativeResponse with the generated narrative and metadata.
|
||||
|
||||
Raises:
|
||||
NarrativeGeneratorError: If generation fails.
|
||||
|
||||
Example:
|
||||
>>> generator = NarrativeGenerator()
|
||||
>>> response = generator.generate_story_response(
|
||||
... character={"name": "Aldric", "level": 3, "player_class": "Fighter", ...},
|
||||
... action="I search the room for hidden doors",
|
||||
... game_state={"current_location": "Ancient Library", ...},
|
||||
... user_tier=UserTier.PREMIUM
|
||||
... )
|
||||
>>> print(response.narrative)
|
||||
"""
|
||||
context_type = ContextType.STORY_PROGRESSION
|
||||
|
||||
logger.info(
|
||||
"Generating story response",
|
||||
character_name=character.get("name"),
|
||||
action=action[:50],
|
||||
user_tier=user_tier.value,
|
||||
location=game_state.get("current_location")
|
||||
)
|
||||
|
||||
# Get model configuration for this tier and context
|
||||
model_config = self.model_selector.select_model(user_tier, context_type)
|
||||
|
||||
# Build the prompt from template
|
||||
try:
|
||||
prompt = self.prompt_templates.render(
|
||||
"story_action.j2",
|
||||
character=character,
|
||||
action=action,
|
||||
game_state=game_state,
|
||||
conversation_history=conversation_history or [],
|
||||
world_context=world_context,
|
||||
max_tokens=model_config.max_tokens,
|
||||
action_instructions=action_instructions
|
||||
)
|
||||
except PromptTemplateError as e:
|
||||
logger.error("Failed to render story prompt", error=str(e))
|
||||
raise NarrativeGeneratorError(f"Prompt template error: {e}")
|
||||
|
||||
# Debug: Log the full prompt being sent
|
||||
logger.debug(
|
||||
"Full prompt being sent to AI",
|
||||
prompt_length=len(prompt),
|
||||
conversation_history_count=len(conversation_history) if conversation_history else 0,
|
||||
prompt_preview=prompt[:500] + "..." if len(prompt) > 500 else prompt
|
||||
)
|
||||
# For detailed debugging, uncomment the line below:
|
||||
print(f"\n{'='*60}\nFULL PROMPT:\n{'='*60}\n{prompt}\n{'='*60}\n")
|
||||
|
||||
# Get system prompt
|
||||
system_prompt = self.SYSTEM_PROMPTS[context_type]
|
||||
|
||||
# Generate response
|
||||
try:
|
||||
client = self._get_client(model_config)
|
||||
response = client.generate(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
max_tokens=model_config.max_tokens,
|
||||
temperature=model_config.temperature,
|
||||
model=model_config.model_type
|
||||
)
|
||||
except ReplicateClientError as e:
|
||||
logger.error(
|
||||
"AI generation failed",
|
||||
error=str(e),
|
||||
context_type=context_type.value
|
||||
)
|
||||
raise NarrativeGeneratorError(f"AI generation failed: {e}")
|
||||
|
||||
logger.info(
|
||||
"Story response generated",
|
||||
tokens_used=response.tokens_used,
|
||||
model=response.model,
|
||||
generation_time=f"{response.generation_time:.2f}s"
|
||||
)
|
||||
|
||||
return NarrativeResponse(
|
||||
narrative=response.text,
|
||||
tokens_used=response.tokens_used,
|
||||
tokens_input=response.tokens_input,
|
||||
tokens_output=response.tokens_output,
|
||||
model=response.model,
|
||||
context_type=context_type.value,
|
||||
generation_time=response.generation_time
|
||||
)
|
||||
|
||||
def generate_combat_narration(
|
||||
self,
|
||||
character: dict[str, Any],
|
||||
combat_state: dict[str, Any],
|
||||
action: str,
|
||||
action_result: dict[str, Any],
|
||||
user_tier: UserTier,
|
||||
is_critical: bool = False,
|
||||
is_finishing_blow: bool = False
|
||||
) -> NarrativeResponse:
|
||||
"""
|
||||
Generate narration for a combat action.
|
||||
|
||||
Args:
|
||||
character: Character data dictionary.
|
||||
combat_state: Current combat state with enemies, round number, etc.
|
||||
action: Description of the combat action taken.
|
||||
action_result: Result of the action (hit, damage, effects, etc.).
|
||||
user_tier: The user's subscription tier.
|
||||
is_critical: Whether this was a critical hit/miss.
|
||||
is_finishing_blow: Whether this defeats the enemy.
|
||||
|
||||
Returns:
|
||||
NarrativeResponse with combat narration.
|
||||
|
||||
Raises:
|
||||
NarrativeGeneratorError: If generation fails.
|
||||
|
||||
Example:
|
||||
>>> response = generator.generate_combat_narration(
|
||||
... character={"name": "Aldric", ...},
|
||||
... combat_state={"round_number": 3, "enemies": [...], ...},
|
||||
... action="swings their sword at the goblin",
|
||||
... action_result={"hit": True, "damage": 12, ...},
|
||||
... user_tier=UserTier.BASIC
|
||||
... )
|
||||
"""
|
||||
context_type = ContextType.COMBAT_NARRATION
|
||||
|
||||
logger.info(
|
||||
"Generating combat narration",
|
||||
character_name=character.get("name"),
|
||||
action=action[:50],
|
||||
is_critical=is_critical,
|
||||
is_finishing_blow=is_finishing_blow
|
||||
)
|
||||
|
||||
# Get model configuration
|
||||
model_config = self.model_selector.select_model(user_tier, context_type)
|
||||
|
||||
# Build the prompt
|
||||
try:
|
||||
prompt = self.prompt_templates.render(
|
||||
"combat_action.j2",
|
||||
character=character,
|
||||
combat_state=combat_state,
|
||||
action=action,
|
||||
action_result=action_result,
|
||||
is_critical=is_critical,
|
||||
is_finishing_blow=is_finishing_blow,
|
||||
max_tokens=model_config.max_tokens
|
||||
)
|
||||
except PromptTemplateError as e:
|
||||
logger.error("Failed to render combat prompt", error=str(e))
|
||||
raise NarrativeGeneratorError(f"Prompt template error: {e}")
|
||||
|
||||
# Generate response
|
||||
system_prompt = self.SYSTEM_PROMPTS[context_type]
|
||||
|
||||
try:
|
||||
client = self._get_client(model_config)
|
||||
response = client.generate(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
max_tokens=model_config.max_tokens,
|
||||
temperature=model_config.temperature,
|
||||
model=model_config.model_type
|
||||
)
|
||||
except ReplicateClientError as e:
|
||||
logger.error("Combat narration generation failed", error=str(e))
|
||||
raise NarrativeGeneratorError(f"AI generation failed: {e}")
|
||||
|
||||
logger.info(
|
||||
"Combat narration generated",
|
||||
tokens_used=response.tokens_used,
|
||||
generation_time=f"{response.generation_time:.2f}s"
|
||||
)
|
||||
|
||||
return NarrativeResponse(
|
||||
narrative=response.text,
|
||||
tokens_used=response.tokens_used,
|
||||
tokens_input=response.tokens_input,
|
||||
tokens_output=response.tokens_output,
|
||||
model=response.model,
|
||||
context_type=context_type.value,
|
||||
generation_time=response.generation_time
|
||||
)
|
||||
|
||||
def generate_quest_selection(
|
||||
self,
|
||||
character: dict[str, Any],
|
||||
eligible_quests: list[dict[str, Any]],
|
||||
game_context: dict[str, Any],
|
||||
user_tier: UserTier,
|
||||
recent_actions: list[str] | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Use AI to select the most contextually appropriate quest.
|
||||
|
||||
Args:
|
||||
character: Character data dictionary.
|
||||
eligible_quests: List of quest data dictionaries that can be offered.
|
||||
game_context: Current game context (location, events, etc.).
|
||||
user_tier: The user's subscription tier.
|
||||
recent_actions: Optional list of recent player actions.
|
||||
|
||||
Returns:
|
||||
The quest_id of the selected quest.
|
||||
|
||||
Raises:
|
||||
NarrativeGeneratorError: If generation fails or response is invalid.
|
||||
|
||||
Example:
|
||||
>>> quest_id = generator.generate_quest_selection(
|
||||
... character={"name": "Aldric", "level": 3, ...},
|
||||
... eligible_quests=[{"quest_id": "goblin_cave", ...}, ...],
|
||||
... game_context={"current_location": "Tavern", ...},
|
||||
... user_tier=UserTier.FREE
|
||||
... )
|
||||
>>> print(quest_id) # "goblin_cave"
|
||||
"""
|
||||
context_type = ContextType.QUEST_SELECTION
|
||||
|
||||
logger.info(
|
||||
"Generating quest selection",
|
||||
character_name=character.get("name"),
|
||||
num_eligible_quests=len(eligible_quests),
|
||||
location=game_context.get("current_location")
|
||||
)
|
||||
|
||||
if not eligible_quests:
|
||||
raise NarrativeGeneratorError("No eligible quests provided")
|
||||
|
||||
# Get model configuration
|
||||
model_config = self.model_selector.select_model(user_tier, context_type)
|
||||
|
||||
# Build the prompt
|
||||
try:
|
||||
prompt = self.prompt_templates.render(
|
||||
"quest_offering.j2",
|
||||
character=character,
|
||||
eligible_quests=eligible_quests,
|
||||
game_context=game_context,
|
||||
recent_actions=recent_actions or []
|
||||
)
|
||||
except PromptTemplateError as e:
|
||||
logger.error("Failed to render quest selection prompt", error=str(e))
|
||||
raise NarrativeGeneratorError(f"Prompt template error: {e}")
|
||||
|
||||
# Generate response
|
||||
system_prompt = self.SYSTEM_PROMPTS[context_type]
|
||||
|
||||
try:
|
||||
client = self._get_client(model_config)
|
||||
response = client.generate(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
max_tokens=model_config.max_tokens,
|
||||
temperature=model_config.temperature,
|
||||
model=model_config.model_type
|
||||
)
|
||||
except ReplicateClientError as e:
|
||||
logger.error("Quest selection generation failed", error=str(e))
|
||||
raise NarrativeGeneratorError(f"AI generation failed: {e}")
|
||||
|
||||
# Parse the response to get quest_id
|
||||
quest_id = response.text.strip().lower()
|
||||
|
||||
# Validate the response is a valid quest_id
|
||||
valid_quest_ids = {q.get("quest_id", "").lower() for q in eligible_quests}
|
||||
if quest_id not in valid_quest_ids:
|
||||
logger.warning(
|
||||
"AI returned invalid quest_id, using first eligible quest",
|
||||
returned_id=quest_id,
|
||||
valid_ids=list(valid_quest_ids)
|
||||
)
|
||||
quest_id = eligible_quests[0].get("quest_id", "")
|
||||
|
||||
logger.info(
|
||||
"Quest selected",
|
||||
quest_id=quest_id,
|
||||
tokens_used=response.tokens_used,
|
||||
generation_time=f"{response.generation_time:.2f}s"
|
||||
)
|
||||
|
||||
return quest_id
|
||||
|
||||
def generate_npc_dialogue(
|
||||
self,
|
||||
character: dict[str, Any],
|
||||
npc: dict[str, Any],
|
||||
conversation_topic: str,
|
||||
game_state: dict[str, Any],
|
||||
user_tier: UserTier,
|
||||
npc_relationship: str | None = None,
|
||||
previous_dialogue: list[dict[str, Any]] | None = None,
|
||||
npc_knowledge: list[str] | None = None
|
||||
) -> NarrativeResponse:
|
||||
"""
|
||||
Generate NPC dialogue in response to player conversation.
|
||||
|
||||
Args:
|
||||
character: Character data dictionary.
|
||||
npc: NPC data with name, role, personality, etc.
|
||||
conversation_topic: What the player said or wants to discuss.
|
||||
game_state: Current game state.
|
||||
user_tier: The user's subscription tier.
|
||||
npc_relationship: Optional description of relationship with NPC.
|
||||
previous_dialogue: Optional list of previous exchanges.
|
||||
npc_knowledge: Optional list of things this NPC knows about.
|
||||
|
||||
Returns:
|
||||
NarrativeResponse with NPC dialogue.
|
||||
|
||||
Raises:
|
||||
NarrativeGeneratorError: If generation fails.
|
||||
|
||||
Example:
|
||||
>>> response = generator.generate_npc_dialogue(
|
||||
... character={"name": "Aldric", ...},
|
||||
... npc={"name": "Old Barkeep", "role": "Tavern Owner", ...},
|
||||
... conversation_topic="What rumors have you heard lately?",
|
||||
... game_state={"current_location": "The Rusty Anchor", ...},
|
||||
... user_tier=UserTier.PREMIUM
|
||||
... )
|
||||
"""
|
||||
context_type = ContextType.NPC_DIALOGUE
|
||||
|
||||
logger.info(
|
||||
"Generating NPC dialogue",
|
||||
character_name=character.get("name"),
|
||||
npc_name=npc.get("name"),
|
||||
topic=conversation_topic[:50]
|
||||
)
|
||||
|
||||
# Get model configuration
|
||||
model_config = self.model_selector.select_model(user_tier, context_type)
|
||||
|
||||
# Build the prompt
|
||||
try:
|
||||
prompt = self.prompt_templates.render(
|
||||
"npc_dialogue.j2",
|
||||
character=character,
|
||||
npc=npc,
|
||||
conversation_topic=conversation_topic,
|
||||
game_state=game_state,
|
||||
npc_relationship=npc_relationship,
|
||||
previous_dialogue=previous_dialogue or [],
|
||||
npc_knowledge=npc_knowledge or [],
|
||||
max_tokens=model_config.max_tokens
|
||||
)
|
||||
except PromptTemplateError as e:
|
||||
logger.error("Failed to render NPC dialogue prompt", error=str(e))
|
||||
raise NarrativeGeneratorError(f"Prompt template error: {e}")
|
||||
|
||||
# Generate response
|
||||
system_prompt = self.SYSTEM_PROMPTS[context_type]
|
||||
|
||||
try:
|
||||
client = self._get_client(model_config)
|
||||
response = client.generate(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
max_tokens=model_config.max_tokens,
|
||||
temperature=model_config.temperature,
|
||||
model=model_config.model_type
|
||||
)
|
||||
except ReplicateClientError as e:
|
||||
logger.error("NPC dialogue generation failed", error=str(e))
|
||||
raise NarrativeGeneratorError(f"AI generation failed: {e}")
|
||||
|
||||
logger.info(
|
||||
"NPC dialogue generated",
|
||||
npc_name=npc.get("name"),
|
||||
tokens_used=response.tokens_used,
|
||||
generation_time=f"{response.generation_time:.2f}s"
|
||||
)
|
||||
|
||||
return NarrativeResponse(
|
||||
narrative=response.text,
|
||||
tokens_used=response.tokens_used,
|
||||
tokens_input=response.tokens_input,
|
||||
tokens_output=response.tokens_output,
|
||||
model=response.model,
|
||||
context_type=context_type.value,
|
||||
generation_time=response.generation_time
|
||||
)
|
||||
Reference in New Issue
Block a user