"""Approximate token counting for conversation budget management.""" from pydantic import BaseModel, Field from app.models.message import Message class TokenUsage(BaseModel): """Snapshot of token usage for a single LLM call.""" prompt_tokens: int = Field(default=0, description="Tokens in the prompt") completion_tokens: int = Field(default=0, description="Tokens in the completion") total_tokens: int = Field(default=0, description="Total tokens used") class TokenCounter: """Tracks cumulative token usage with character-based estimation. Uses a simple heuristic of ~4 characters per token for estimation. This is intentionally approximate — accurate enough for budget tracking. """ CHARS_PER_TOKEN: int = 4 def __init__(self, budget: int = 32_000) -> None: """Initialize the token counter. Args: budget: Maximum token budget for the conversation. """ self._budget = budget self._cumulative = TokenUsage() @property def budget(self) -> int: """The configured token budget.""" return self._budget @property def cumulative_usage(self) -> TokenUsage: """Cumulative token usage across all tracked calls.""" return self._cumulative @property def remaining_budget(self) -> int: """Estimated tokens remaining before hitting the budget.""" return max(0, self._budget - self._cumulative.total_tokens) def estimate_tokens(self, text: str) -> int: """Estimate token count for a string using character heuristic. Args: text: The text to estimate tokens for. Returns: Estimated token count. """ return max(1, len(text) // self.CHARS_PER_TOKEN) def estimate_messages_tokens(self, messages: list[Message]) -> int: """Estimate total tokens for a list of messages. Args: messages: List of conversation messages. Returns: Estimated total token count. """ total = 0 for msg in messages: if msg.content: total += self.estimate_tokens(msg.content) if msg.tool_calls: for tc in msg.tool_calls: total += self.estimate_tokens(tc.function.name) total += self.estimate_tokens(tc.function.arguments) # Per-message overhead (role, formatting) total += 4 return total def count_usage(self, usage: TokenUsage) -> None: """Record token usage from an LLM call. Args: usage: Token usage from a single call. """ self._cumulative.prompt_tokens += usage.prompt_tokens self._cumulative.completion_tokens += usage.completion_tokens self._cumulative.total_tokens += usage.total_tokens def is_over_budget(self) -> bool: """Check if cumulative usage has exceeded the token budget.""" return self._cumulative.total_tokens >= self._budget