"""Session state and conversation history manager.""" from datetime import UTC, datetime from app.models.config import AppConfig from app.models.message import Message from app.utils.token_counter import TokenCounter class SessionContext: """In-memory conversation state manager. Tracks conversation history, token usage estimates, and session metadata. """ def __init__(self, config: AppConfig) -> None: """Initialize session context. Args: config: Application configuration. """ self._config = config self._history: list[Message] = [] self._token_counter = TokenCounter(config.agent.max_conversation_tokens) self._start_time = datetime.now(UTC) self._message_count: int = 0 def add_message(self, role: str, content: str | None = None, **kwargs: object) -> Message: """Create and append a message to conversation history. Args: role: Message role (system, user, assistant, tool). content: Text content of the message. **kwargs: Additional Message fields (tool_calls, tool_call_id, name). Returns: The created Message instance. """ message = Message(role=role, content=content, **kwargs) # type: ignore[arg-type] self._history.append(message) self._message_count += 1 return message def pop_last_message(self) -> Message | None: """Remove and return the last message, or None if history is empty.""" if self._history: self._message_count -= 1 return self._history.pop() return None def get_history(self) -> list[Message]: """Return a shallow copy of the conversation history.""" return list(self._history) def clear_history(self) -> None: """Clear conversation history and reset counters.""" self._history.clear() self._message_count = 0 self._token_counter = TokenCounter(self._config.agent.max_conversation_tokens) @property def estimated_tokens(self) -> int: """Estimated token count for the current conversation history.""" return self._token_counter.estimate_messages_tokens(self._history) @property def token_counter(self) -> TokenCounter: """The token counter instance.""" return self._token_counter @property def message_count(self) -> int: """Number of messages added to this session.""" return self._message_count @property def start_time(self) -> datetime: """Session start timestamp (UTC).""" return self._start_time def truncate_history(self, system_token_estimate: int = 0) -> int: """Drop oldest messages to bring token usage under budget. Preserves the first user message and the most recent N messages (configured by ``truncation_keep_recent``). Cleans up orphaned tool messages after truncation. Args: system_token_estimate: Estimated tokens used by the system prompt. Returns: Number of messages dropped. """ budget = self._token_counter.budget threshold = self._config.agent.truncation_threshold keep_recent = self._config.agent.truncation_keep_recent estimated = self._token_counter.estimate_messages_tokens(self._history) + system_token_estimate if estimated < threshold * budget: return 0 target = int(budget * 0.75) # headroom if len(self._history) <= keep_recent + 1: return 0 # Split: first user message | droppable middle | recent tail first_msg = self._history[0] if self._history and self._history[0].role == "user" else None start_idx = 1 if first_msg else 0 tail_start = max(start_idx, len(self._history) - keep_recent) dropped = 0 drop_indices: set[int] = set() for i in range(start_idx, tail_start): drop_indices.add(i) dropped += 1 # Recalculate with remaining messages remaining = [m for j, m in enumerate(self._history) if j not in drop_indices] est = self._token_counter.estimate_messages_tokens(remaining) + system_token_estimate if est < target: break if dropped == 0: return 0 self._history = [m for j, m in enumerate(self._history) if j not in drop_indices] # Clean up orphaned tool messages self._cleanup_orphaned_tool_messages() return dropped def _cleanup_orphaned_tool_messages(self) -> None: """Remove tool messages whose tool_call_id doesn't match any assistant tool_call.""" # Collect all tool_call IDs from assistant messages valid_tc_ids: set[str] = set() for msg in self._history: if msg.role == "assistant" and msg.tool_calls: for tc in msg.tool_calls: valid_tc_ids.add(tc.id) # Remove tool messages referencing missing tool calls self._history = [ msg for msg in self._history if msg.role != "tool" or (msg.tool_call_id and msg.tool_call_id in valid_tc_ids) ] def to_serializable(self) -> dict: """Export messages and token state for session persistence. Returns: Dict with messages and token usage data. """ return { "messages": [m.model_dump(exclude_none=True) for m in self._history], "token_usage": self._token_counter.cumulative_usage.model_dump(), } def restore_from(self, data: dict) -> None: """Clear and replay from serialized data. Args: data: Dict with messages and optional token_usage as produced by to_serializable(). """ self._history.clear() self._message_count = 0 for msg_data in data.get("messages", []): msg = Message(**msg_data) self._history.append(msg) self._message_count += 1 token_data = data.get("token_usage") if token_data: from app.utils.token_counter import TokenUsage usage = TokenUsage(**token_data) self._token_counter.count_usage(usage)