first commit
This commit is contained in:
602
api/app/services/rate_limiter_service.py
Normal file
602
api/app/services/rate_limiter_service.py
Normal file
@@ -0,0 +1,602 @@
|
||||
"""
|
||||
Rate Limiter Service
|
||||
|
||||
This module implements tier-based rate limiting for AI requests using Redis
|
||||
for distributed counting. Each user tier has a different daily limit for
|
||||
AI-generated turns.
|
||||
|
||||
Usage:
|
||||
from app.services.rate_limiter_service import RateLimiterService, RateLimitExceeded
|
||||
from app.ai.model_selector import UserTier
|
||||
|
||||
# Initialize service
|
||||
rate_limiter = RateLimiterService()
|
||||
|
||||
# Check and increment usage
|
||||
try:
|
||||
rate_limiter.check_rate_limit("user_123", UserTier.FREE)
|
||||
rate_limiter.increment_usage("user_123")
|
||||
except RateLimitExceeded as e:
|
||||
print(f"Rate limit exceeded: {e}")
|
||||
|
||||
# Get remaining turns
|
||||
remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE)
|
||||
"""
|
||||
|
||||
from datetime import date, datetime, timezone, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from app.services.redis_service import RedisService, RedisServiceError
|
||||
from app.ai.model_selector import UserTier
|
||||
from app.utils.logging import get_logger
|
||||
|
||||
|
||||
# Initialize logger
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
class RateLimitExceeded(Exception):
|
||||
"""
|
||||
Raised when a user has exceeded their daily rate limit.
|
||||
|
||||
Attributes:
|
||||
user_id: The user who exceeded the limit
|
||||
user_tier: The user's subscription tier
|
||||
limit: The daily limit for their tier
|
||||
current_usage: The current usage count
|
||||
reset_time: UTC timestamp when the limit resets
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str,
|
||||
user_tier: UserTier,
|
||||
limit: int,
|
||||
current_usage: int,
|
||||
reset_time: datetime
|
||||
):
|
||||
self.user_id = user_id
|
||||
self.user_tier = user_tier
|
||||
self.limit = limit
|
||||
self.current_usage = current_usage
|
||||
self.reset_time = reset_time
|
||||
|
||||
message = (
|
||||
f"Rate limit exceeded for user {user_id} ({user_tier.value} tier). "
|
||||
f"Used {current_usage}/{limit} turns. Resets at {reset_time.isoformat()}"
|
||||
)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class RateLimiterService:
|
||||
"""
|
||||
Service for managing tier-based rate limiting.
|
||||
|
||||
This service uses Redis to track daily AI usage per user and enforces
|
||||
limits based on subscription tier. Counters reset daily at midnight UTC.
|
||||
|
||||
Tier Limits:
|
||||
- Free: 20 turns/day
|
||||
- Basic: 50 turns/day
|
||||
- Premium: 100 turns/day
|
||||
- Elite: 200 turns/day
|
||||
|
||||
Attributes:
|
||||
redis: RedisService instance for counter storage
|
||||
tier_limits: Mapping of tier to daily turn limit
|
||||
"""
|
||||
|
||||
# Daily turn limits per tier
|
||||
TIER_LIMITS = {
|
||||
UserTier.FREE: 20,
|
||||
UserTier.BASIC: 50,
|
||||
UserTier.PREMIUM: 100,
|
||||
UserTier.ELITE: 200,
|
||||
}
|
||||
|
||||
# Daily DM question limits per tier
|
||||
DM_QUESTION_LIMITS = {
|
||||
UserTier.FREE: 10,
|
||||
UserTier.BASIC: 20,
|
||||
UserTier.PREMIUM: 50,
|
||||
UserTier.ELITE: -1, # -1 means unlimited
|
||||
}
|
||||
|
||||
# Redis key prefix for rate limit counters
|
||||
KEY_PREFIX = "rate_limit:daily:"
|
||||
DM_QUESTION_PREFIX = "rate_limit:dm_questions:"
|
||||
|
||||
def __init__(self, redis_service: Optional[RedisService] = None):
|
||||
"""
|
||||
Initialize the rate limiter service.
|
||||
|
||||
Args:
|
||||
redis_service: Optional RedisService instance. If not provided,
|
||||
a new instance will be created.
|
||||
"""
|
||||
self.redis = redis_service or RedisService()
|
||||
|
||||
logger.info(
|
||||
"RateLimiterService initialized",
|
||||
tier_limits=self.TIER_LIMITS
|
||||
)
|
||||
|
||||
def _get_daily_key(self, user_id: str, day: Optional[date] = None) -> str:
|
||||
"""
|
||||
Generate the Redis key for a user's daily counter.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
day: The date (defaults to today UTC)
|
||||
|
||||
Returns:
|
||||
Redis key in format "rate_limit:daily:user_id:YYYY-MM-DD"
|
||||
"""
|
||||
if day is None:
|
||||
day = datetime.now(timezone.utc).date()
|
||||
|
||||
return f"{self.KEY_PREFIX}{user_id}:{day.isoformat()}"
|
||||
|
||||
def _get_seconds_until_midnight_utc(self) -> int:
|
||||
"""
|
||||
Calculate seconds remaining until midnight UTC.
|
||||
|
||||
Returns:
|
||||
Number of seconds until the next UTC midnight
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
tomorrow = datetime(
|
||||
now.year, now.month, now.day,
|
||||
tzinfo=timezone.utc
|
||||
) + timedelta(days=1)
|
||||
|
||||
return int((tomorrow - now).total_seconds())
|
||||
|
||||
def _get_reset_time(self) -> datetime:
|
||||
"""
|
||||
Get the UTC datetime when the rate limit resets.
|
||||
|
||||
Returns:
|
||||
Datetime of next midnight UTC
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
return datetime(
|
||||
now.year, now.month, now.day,
|
||||
tzinfo=timezone.utc
|
||||
) + timedelta(days=1)
|
||||
|
||||
def get_limit_for_tier(self, user_tier: UserTier) -> int:
|
||||
"""
|
||||
Get the daily turn limit for a specific tier.
|
||||
|
||||
Args:
|
||||
user_tier: The user's subscription tier
|
||||
|
||||
Returns:
|
||||
Daily turn limit for the tier
|
||||
"""
|
||||
return self.TIER_LIMITS.get(user_tier, self.TIER_LIMITS[UserTier.FREE])
|
||||
|
||||
def get_current_usage(self, user_id: str) -> int:
|
||||
"""
|
||||
Get the current daily usage count for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to check
|
||||
|
||||
Returns:
|
||||
Current usage count (0 if no usage today)
|
||||
|
||||
Raises:
|
||||
RedisServiceError: If Redis operation fails
|
||||
"""
|
||||
key = self._get_daily_key(user_id)
|
||||
|
||||
try:
|
||||
value = self.redis.get(key)
|
||||
usage = int(value) if value else 0
|
||||
|
||||
logger.debug(
|
||||
"Retrieved current usage",
|
||||
user_id=user_id,
|
||||
usage=usage
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.error(
|
||||
"Invalid usage value in Redis",
|
||||
user_id=user_id,
|
||||
error=str(e)
|
||||
)
|
||||
return 0
|
||||
|
||||
def check_rate_limit(self, user_id: str, user_tier: UserTier) -> None:
|
||||
"""
|
||||
Check if a user has exceeded their daily rate limit.
|
||||
|
||||
This method checks the current usage against the tier limit and
|
||||
raises an exception if the limit has been reached.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to check
|
||||
user_tier: The user's subscription tier
|
||||
|
||||
Raises:
|
||||
RateLimitExceeded: If the user has reached their daily limit
|
||||
RedisServiceError: If Redis operation fails
|
||||
"""
|
||||
current_usage = self.get_current_usage(user_id)
|
||||
limit = self.get_limit_for_tier(user_tier)
|
||||
|
||||
if current_usage >= limit:
|
||||
reset_time = self._get_reset_time()
|
||||
|
||||
logger.warning(
|
||||
"Rate limit exceeded",
|
||||
user_id=user_id,
|
||||
user_tier=user_tier.value,
|
||||
current_usage=current_usage,
|
||||
limit=limit,
|
||||
reset_time=reset_time.isoformat()
|
||||
)
|
||||
|
||||
raise RateLimitExceeded(
|
||||
user_id=user_id,
|
||||
user_tier=user_tier,
|
||||
limit=limit,
|
||||
current_usage=current_usage,
|
||||
reset_time=reset_time
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Rate limit check passed",
|
||||
user_id=user_id,
|
||||
user_tier=user_tier.value,
|
||||
current_usage=current_usage,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
def increment_usage(self, user_id: str) -> int:
|
||||
"""
|
||||
Increment the daily usage counter for a user.
|
||||
|
||||
This method should be called after successfully processing an AI request.
|
||||
The counter will automatically expire at midnight UTC.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to increment
|
||||
|
||||
Returns:
|
||||
The new usage count after incrementing
|
||||
|
||||
Raises:
|
||||
RedisServiceError: If Redis operation fails
|
||||
"""
|
||||
key = self._get_daily_key(user_id)
|
||||
|
||||
# Increment the counter
|
||||
new_count = self.redis.incr(key)
|
||||
|
||||
# Set expiration if this is the first increment (new_count == 1)
|
||||
# This ensures the key expires at midnight UTC
|
||||
if new_count == 1:
|
||||
ttl = self._get_seconds_until_midnight_utc()
|
||||
self.redis.expire(key, ttl)
|
||||
|
||||
logger.debug(
|
||||
"Set expiration on new rate limit key",
|
||||
user_id=user_id,
|
||||
ttl=ttl
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Incremented usage counter",
|
||||
user_id=user_id,
|
||||
new_count=new_count
|
||||
)
|
||||
|
||||
return new_count
|
||||
|
||||
def get_remaining_turns(self, user_id: str, user_tier: UserTier) -> int:
|
||||
"""
|
||||
Get the number of remaining turns for a user today.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to check
|
||||
user_tier: The user's subscription tier
|
||||
|
||||
Returns:
|
||||
Number of turns remaining (0 if limit reached)
|
||||
"""
|
||||
current_usage = self.get_current_usage(user_id)
|
||||
limit = self.get_limit_for_tier(user_tier)
|
||||
|
||||
remaining = max(0, limit - current_usage)
|
||||
|
||||
logger.debug(
|
||||
"Calculated remaining turns",
|
||||
user_id=user_id,
|
||||
user_tier=user_tier.value,
|
||||
current_usage=current_usage,
|
||||
limit=limit,
|
||||
remaining=remaining
|
||||
)
|
||||
|
||||
return remaining
|
||||
|
||||
def get_usage_info(self, user_id: str, user_tier: UserTier) -> dict:
|
||||
"""
|
||||
Get comprehensive usage information for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to check
|
||||
user_tier: The user's subscription tier
|
||||
|
||||
Returns:
|
||||
Dictionary with usage info:
|
||||
- user_id: User identifier
|
||||
- user_tier: Subscription tier
|
||||
- current_usage: Current daily usage
|
||||
- daily_limit: Daily limit for tier
|
||||
- remaining: Remaining turns
|
||||
- reset_time: ISO format UTC reset time
|
||||
- is_limited: Whether limit has been reached
|
||||
"""
|
||||
current_usage = self.get_current_usage(user_id)
|
||||
limit = self.get_limit_for_tier(user_tier)
|
||||
remaining = max(0, limit - current_usage)
|
||||
reset_time = self._get_reset_time()
|
||||
|
||||
info = {
|
||||
"user_id": user_id,
|
||||
"user_tier": user_tier.value,
|
||||
"current_usage": current_usage,
|
||||
"daily_limit": limit,
|
||||
"remaining": remaining,
|
||||
"reset_time": reset_time.isoformat(),
|
||||
"is_limited": current_usage >= limit
|
||||
}
|
||||
|
||||
logger.debug("Retrieved usage info", **info)
|
||||
|
||||
return info
|
||||
|
||||
def reset_usage(self, user_id: str) -> bool:
|
||||
"""
|
||||
Reset the daily usage counter for a user.
|
||||
|
||||
This is primarily for admin/testing purposes.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to reset
|
||||
|
||||
Returns:
|
||||
True if the counter was deleted, False if it didn't exist
|
||||
|
||||
Raises:
|
||||
RedisServiceError: If Redis operation fails
|
||||
"""
|
||||
key = self._get_daily_key(user_id)
|
||||
deleted = self.redis.delete(key)
|
||||
|
||||
logger.info(
|
||||
"Reset usage counter",
|
||||
user_id=user_id,
|
||||
deleted=deleted > 0
|
||||
)
|
||||
|
||||
return deleted > 0
|
||||
|
||||
# ===== DM QUESTION RATE LIMITING =====
|
||||
|
||||
def _get_dm_question_key(self, user_id: str, day: Optional[date] = None) -> str:
|
||||
"""
|
||||
Generate the Redis key for a user's daily DM question counter.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
day: The date (defaults to today UTC)
|
||||
|
||||
Returns:
|
||||
Redis key in format "rate_limit:dm_questions:user_id:YYYY-MM-DD"
|
||||
"""
|
||||
if day is None:
|
||||
day = datetime.now(timezone.utc).date()
|
||||
|
||||
return f"{self.DM_QUESTION_PREFIX}{user_id}:{day.isoformat()}"
|
||||
|
||||
def get_dm_question_limit_for_tier(self, user_tier: UserTier) -> int:
|
||||
"""
|
||||
Get the daily DM question limit for a specific tier.
|
||||
|
||||
Args:
|
||||
user_tier: The user's subscription tier
|
||||
|
||||
Returns:
|
||||
Daily DM question limit for the tier (-1 for unlimited)
|
||||
"""
|
||||
return self.DM_QUESTION_LIMITS.get(user_tier, self.DM_QUESTION_LIMITS[UserTier.FREE])
|
||||
|
||||
def get_current_dm_usage(self, user_id: str) -> int:
|
||||
"""
|
||||
Get the current daily DM question usage count for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to check
|
||||
|
||||
Returns:
|
||||
Current DM question usage count (0 if no usage today)
|
||||
|
||||
Raises:
|
||||
RedisServiceError: If Redis operation fails
|
||||
"""
|
||||
key = self._get_dm_question_key(user_id)
|
||||
|
||||
try:
|
||||
value = self.redis.get(key)
|
||||
usage = int(value) if value else 0
|
||||
|
||||
logger.debug(
|
||||
"Retrieved current DM question usage",
|
||||
user_id=user_id,
|
||||
usage=usage
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.error(
|
||||
"Invalid DM question usage value in Redis",
|
||||
user_id=user_id,
|
||||
error=str(e)
|
||||
)
|
||||
return 0
|
||||
|
||||
def check_dm_question_limit(self, user_id: str, user_tier: UserTier) -> None:
|
||||
"""
|
||||
Check if a user has exceeded their daily DM question limit.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to check
|
||||
user_tier: The user's subscription tier
|
||||
|
||||
Raises:
|
||||
RateLimitExceeded: If the user has reached their daily DM question limit
|
||||
RedisServiceError: If Redis operation fails
|
||||
"""
|
||||
limit = self.get_dm_question_limit_for_tier(user_tier)
|
||||
|
||||
# -1 means unlimited
|
||||
if limit == -1:
|
||||
logger.debug(
|
||||
"DM question limit check passed (unlimited)",
|
||||
user_id=user_id,
|
||||
user_tier=user_tier.value
|
||||
)
|
||||
return
|
||||
|
||||
current_usage = self.get_current_dm_usage(user_id)
|
||||
|
||||
if current_usage >= limit:
|
||||
reset_time = self._get_reset_time()
|
||||
|
||||
logger.warning(
|
||||
"DM question limit exceeded",
|
||||
user_id=user_id,
|
||||
user_tier=user_tier.value,
|
||||
current_usage=current_usage,
|
||||
limit=limit,
|
||||
reset_time=reset_time.isoformat()
|
||||
)
|
||||
|
||||
raise RateLimitExceeded(
|
||||
user_id=user_id,
|
||||
user_tier=user_tier,
|
||||
limit=limit,
|
||||
current_usage=current_usage,
|
||||
reset_time=reset_time
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"DM question limit check passed",
|
||||
user_id=user_id,
|
||||
user_tier=user_tier.value,
|
||||
current_usage=current_usage,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
def increment_dm_usage(self, user_id: str) -> int:
|
||||
"""
|
||||
Increment the daily DM question counter for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to increment
|
||||
|
||||
Returns:
|
||||
The new DM question usage count after incrementing
|
||||
|
||||
Raises:
|
||||
RedisServiceError: If Redis operation fails
|
||||
"""
|
||||
key = self._get_dm_question_key(user_id)
|
||||
|
||||
# Increment the counter
|
||||
new_count = self.redis.incr(key)
|
||||
|
||||
# Set expiration if this is the first increment
|
||||
if new_count == 1:
|
||||
ttl = self._get_seconds_until_midnight_utc()
|
||||
self.redis.expire(key, ttl)
|
||||
|
||||
logger.debug(
|
||||
"Set expiration on new DM question key",
|
||||
user_id=user_id,
|
||||
ttl=ttl
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Incremented DM question counter",
|
||||
user_id=user_id,
|
||||
new_count=new_count
|
||||
)
|
||||
|
||||
return new_count
|
||||
|
||||
def get_remaining_dm_questions(self, user_id: str, user_tier: UserTier) -> int:
|
||||
"""
|
||||
Get the number of remaining DM questions for a user today.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to check
|
||||
user_tier: The user's subscription tier
|
||||
|
||||
Returns:
|
||||
Number of DM questions remaining (-1 if unlimited, 0 if limit reached)
|
||||
"""
|
||||
limit = self.get_dm_question_limit_for_tier(user_tier)
|
||||
|
||||
# -1 means unlimited
|
||||
if limit == -1:
|
||||
return -1
|
||||
|
||||
current_usage = self.get_current_dm_usage(user_id)
|
||||
remaining = max(0, limit - current_usage)
|
||||
|
||||
logger.debug(
|
||||
"Calculated remaining DM questions",
|
||||
user_id=user_id,
|
||||
user_tier=user_tier.value,
|
||||
current_usage=current_usage,
|
||||
limit=limit,
|
||||
remaining=remaining
|
||||
)
|
||||
|
||||
return remaining
|
||||
|
||||
def reset_dm_usage(self, user_id: str) -> bool:
|
||||
"""
|
||||
Reset the daily DM question counter for a user.
|
||||
|
||||
This is primarily for admin/testing purposes.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to reset
|
||||
|
||||
Returns:
|
||||
True if the counter was deleted, False if it didn't exist
|
||||
|
||||
Raises:
|
||||
RedisServiceError: If Redis operation fails
|
||||
"""
|
||||
key = self._get_dm_question_key(user_id)
|
||||
deleted = self.redis.delete(key)
|
||||
|
||||
logger.info(
|
||||
"Reset DM question counter",
|
||||
user_id=user_id,
|
||||
deleted=deleted > 0
|
||||
)
|
||||
|
||||
return deleted > 0
|
||||
Reference in New Issue
Block a user