Files
Code_of_Conquest/api/app/services/rate_limiter_service.py
Phillip Tarrant 51f6041ee4 fix(api): remove reference to non-existent TIER_LIMITS attribute
The RateLimiterService.__init__ was logging self.TIER_LIMITS which doesn't
exist after refactoring to config-based tier limits. Changed to log the
existing DM_QUESTION_LIMITS attribute instead.
2025-11-26 10:07:35 -06:00

628 lines
18 KiB
Python

"""
Rate Limiter Service
This module implements tier-based rate limiting for AI requests using Redis
for distributed counting. Each user tier has a different daily limit for
AI-generated turns.
Usage:
from app.services.rate_limiter_service import RateLimiterService, RateLimitExceeded
from app.ai.model_selector import UserTier
# Initialize service
rate_limiter = RateLimiterService()
# Check and increment usage
try:
rate_limiter.check_rate_limit("user_123", UserTier.FREE)
rate_limiter.increment_usage("user_123")
except RateLimitExceeded as e:
print(f"Rate limit exceeded: {e}")
# Get remaining turns
remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE)
"""
from datetime import date, datetime, timezone, timedelta
from typing import Optional
from app.services.redis_service import RedisService, RedisServiceError
from app.ai.model_selector import UserTier
from app.utils.logging import get_logger
from app.config import get_config
# Initialize logger
logger = get_logger(__file__)
class RateLimitExceeded(Exception):
"""
Raised when a user has exceeded their daily rate limit.
Attributes:
user_id: The user who exceeded the limit
user_tier: The user's subscription tier
limit: The daily limit for their tier
current_usage: The current usage count
reset_time: UTC timestamp when the limit resets
"""
def __init__(
self,
user_id: str,
user_tier: UserTier,
limit: int,
current_usage: int,
reset_time: datetime
):
self.user_id = user_id
self.user_tier = user_tier
self.limit = limit
self.current_usage = current_usage
self.reset_time = reset_time
message = (
f"Rate limit exceeded for user {user_id} ({user_tier.value} tier). "
f"Used {current_usage}/{limit} turns. Resets at {reset_time.isoformat()}"
)
super().__init__(message)
class RateLimiterService:
"""
Service for managing tier-based rate limiting.
This service uses Redis to track daily AI usage per user and enforces
limits based on subscription tier. Counters reset daily at midnight UTC.
Tier limits are loaded from config (rate_limiting.tiers.{tier}.ai_calls_per_day).
A value of -1 means unlimited.
Attributes:
redis: RedisService instance for counter storage
"""
# Daily DM question limits per tier
DM_QUESTION_LIMITS = {
UserTier.FREE: 10,
UserTier.BASIC: 20,
UserTier.PREMIUM: 50,
UserTier.ELITE: -1, # -1 means unlimited
}
# Redis key prefix for rate limit counters
KEY_PREFIX = "rate_limit:daily:"
DM_QUESTION_PREFIX = "rate_limit:dm_questions:"
def __init__(self, redis_service: Optional[RedisService] = None):
"""
Initialize the rate limiter service.
Args:
redis_service: Optional RedisService instance. If not provided,
a new instance will be created.
"""
self.redis = redis_service or RedisService()
logger.info(
"RateLimiterService initialized",
dm_question_limits=self.DM_QUESTION_LIMITS
)
def _get_daily_key(self, user_id: str, day: Optional[date] = None) -> str:
"""
Generate the Redis key for a user's daily counter.
Args:
user_id: The user ID
day: The date (defaults to today UTC)
Returns:
Redis key in format "rate_limit:daily:user_id:YYYY-MM-DD"
"""
if day is None:
day = datetime.now(timezone.utc).date()
return f"{self.KEY_PREFIX}{user_id}:{day.isoformat()}"
def _get_seconds_until_midnight_utc(self) -> int:
"""
Calculate seconds remaining until midnight UTC.
Returns:
Number of seconds until the next UTC midnight
"""
now = datetime.now(timezone.utc)
tomorrow = datetime(
now.year, now.month, now.day,
tzinfo=timezone.utc
) + timedelta(days=1)
return int((tomorrow - now).total_seconds())
def _get_reset_time(self) -> datetime:
"""
Get the UTC datetime when the rate limit resets.
Returns:
Datetime of next midnight UTC
"""
now = datetime.now(timezone.utc)
return datetime(
now.year, now.month, now.day,
tzinfo=timezone.utc
) + timedelta(days=1)
def get_limit_for_tier(self, user_tier: UserTier) -> int:
"""
Get the daily turn limit for a specific tier from config.
Args:
user_tier: The user's subscription tier
Returns:
Daily turn limit for the tier (-1 means unlimited)
"""
config = get_config()
tier_name = user_tier.value.lower()
tier_config = config.rate_limiting.tiers.get(tier_name)
if tier_config:
return tier_config.ai_calls_per_day
# Fallback to default if tier not found in config
logger.warning(
"Tier not found in config, using default limit",
tier=tier_name
)
return 50 # Default fallback
def get_current_usage(self, user_id: str) -> int:
"""
Get the current daily usage count for a user.
Args:
user_id: The user ID to check
Returns:
Current usage count (0 if no usage today)
Raises:
RedisServiceError: If Redis operation fails
"""
key = self._get_daily_key(user_id)
try:
value = self.redis.get(key)
usage = int(value) if value else 0
logger.debug(
"Retrieved current usage",
user_id=user_id,
usage=usage
)
return usage
except (ValueError, TypeError) as e:
logger.error(
"Invalid usage value in Redis",
user_id=user_id,
error=str(e)
)
return 0
def check_rate_limit(self, user_id: str, user_tier: UserTier) -> None:
"""
Check if a user has exceeded their daily rate limit.
This method checks the current usage against the tier limit and
raises an exception if the limit has been reached.
Args:
user_id: The user ID to check
user_tier: The user's subscription tier
Raises:
RateLimitExceeded: If the user has reached their daily limit
RedisServiceError: If Redis operation fails
"""
limit = self.get_limit_for_tier(user_tier)
# -1 means unlimited
if limit == -1:
logger.debug(
"Rate limit check passed (unlimited)",
user_id=user_id,
user_tier=user_tier.value
)
return
current_usage = self.get_current_usage(user_id)
if current_usage >= limit:
reset_time = self._get_reset_time()
logger.warning(
"Rate limit exceeded",
user_id=user_id,
user_tier=user_tier.value,
current_usage=current_usage,
limit=limit,
reset_time=reset_time.isoformat()
)
raise RateLimitExceeded(
user_id=user_id,
user_tier=user_tier,
limit=limit,
current_usage=current_usage,
reset_time=reset_time
)
logger.debug(
"Rate limit check passed",
user_id=user_id,
user_tier=user_tier.value,
current_usage=current_usage,
limit=limit
)
def increment_usage(self, user_id: str) -> int:
"""
Increment the daily usage counter for a user.
This method should be called after successfully processing an AI request.
The counter will automatically expire at midnight UTC.
Args:
user_id: The user ID to increment
Returns:
The new usage count after incrementing
Raises:
RedisServiceError: If Redis operation fails
"""
key = self._get_daily_key(user_id)
# Increment the counter
new_count = self.redis.incr(key)
# Set expiration if this is the first increment (new_count == 1)
# This ensures the key expires at midnight UTC
if new_count == 1:
ttl = self._get_seconds_until_midnight_utc()
self.redis.expire(key, ttl)
logger.debug(
"Set expiration on new rate limit key",
user_id=user_id,
ttl=ttl
)
logger.info(
"Incremented usage counter",
user_id=user_id,
new_count=new_count
)
return new_count
def get_remaining_turns(self, user_id: str, user_tier: UserTier) -> int:
"""
Get the number of remaining turns for a user today.
Args:
user_id: The user ID to check
user_tier: The user's subscription tier
Returns:
Number of turns remaining (-1 if unlimited, 0 if limit reached)
"""
limit = self.get_limit_for_tier(user_tier)
# -1 means unlimited
if limit == -1:
return -1
current_usage = self.get_current_usage(user_id)
remaining = max(0, limit - current_usage)
logger.debug(
"Calculated remaining turns",
user_id=user_id,
user_tier=user_tier.value,
current_usage=current_usage,
limit=limit,
remaining=remaining
)
return remaining
def get_usage_info(self, user_id: str, user_tier: UserTier) -> dict:
"""
Get comprehensive usage information for a user.
Args:
user_id: The user ID to check
user_tier: The user's subscription tier
Returns:
Dictionary with usage info:
- user_id: User identifier
- user_tier: Subscription tier
- current_usage: Current daily usage
- daily_limit: Daily limit for tier (-1 means unlimited)
- remaining: Remaining turns (-1 if unlimited)
- reset_time: ISO format UTC reset time
- is_limited: Whether limit has been reached (always False if unlimited)
- is_unlimited: Whether user has unlimited turns
"""
current_usage = self.get_current_usage(user_id)
limit = self.get_limit_for_tier(user_tier)
reset_time = self._get_reset_time()
# Handle unlimited tier (-1)
is_unlimited = (limit == -1)
if is_unlimited:
remaining = -1
is_limited = False
else:
remaining = max(0, limit - current_usage)
is_limited = current_usage >= limit
info = {
"user_id": user_id,
"user_tier": user_tier.value,
"current_usage": current_usage,
"daily_limit": limit,
"remaining": remaining,
"reset_time": reset_time.isoformat(),
"is_limited": is_limited,
"is_unlimited": is_unlimited
}
logger.debug("Retrieved usage info", **info)
return info
def reset_usage(self, user_id: str) -> bool:
"""
Reset the daily usage counter for a user.
This is primarily for admin/testing purposes.
Args:
user_id: The user ID to reset
Returns:
True if the counter was deleted, False if it didn't exist
Raises:
RedisServiceError: If Redis operation fails
"""
key = self._get_daily_key(user_id)
deleted = self.redis.delete(key)
logger.info(
"Reset usage counter",
user_id=user_id,
deleted=deleted > 0
)
return deleted > 0
# ===== DM QUESTION RATE LIMITING =====
def _get_dm_question_key(self, user_id: str, day: Optional[date] = None) -> str:
"""
Generate the Redis key for a user's daily DM question counter.
Args:
user_id: The user ID
day: The date (defaults to today UTC)
Returns:
Redis key in format "rate_limit:dm_questions:user_id:YYYY-MM-DD"
"""
if day is None:
day = datetime.now(timezone.utc).date()
return f"{self.DM_QUESTION_PREFIX}{user_id}:{day.isoformat()}"
def get_dm_question_limit_for_tier(self, user_tier: UserTier) -> int:
"""
Get the daily DM question limit for a specific tier.
Args:
user_tier: The user's subscription tier
Returns:
Daily DM question limit for the tier (-1 for unlimited)
"""
return self.DM_QUESTION_LIMITS.get(user_tier, self.DM_QUESTION_LIMITS[UserTier.FREE])
def get_current_dm_usage(self, user_id: str) -> int:
"""
Get the current daily DM question usage count for a user.
Args:
user_id: The user ID to check
Returns:
Current DM question usage count (0 if no usage today)
Raises:
RedisServiceError: If Redis operation fails
"""
key = self._get_dm_question_key(user_id)
try:
value = self.redis.get(key)
usage = int(value) if value else 0
logger.debug(
"Retrieved current DM question usage",
user_id=user_id,
usage=usage
)
return usage
except (ValueError, TypeError) as e:
logger.error(
"Invalid DM question usage value in Redis",
user_id=user_id,
error=str(e)
)
return 0
def check_dm_question_limit(self, user_id: str, user_tier: UserTier) -> None:
"""
Check if a user has exceeded their daily DM question limit.
Args:
user_id: The user ID to check
user_tier: The user's subscription tier
Raises:
RateLimitExceeded: If the user has reached their daily DM question limit
RedisServiceError: If Redis operation fails
"""
limit = self.get_dm_question_limit_for_tier(user_tier)
# -1 means unlimited
if limit == -1:
logger.debug(
"DM question limit check passed (unlimited)",
user_id=user_id,
user_tier=user_tier.value
)
return
current_usage = self.get_current_dm_usage(user_id)
if current_usage >= limit:
reset_time = self._get_reset_time()
logger.warning(
"DM question limit exceeded",
user_id=user_id,
user_tier=user_tier.value,
current_usage=current_usage,
limit=limit,
reset_time=reset_time.isoformat()
)
raise RateLimitExceeded(
user_id=user_id,
user_tier=user_tier,
limit=limit,
current_usage=current_usage,
reset_time=reset_time
)
logger.debug(
"DM question limit check passed",
user_id=user_id,
user_tier=user_tier.value,
current_usage=current_usage,
limit=limit
)
def increment_dm_usage(self, user_id: str) -> int:
"""
Increment the daily DM question counter for a user.
Args:
user_id: The user ID to increment
Returns:
The new DM question usage count after incrementing
Raises:
RedisServiceError: If Redis operation fails
"""
key = self._get_dm_question_key(user_id)
# Increment the counter
new_count = self.redis.incr(key)
# Set expiration if this is the first increment
if new_count == 1:
ttl = self._get_seconds_until_midnight_utc()
self.redis.expire(key, ttl)
logger.debug(
"Set expiration on new DM question key",
user_id=user_id,
ttl=ttl
)
logger.info(
"Incremented DM question counter",
user_id=user_id,
new_count=new_count
)
return new_count
def get_remaining_dm_questions(self, user_id: str, user_tier: UserTier) -> int:
"""
Get the number of remaining DM questions for a user today.
Args:
user_id: The user ID to check
user_tier: The user's subscription tier
Returns:
Number of DM questions remaining (-1 if unlimited, 0 if limit reached)
"""
limit = self.get_dm_question_limit_for_tier(user_tier)
# -1 means unlimited
if limit == -1:
return -1
current_usage = self.get_current_dm_usage(user_id)
remaining = max(0, limit - current_usage)
logger.debug(
"Calculated remaining DM questions",
user_id=user_id,
user_tier=user_tier.value,
current_usage=current_usage,
limit=limit,
remaining=remaining
)
return remaining
def reset_dm_usage(self, user_id: str) -> bool:
"""
Reset the daily DM question counter for a user.
This is primarily for admin/testing purposes.
Args:
user_id: The user ID to reset
Returns:
True if the counter was deleted, False if it didn't exist
Raises:
RedisServiceError: If Redis operation fails
"""
key = self._get_dm_question_key(user_id)
deleted = self.redis.delete(key)
logger.info(
"Reset DM question counter",
user_id=user_id,
deleted=deleted > 0
)
return deleted > 0