Files
SneakyCode/app/utils/token_counter.py
2026-03-11 07:21:21 -05:00

94 lines
3.0 KiB
Python

"""Approximate token counting for conversation budget management."""
from pydantic import BaseModel, Field
from app.models.message import Message
class TokenUsage(BaseModel):
"""Snapshot of token usage for a single LLM call."""
prompt_tokens: int = Field(default=0, description="Tokens in the prompt")
completion_tokens: int = Field(default=0, description="Tokens in the completion")
total_tokens: int = Field(default=0, description="Total tokens used")
class TokenCounter:
"""Tracks cumulative token usage with character-based estimation.
Uses a simple heuristic of ~4 characters per token for estimation.
This is intentionally approximate — accurate enough for budget tracking.
"""
CHARS_PER_TOKEN: int = 4
def __init__(self, budget: int = 32_000) -> None:
"""Initialize the token counter.
Args:
budget: Maximum token budget for the conversation.
"""
self._budget = budget
self._cumulative = TokenUsage()
@property
def budget(self) -> int:
"""The configured token budget."""
return self._budget
@property
def cumulative_usage(self) -> TokenUsage:
"""Cumulative token usage across all tracked calls."""
return self._cumulative
@property
def remaining_budget(self) -> int:
"""Estimated tokens remaining before hitting the budget."""
return max(0, self._budget - self._cumulative.total_tokens)
def estimate_tokens(self, text: str) -> int:
"""Estimate token count for a string using character heuristic.
Args:
text: The text to estimate tokens for.
Returns:
Estimated token count.
"""
return max(1, len(text) // self.CHARS_PER_TOKEN)
def estimate_messages_tokens(self, messages: list[Message]) -> int:
"""Estimate total tokens for a list of messages.
Args:
messages: List of conversation messages.
Returns:
Estimated total token count.
"""
total = 0
for msg in messages:
if msg.content:
total += self.estimate_tokens(msg.content)
if msg.tool_calls:
for tc in msg.tool_calls:
total += self.estimate_tokens(tc.function.name)
total += self.estimate_tokens(tc.function.arguments)
# Per-message overhead (role, formatting)
total += 4
return total
def count_usage(self, usage: TokenUsage) -> None:
"""Record token usage from an LLM call.
Args:
usage: Token usage from a single call.
"""
self._cumulative.prompt_tokens += usage.prompt_tokens
self._cumulative.completion_tokens += usage.completion_tokens
self._cumulative.total_tokens += usage.total_tokens
def is_over_budget(self) -> bool:
"""Check if cumulative usage has exceeded the token budget."""
return self._cumulative.total_tokens >= self._budget