94 lines
3.0 KiB
Python
94 lines
3.0 KiB
Python
"""Approximate token counting for conversation budget management."""
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from app.models.message import Message
|
|
|
|
|
|
class TokenUsage(BaseModel):
|
|
"""Snapshot of token usage for a single LLM call."""
|
|
|
|
prompt_tokens: int = Field(default=0, description="Tokens in the prompt")
|
|
completion_tokens: int = Field(default=0, description="Tokens in the completion")
|
|
total_tokens: int = Field(default=0, description="Total tokens used")
|
|
|
|
|
|
class TokenCounter:
|
|
"""Tracks cumulative token usage with character-based estimation.
|
|
|
|
Uses a simple heuristic of ~4 characters per token for estimation.
|
|
This is intentionally approximate — accurate enough for budget tracking.
|
|
"""
|
|
|
|
CHARS_PER_TOKEN: int = 4
|
|
|
|
def __init__(self, budget: int = 32_000) -> None:
|
|
"""Initialize the token counter.
|
|
|
|
Args:
|
|
budget: Maximum token budget for the conversation.
|
|
"""
|
|
self._budget = budget
|
|
self._cumulative = TokenUsage()
|
|
|
|
@property
|
|
def budget(self) -> int:
|
|
"""The configured token budget."""
|
|
return self._budget
|
|
|
|
@property
|
|
def cumulative_usage(self) -> TokenUsage:
|
|
"""Cumulative token usage across all tracked calls."""
|
|
return self._cumulative
|
|
|
|
@property
|
|
def remaining_budget(self) -> int:
|
|
"""Estimated tokens remaining before hitting the budget."""
|
|
return max(0, self._budget - self._cumulative.total_tokens)
|
|
|
|
def estimate_tokens(self, text: str) -> int:
|
|
"""Estimate token count for a string using character heuristic.
|
|
|
|
Args:
|
|
text: The text to estimate tokens for.
|
|
|
|
Returns:
|
|
Estimated token count.
|
|
"""
|
|
return max(1, len(text) // self.CHARS_PER_TOKEN)
|
|
|
|
def estimate_messages_tokens(self, messages: list[Message]) -> int:
|
|
"""Estimate total tokens for a list of messages.
|
|
|
|
Args:
|
|
messages: List of conversation messages.
|
|
|
|
Returns:
|
|
Estimated total token count.
|
|
"""
|
|
total = 0
|
|
for msg in messages:
|
|
if msg.content:
|
|
total += self.estimate_tokens(msg.content)
|
|
if msg.tool_calls:
|
|
for tc in msg.tool_calls:
|
|
total += self.estimate_tokens(tc.function.name)
|
|
total += self.estimate_tokens(tc.function.arguments)
|
|
# Per-message overhead (role, formatting)
|
|
total += 4
|
|
return total
|
|
|
|
def count_usage(self, usage: TokenUsage) -> None:
|
|
"""Record token usage from an LLM call.
|
|
|
|
Args:
|
|
usage: Token usage from a single call.
|
|
"""
|
|
self._cumulative.prompt_tokens += usage.prompt_tokens
|
|
self._cumulative.completion_tokens += usage.completion_tokens
|
|
self._cumulative.total_tokens += usage.total_tokens
|
|
|
|
def is_over_budget(self) -> bool:
|
|
"""Check if cumulative usage has exceeded the token budget."""
|
|
return self._cumulative.total_tokens >= self._budget
|