603 lines
17 KiB
Python
603 lines
17 KiB
Python
"""
|
|
Rate Limiter Service
|
|
|
|
This module implements tier-based rate limiting for AI requests using Redis
|
|
for distributed counting. Each user tier has a different daily limit for
|
|
AI-generated turns.
|
|
|
|
Usage:
|
|
from app.services.rate_limiter_service import RateLimiterService, RateLimitExceeded
|
|
from app.ai.model_selector import UserTier
|
|
|
|
# Initialize service
|
|
rate_limiter = RateLimiterService()
|
|
|
|
# Check and increment usage
|
|
try:
|
|
rate_limiter.check_rate_limit("user_123", UserTier.FREE)
|
|
rate_limiter.increment_usage("user_123")
|
|
except RateLimitExceeded as e:
|
|
print(f"Rate limit exceeded: {e}")
|
|
|
|
# Get remaining turns
|
|
remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE)
|
|
"""
|
|
|
|
from datetime import date, datetime, timezone, timedelta
|
|
from typing import Optional
|
|
|
|
from app.services.redis_service import RedisService, RedisServiceError
|
|
from app.ai.model_selector import UserTier
|
|
from app.utils.logging import get_logger
|
|
|
|
|
|
# Initialize logger
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
class RateLimitExceeded(Exception):
|
|
"""
|
|
Raised when a user has exceeded their daily rate limit.
|
|
|
|
Attributes:
|
|
user_id: The user who exceeded the limit
|
|
user_tier: The user's subscription tier
|
|
limit: The daily limit for their tier
|
|
current_usage: The current usage count
|
|
reset_time: UTC timestamp when the limit resets
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
user_id: str,
|
|
user_tier: UserTier,
|
|
limit: int,
|
|
current_usage: int,
|
|
reset_time: datetime
|
|
):
|
|
self.user_id = user_id
|
|
self.user_tier = user_tier
|
|
self.limit = limit
|
|
self.current_usage = current_usage
|
|
self.reset_time = reset_time
|
|
|
|
message = (
|
|
f"Rate limit exceeded for user {user_id} ({user_tier.value} tier). "
|
|
f"Used {current_usage}/{limit} turns. Resets at {reset_time.isoformat()}"
|
|
)
|
|
super().__init__(message)
|
|
|
|
|
|
class RateLimiterService:
|
|
"""
|
|
Service for managing tier-based rate limiting.
|
|
|
|
This service uses Redis to track daily AI usage per user and enforces
|
|
limits based on subscription tier. Counters reset daily at midnight UTC.
|
|
|
|
Tier Limits:
|
|
- Free: 20 turns/day
|
|
- Basic: 50 turns/day
|
|
- Premium: 100 turns/day
|
|
- Elite: 200 turns/day
|
|
|
|
Attributes:
|
|
redis: RedisService instance for counter storage
|
|
tier_limits: Mapping of tier to daily turn limit
|
|
"""
|
|
|
|
# Daily turn limits per tier
|
|
TIER_LIMITS = {
|
|
UserTier.FREE: 20,
|
|
UserTier.BASIC: 50,
|
|
UserTier.PREMIUM: 100,
|
|
UserTier.ELITE: 200,
|
|
}
|
|
|
|
# Daily DM question limits per tier
|
|
DM_QUESTION_LIMITS = {
|
|
UserTier.FREE: 10,
|
|
UserTier.BASIC: 20,
|
|
UserTier.PREMIUM: 50,
|
|
UserTier.ELITE: -1, # -1 means unlimited
|
|
}
|
|
|
|
# Redis key prefix for rate limit counters
|
|
KEY_PREFIX = "rate_limit:daily:"
|
|
DM_QUESTION_PREFIX = "rate_limit:dm_questions:"
|
|
|
|
def __init__(self, redis_service: Optional[RedisService] = None):
|
|
"""
|
|
Initialize the rate limiter service.
|
|
|
|
Args:
|
|
redis_service: Optional RedisService instance. If not provided,
|
|
a new instance will be created.
|
|
"""
|
|
self.redis = redis_service or RedisService()
|
|
|
|
logger.info(
|
|
"RateLimiterService initialized",
|
|
tier_limits=self.TIER_LIMITS
|
|
)
|
|
|
|
def _get_daily_key(self, user_id: str, day: Optional[date] = None) -> str:
|
|
"""
|
|
Generate the Redis key for a user's daily counter.
|
|
|
|
Args:
|
|
user_id: The user ID
|
|
day: The date (defaults to today UTC)
|
|
|
|
Returns:
|
|
Redis key in format "rate_limit:daily:user_id:YYYY-MM-DD"
|
|
"""
|
|
if day is None:
|
|
day = datetime.now(timezone.utc).date()
|
|
|
|
return f"{self.KEY_PREFIX}{user_id}:{day.isoformat()}"
|
|
|
|
def _get_seconds_until_midnight_utc(self) -> int:
|
|
"""
|
|
Calculate seconds remaining until midnight UTC.
|
|
|
|
Returns:
|
|
Number of seconds until the next UTC midnight
|
|
"""
|
|
now = datetime.now(timezone.utc)
|
|
tomorrow = datetime(
|
|
now.year, now.month, now.day,
|
|
tzinfo=timezone.utc
|
|
) + timedelta(days=1)
|
|
|
|
return int((tomorrow - now).total_seconds())
|
|
|
|
def _get_reset_time(self) -> datetime:
|
|
"""
|
|
Get the UTC datetime when the rate limit resets.
|
|
|
|
Returns:
|
|
Datetime of next midnight UTC
|
|
"""
|
|
now = datetime.now(timezone.utc)
|
|
return datetime(
|
|
now.year, now.month, now.day,
|
|
tzinfo=timezone.utc
|
|
) + timedelta(days=1)
|
|
|
|
def get_limit_for_tier(self, user_tier: UserTier) -> int:
|
|
"""
|
|
Get the daily turn limit for a specific tier.
|
|
|
|
Args:
|
|
user_tier: The user's subscription tier
|
|
|
|
Returns:
|
|
Daily turn limit for the tier
|
|
"""
|
|
return self.TIER_LIMITS.get(user_tier, self.TIER_LIMITS[UserTier.FREE])
|
|
|
|
def get_current_usage(self, user_id: str) -> int:
|
|
"""
|
|
Get the current daily usage count for a user.
|
|
|
|
Args:
|
|
user_id: The user ID to check
|
|
|
|
Returns:
|
|
Current usage count (0 if no usage today)
|
|
|
|
Raises:
|
|
RedisServiceError: If Redis operation fails
|
|
"""
|
|
key = self._get_daily_key(user_id)
|
|
|
|
try:
|
|
value = self.redis.get(key)
|
|
usage = int(value) if value else 0
|
|
|
|
logger.debug(
|
|
"Retrieved current usage",
|
|
user_id=user_id,
|
|
usage=usage
|
|
)
|
|
|
|
return usage
|
|
|
|
except (ValueError, TypeError) as e:
|
|
logger.error(
|
|
"Invalid usage value in Redis",
|
|
user_id=user_id,
|
|
error=str(e)
|
|
)
|
|
return 0
|
|
|
|
def check_rate_limit(self, user_id: str, user_tier: UserTier) -> None:
|
|
"""
|
|
Check if a user has exceeded their daily rate limit.
|
|
|
|
This method checks the current usage against the tier limit and
|
|
raises an exception if the limit has been reached.
|
|
|
|
Args:
|
|
user_id: The user ID to check
|
|
user_tier: The user's subscription tier
|
|
|
|
Raises:
|
|
RateLimitExceeded: If the user has reached their daily limit
|
|
RedisServiceError: If Redis operation fails
|
|
"""
|
|
current_usage = self.get_current_usage(user_id)
|
|
limit = self.get_limit_for_tier(user_tier)
|
|
|
|
if current_usage >= limit:
|
|
reset_time = self._get_reset_time()
|
|
|
|
logger.warning(
|
|
"Rate limit exceeded",
|
|
user_id=user_id,
|
|
user_tier=user_tier.value,
|
|
current_usage=current_usage,
|
|
limit=limit,
|
|
reset_time=reset_time.isoformat()
|
|
)
|
|
|
|
raise RateLimitExceeded(
|
|
user_id=user_id,
|
|
user_tier=user_tier,
|
|
limit=limit,
|
|
current_usage=current_usage,
|
|
reset_time=reset_time
|
|
)
|
|
|
|
logger.debug(
|
|
"Rate limit check passed",
|
|
user_id=user_id,
|
|
user_tier=user_tier.value,
|
|
current_usage=current_usage,
|
|
limit=limit
|
|
)
|
|
|
|
def increment_usage(self, user_id: str) -> int:
|
|
"""
|
|
Increment the daily usage counter for a user.
|
|
|
|
This method should be called after successfully processing an AI request.
|
|
The counter will automatically expire at midnight UTC.
|
|
|
|
Args:
|
|
user_id: The user ID to increment
|
|
|
|
Returns:
|
|
The new usage count after incrementing
|
|
|
|
Raises:
|
|
RedisServiceError: If Redis operation fails
|
|
"""
|
|
key = self._get_daily_key(user_id)
|
|
|
|
# Increment the counter
|
|
new_count = self.redis.incr(key)
|
|
|
|
# Set expiration if this is the first increment (new_count == 1)
|
|
# This ensures the key expires at midnight UTC
|
|
if new_count == 1:
|
|
ttl = self._get_seconds_until_midnight_utc()
|
|
self.redis.expire(key, ttl)
|
|
|
|
logger.debug(
|
|
"Set expiration on new rate limit key",
|
|
user_id=user_id,
|
|
ttl=ttl
|
|
)
|
|
|
|
logger.info(
|
|
"Incremented usage counter",
|
|
user_id=user_id,
|
|
new_count=new_count
|
|
)
|
|
|
|
return new_count
|
|
|
|
def get_remaining_turns(self, user_id: str, user_tier: UserTier) -> int:
|
|
"""
|
|
Get the number of remaining turns for a user today.
|
|
|
|
Args:
|
|
user_id: The user ID to check
|
|
user_tier: The user's subscription tier
|
|
|
|
Returns:
|
|
Number of turns remaining (0 if limit reached)
|
|
"""
|
|
current_usage = self.get_current_usage(user_id)
|
|
limit = self.get_limit_for_tier(user_tier)
|
|
|
|
remaining = max(0, limit - current_usage)
|
|
|
|
logger.debug(
|
|
"Calculated remaining turns",
|
|
user_id=user_id,
|
|
user_tier=user_tier.value,
|
|
current_usage=current_usage,
|
|
limit=limit,
|
|
remaining=remaining
|
|
)
|
|
|
|
return remaining
|
|
|
|
def get_usage_info(self, user_id: str, user_tier: UserTier) -> dict:
|
|
"""
|
|
Get comprehensive usage information for a user.
|
|
|
|
Args:
|
|
user_id: The user ID to check
|
|
user_tier: The user's subscription tier
|
|
|
|
Returns:
|
|
Dictionary with usage info:
|
|
- user_id: User identifier
|
|
- user_tier: Subscription tier
|
|
- current_usage: Current daily usage
|
|
- daily_limit: Daily limit for tier
|
|
- remaining: Remaining turns
|
|
- reset_time: ISO format UTC reset time
|
|
- is_limited: Whether limit has been reached
|
|
"""
|
|
current_usage = self.get_current_usage(user_id)
|
|
limit = self.get_limit_for_tier(user_tier)
|
|
remaining = max(0, limit - current_usage)
|
|
reset_time = self._get_reset_time()
|
|
|
|
info = {
|
|
"user_id": user_id,
|
|
"user_tier": user_tier.value,
|
|
"current_usage": current_usage,
|
|
"daily_limit": limit,
|
|
"remaining": remaining,
|
|
"reset_time": reset_time.isoformat(),
|
|
"is_limited": current_usage >= limit
|
|
}
|
|
|
|
logger.debug("Retrieved usage info", **info)
|
|
|
|
return info
|
|
|
|
def reset_usage(self, user_id: str) -> bool:
|
|
"""
|
|
Reset the daily usage counter for a user.
|
|
|
|
This is primarily for admin/testing purposes.
|
|
|
|
Args:
|
|
user_id: The user ID to reset
|
|
|
|
Returns:
|
|
True if the counter was deleted, False if it didn't exist
|
|
|
|
Raises:
|
|
RedisServiceError: If Redis operation fails
|
|
"""
|
|
key = self._get_daily_key(user_id)
|
|
deleted = self.redis.delete(key)
|
|
|
|
logger.info(
|
|
"Reset usage counter",
|
|
user_id=user_id,
|
|
deleted=deleted > 0
|
|
)
|
|
|
|
return deleted > 0
|
|
|
|
# ===== DM QUESTION RATE LIMITING =====
|
|
|
|
def _get_dm_question_key(self, user_id: str, day: Optional[date] = None) -> str:
|
|
"""
|
|
Generate the Redis key for a user's daily DM question counter.
|
|
|
|
Args:
|
|
user_id: The user ID
|
|
day: The date (defaults to today UTC)
|
|
|
|
Returns:
|
|
Redis key in format "rate_limit:dm_questions:user_id:YYYY-MM-DD"
|
|
"""
|
|
if day is None:
|
|
day = datetime.now(timezone.utc).date()
|
|
|
|
return f"{self.DM_QUESTION_PREFIX}{user_id}:{day.isoformat()}"
|
|
|
|
def get_dm_question_limit_for_tier(self, user_tier: UserTier) -> int:
|
|
"""
|
|
Get the daily DM question limit for a specific tier.
|
|
|
|
Args:
|
|
user_tier: The user's subscription tier
|
|
|
|
Returns:
|
|
Daily DM question limit for the tier (-1 for unlimited)
|
|
"""
|
|
return self.DM_QUESTION_LIMITS.get(user_tier, self.DM_QUESTION_LIMITS[UserTier.FREE])
|
|
|
|
def get_current_dm_usage(self, user_id: str) -> int:
|
|
"""
|
|
Get the current daily DM question usage count for a user.
|
|
|
|
Args:
|
|
user_id: The user ID to check
|
|
|
|
Returns:
|
|
Current DM question usage count (0 if no usage today)
|
|
|
|
Raises:
|
|
RedisServiceError: If Redis operation fails
|
|
"""
|
|
key = self._get_dm_question_key(user_id)
|
|
|
|
try:
|
|
value = self.redis.get(key)
|
|
usage = int(value) if value else 0
|
|
|
|
logger.debug(
|
|
"Retrieved current DM question usage",
|
|
user_id=user_id,
|
|
usage=usage
|
|
)
|
|
|
|
return usage
|
|
|
|
except (ValueError, TypeError) as e:
|
|
logger.error(
|
|
"Invalid DM question usage value in Redis",
|
|
user_id=user_id,
|
|
error=str(e)
|
|
)
|
|
return 0
|
|
|
|
def check_dm_question_limit(self, user_id: str, user_tier: UserTier) -> None:
|
|
"""
|
|
Check if a user has exceeded their daily DM question limit.
|
|
|
|
Args:
|
|
user_id: The user ID to check
|
|
user_tier: The user's subscription tier
|
|
|
|
Raises:
|
|
RateLimitExceeded: If the user has reached their daily DM question limit
|
|
RedisServiceError: If Redis operation fails
|
|
"""
|
|
limit = self.get_dm_question_limit_for_tier(user_tier)
|
|
|
|
# -1 means unlimited
|
|
if limit == -1:
|
|
logger.debug(
|
|
"DM question limit check passed (unlimited)",
|
|
user_id=user_id,
|
|
user_tier=user_tier.value
|
|
)
|
|
return
|
|
|
|
current_usage = self.get_current_dm_usage(user_id)
|
|
|
|
if current_usage >= limit:
|
|
reset_time = self._get_reset_time()
|
|
|
|
logger.warning(
|
|
"DM question limit exceeded",
|
|
user_id=user_id,
|
|
user_tier=user_tier.value,
|
|
current_usage=current_usage,
|
|
limit=limit,
|
|
reset_time=reset_time.isoformat()
|
|
)
|
|
|
|
raise RateLimitExceeded(
|
|
user_id=user_id,
|
|
user_tier=user_tier,
|
|
limit=limit,
|
|
current_usage=current_usage,
|
|
reset_time=reset_time
|
|
)
|
|
|
|
logger.debug(
|
|
"DM question limit check passed",
|
|
user_id=user_id,
|
|
user_tier=user_tier.value,
|
|
current_usage=current_usage,
|
|
limit=limit
|
|
)
|
|
|
|
def increment_dm_usage(self, user_id: str) -> int:
|
|
"""
|
|
Increment the daily DM question counter for a user.
|
|
|
|
Args:
|
|
user_id: The user ID to increment
|
|
|
|
Returns:
|
|
The new DM question usage count after incrementing
|
|
|
|
Raises:
|
|
RedisServiceError: If Redis operation fails
|
|
"""
|
|
key = self._get_dm_question_key(user_id)
|
|
|
|
# Increment the counter
|
|
new_count = self.redis.incr(key)
|
|
|
|
# Set expiration if this is the first increment
|
|
if new_count == 1:
|
|
ttl = self._get_seconds_until_midnight_utc()
|
|
self.redis.expire(key, ttl)
|
|
|
|
logger.debug(
|
|
"Set expiration on new DM question key",
|
|
user_id=user_id,
|
|
ttl=ttl
|
|
)
|
|
|
|
logger.info(
|
|
"Incremented DM question counter",
|
|
user_id=user_id,
|
|
new_count=new_count
|
|
)
|
|
|
|
return new_count
|
|
|
|
def get_remaining_dm_questions(self, user_id: str, user_tier: UserTier) -> int:
|
|
"""
|
|
Get the number of remaining DM questions for a user today.
|
|
|
|
Args:
|
|
user_id: The user ID to check
|
|
user_tier: The user's subscription tier
|
|
|
|
Returns:
|
|
Number of DM questions remaining (-1 if unlimited, 0 if limit reached)
|
|
"""
|
|
limit = self.get_dm_question_limit_for_tier(user_tier)
|
|
|
|
# -1 means unlimited
|
|
if limit == -1:
|
|
return -1
|
|
|
|
current_usage = self.get_current_dm_usage(user_id)
|
|
remaining = max(0, limit - current_usage)
|
|
|
|
logger.debug(
|
|
"Calculated remaining DM questions",
|
|
user_id=user_id,
|
|
user_tier=user_tier.value,
|
|
current_usage=current_usage,
|
|
limit=limit,
|
|
remaining=remaining
|
|
)
|
|
|
|
return remaining
|
|
|
|
def reset_dm_usage(self, user_id: str) -> bool:
|
|
"""
|
|
Reset the daily DM question counter for a user.
|
|
|
|
This is primarily for admin/testing purposes.
|
|
|
|
Args:
|
|
user_id: The user ID to reset
|
|
|
|
Returns:
|
|
True if the counter was deleted, False if it didn't exist
|
|
|
|
Raises:
|
|
RedisServiceError: If Redis operation fails
|
|
"""
|
|
key = self._get_dm_question_key(user_id)
|
|
deleted = self.redis.delete(key)
|
|
|
|
logger.info(
|
|
"Reset DM question counter",
|
|
user_id=user_id,
|
|
deleted=deleted > 0
|
|
)
|
|
|
|
return deleted > 0
|