- Config extensions: retry backoff, truncation threshold, session persistence - LLM retry with exponential backoff + jitter on transient errors (5xx, connection) - Conversation truncation: drops oldest messages preserving first user + recent N - Session persistence: auto-save/restore with atomic writes, cleanup of old files - Graceful shutdown: SIGTERM handler, cancel() on AgentLoop, save-on-exit - Partial message recovery on mid-stream interruption - New slash commands: /save, /session - 18 new tests (5 retry, 5 truncation, 4 session, 4 integration workflows) - README.md and docs/tools.md documentation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
178 lines
6.2 KiB
Python
178 lines
6.2 KiB
Python
"""Session state and conversation history manager."""
|
|
|
|
from datetime import UTC, datetime
|
|
|
|
from app.models.config import AppConfig
|
|
from app.models.message import Message
|
|
from app.utils.token_counter import TokenCounter
|
|
|
|
|
|
class SessionContext:
|
|
"""In-memory conversation state manager.
|
|
|
|
Tracks conversation history, token usage estimates, and session metadata.
|
|
"""
|
|
|
|
def __init__(self, config: AppConfig) -> None:
|
|
"""Initialize session context.
|
|
|
|
Args:
|
|
config: Application configuration.
|
|
"""
|
|
self._config = config
|
|
self._history: list[Message] = []
|
|
self._token_counter = TokenCounter(config.agent.max_conversation_tokens)
|
|
self._start_time = datetime.now(UTC)
|
|
self._message_count: int = 0
|
|
|
|
def add_message(self, role: str, content: str | None = None, **kwargs: object) -> Message:
|
|
"""Create and append a message to conversation history.
|
|
|
|
Args:
|
|
role: Message role (system, user, assistant, tool).
|
|
content: Text content of the message.
|
|
**kwargs: Additional Message fields (tool_calls, tool_call_id, name).
|
|
|
|
Returns:
|
|
The created Message instance.
|
|
"""
|
|
message = Message(role=role, content=content, **kwargs) # type: ignore[arg-type]
|
|
self._history.append(message)
|
|
self._message_count += 1
|
|
return message
|
|
|
|
def pop_last_message(self) -> Message | None:
|
|
"""Remove and return the last message, or None if history is empty."""
|
|
if self._history:
|
|
self._message_count -= 1
|
|
return self._history.pop()
|
|
return None
|
|
|
|
def get_history(self) -> list[Message]:
|
|
"""Return a shallow copy of the conversation history."""
|
|
return list(self._history)
|
|
|
|
def clear_history(self) -> None:
|
|
"""Clear conversation history and reset counters."""
|
|
self._history.clear()
|
|
self._message_count = 0
|
|
self._token_counter = TokenCounter(self._config.agent.max_conversation_tokens)
|
|
|
|
@property
|
|
def estimated_tokens(self) -> int:
|
|
"""Estimated token count for the current conversation history."""
|
|
return self._token_counter.estimate_messages_tokens(self._history)
|
|
|
|
@property
|
|
def token_counter(self) -> TokenCounter:
|
|
"""The token counter instance."""
|
|
return self._token_counter
|
|
|
|
@property
|
|
def message_count(self) -> int:
|
|
"""Number of messages added to this session."""
|
|
return self._message_count
|
|
|
|
@property
|
|
def start_time(self) -> datetime:
|
|
"""Session start timestamp (UTC)."""
|
|
return self._start_time
|
|
|
|
def truncate_history(self, system_token_estimate: int = 0) -> int:
|
|
"""Drop oldest messages to bring token usage under budget.
|
|
|
|
Preserves the first user message and the most recent N messages
|
|
(configured by ``truncation_keep_recent``). Cleans up orphaned tool
|
|
messages after truncation.
|
|
|
|
Args:
|
|
system_token_estimate: Estimated tokens used by the system prompt.
|
|
|
|
Returns:
|
|
Number of messages dropped.
|
|
"""
|
|
budget = self._token_counter.budget
|
|
threshold = self._config.agent.truncation_threshold
|
|
keep_recent = self._config.agent.truncation_keep_recent
|
|
|
|
estimated = self._token_counter.estimate_messages_tokens(self._history) + system_token_estimate
|
|
if estimated < threshold * budget:
|
|
return 0
|
|
|
|
target = int(budget * 0.75) # headroom
|
|
if len(self._history) <= keep_recent + 1:
|
|
return 0
|
|
|
|
# Split: first user message | droppable middle | recent tail
|
|
first_msg = self._history[0] if self._history and self._history[0].role == "user" else None
|
|
start_idx = 1 if first_msg else 0
|
|
tail_start = max(start_idx, len(self._history) - keep_recent)
|
|
|
|
dropped = 0
|
|
drop_indices: set[int] = set()
|
|
|
|
for i in range(start_idx, tail_start):
|
|
drop_indices.add(i)
|
|
dropped += 1
|
|
# Recalculate with remaining messages
|
|
remaining = [m for j, m in enumerate(self._history) if j not in drop_indices]
|
|
est = self._token_counter.estimate_messages_tokens(remaining) + system_token_estimate
|
|
if est < target:
|
|
break
|
|
|
|
if dropped == 0:
|
|
return 0
|
|
|
|
self._history = [m for j, m in enumerate(self._history) if j not in drop_indices]
|
|
|
|
# Clean up orphaned tool messages
|
|
self._cleanup_orphaned_tool_messages()
|
|
|
|
return dropped
|
|
|
|
def _cleanup_orphaned_tool_messages(self) -> None:
|
|
"""Remove tool messages whose tool_call_id doesn't match any assistant tool_call."""
|
|
# Collect all tool_call IDs from assistant messages
|
|
valid_tc_ids: set[str] = set()
|
|
for msg in self._history:
|
|
if msg.role == "assistant" and msg.tool_calls:
|
|
for tc in msg.tool_calls:
|
|
valid_tc_ids.add(tc.id)
|
|
|
|
# Remove tool messages referencing missing tool calls
|
|
self._history = [
|
|
msg for msg in self._history
|
|
if msg.role != "tool" or (msg.tool_call_id and msg.tool_call_id in valid_tc_ids)
|
|
]
|
|
|
|
def to_serializable(self) -> dict:
|
|
"""Export messages and token state for session persistence.
|
|
|
|
Returns:
|
|
Dict with messages and token usage data.
|
|
"""
|
|
return {
|
|
"messages": [m.model_dump(exclude_none=True) for m in self._history],
|
|
"token_usage": self._token_counter.cumulative_usage.model_dump(),
|
|
}
|
|
|
|
def restore_from(self, data: dict) -> None:
|
|
"""Clear and replay from serialized data.
|
|
|
|
Args:
|
|
data: Dict with messages and optional token_usage as produced by to_serializable().
|
|
"""
|
|
self._history.clear()
|
|
self._message_count = 0
|
|
|
|
for msg_data in data.get("messages", []):
|
|
msg = Message(**msg_data)
|
|
self._history.append(msg)
|
|
self._message_count += 1
|
|
|
|
token_data = data.get("token_usage")
|
|
if token_data:
|
|
from app.utils.token_counter import TokenUsage
|
|
usage = TokenUsage(**token_data)
|
|
self._token_counter.count_usage(usage)
|