""" Rate Limiter Service This module implements tier-based rate limiting for AI requests using Redis for distributed counting. Each user tier has a different daily limit for AI-generated turns. Usage: from app.services.rate_limiter_service import RateLimiterService, RateLimitExceeded from app.ai.model_selector import UserTier # Initialize service rate_limiter = RateLimiterService() # Check and increment usage try: rate_limiter.check_rate_limit("user_123", UserTier.FREE) rate_limiter.increment_usage("user_123") except RateLimitExceeded as e: print(f"Rate limit exceeded: {e}") # Get remaining turns remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE) """ from datetime import date, datetime, timezone, timedelta from typing import Optional from app.services.redis_service import RedisService, RedisServiceError from app.ai.model_selector import UserTier from app.utils.logging import get_logger # Initialize logger logger = get_logger(__file__) class RateLimitExceeded(Exception): """ Raised when a user has exceeded their daily rate limit. Attributes: user_id: The user who exceeded the limit user_tier: The user's subscription tier limit: The daily limit for their tier current_usage: The current usage count reset_time: UTC timestamp when the limit resets """ def __init__( self, user_id: str, user_tier: UserTier, limit: int, current_usage: int, reset_time: datetime ): self.user_id = user_id self.user_tier = user_tier self.limit = limit self.current_usage = current_usage self.reset_time = reset_time message = ( f"Rate limit exceeded for user {user_id} ({user_tier.value} tier). " f"Used {current_usage}/{limit} turns. Resets at {reset_time.isoformat()}" ) super().__init__(message) class RateLimiterService: """ Service for managing tier-based rate limiting. This service uses Redis to track daily AI usage per user and enforces limits based on subscription tier. Counters reset daily at midnight UTC. Tier Limits: - Free: 20 turns/day - Basic: 50 turns/day - Premium: 100 turns/day - Elite: 200 turns/day Attributes: redis: RedisService instance for counter storage tier_limits: Mapping of tier to daily turn limit """ # Daily turn limits per tier TIER_LIMITS = { UserTier.FREE: 20, UserTier.BASIC: 50, UserTier.PREMIUM: 100, UserTier.ELITE: 200, } # Daily DM question limits per tier DM_QUESTION_LIMITS = { UserTier.FREE: 10, UserTier.BASIC: 20, UserTier.PREMIUM: 50, UserTier.ELITE: -1, # -1 means unlimited } # Redis key prefix for rate limit counters KEY_PREFIX = "rate_limit:daily:" DM_QUESTION_PREFIX = "rate_limit:dm_questions:" def __init__(self, redis_service: Optional[RedisService] = None): """ Initialize the rate limiter service. Args: redis_service: Optional RedisService instance. If not provided, a new instance will be created. """ self.redis = redis_service or RedisService() logger.info( "RateLimiterService initialized", tier_limits=self.TIER_LIMITS ) def _get_daily_key(self, user_id: str, day: Optional[date] = None) -> str: """ Generate the Redis key for a user's daily counter. Args: user_id: The user ID day: The date (defaults to today UTC) Returns: Redis key in format "rate_limit:daily:user_id:YYYY-MM-DD" """ if day is None: day = datetime.now(timezone.utc).date() return f"{self.KEY_PREFIX}{user_id}:{day.isoformat()}" def _get_seconds_until_midnight_utc(self) -> int: """ Calculate seconds remaining until midnight UTC. Returns: Number of seconds until the next UTC midnight """ now = datetime.now(timezone.utc) tomorrow = datetime( now.year, now.month, now.day, tzinfo=timezone.utc ) + timedelta(days=1) return int((tomorrow - now).total_seconds()) def _get_reset_time(self) -> datetime: """ Get the UTC datetime when the rate limit resets. Returns: Datetime of next midnight UTC """ now = datetime.now(timezone.utc) return datetime( now.year, now.month, now.day, tzinfo=timezone.utc ) + timedelta(days=1) def get_limit_for_tier(self, user_tier: UserTier) -> int: """ Get the daily turn limit for a specific tier. Args: user_tier: The user's subscription tier Returns: Daily turn limit for the tier """ return self.TIER_LIMITS.get(user_tier, self.TIER_LIMITS[UserTier.FREE]) def get_current_usage(self, user_id: str) -> int: """ Get the current daily usage count for a user. Args: user_id: The user ID to check Returns: Current usage count (0 if no usage today) Raises: RedisServiceError: If Redis operation fails """ key = self._get_daily_key(user_id) try: value = self.redis.get(key) usage = int(value) if value else 0 logger.debug( "Retrieved current usage", user_id=user_id, usage=usage ) return usage except (ValueError, TypeError) as e: logger.error( "Invalid usage value in Redis", user_id=user_id, error=str(e) ) return 0 def check_rate_limit(self, user_id: str, user_tier: UserTier) -> None: """ Check if a user has exceeded their daily rate limit. This method checks the current usage against the tier limit and raises an exception if the limit has been reached. Args: user_id: The user ID to check user_tier: The user's subscription tier Raises: RateLimitExceeded: If the user has reached their daily limit RedisServiceError: If Redis operation fails """ current_usage = self.get_current_usage(user_id) limit = self.get_limit_for_tier(user_tier) if current_usage >= limit: reset_time = self._get_reset_time() logger.warning( "Rate limit exceeded", user_id=user_id, user_tier=user_tier.value, current_usage=current_usage, limit=limit, reset_time=reset_time.isoformat() ) raise RateLimitExceeded( user_id=user_id, user_tier=user_tier, limit=limit, current_usage=current_usage, reset_time=reset_time ) logger.debug( "Rate limit check passed", user_id=user_id, user_tier=user_tier.value, current_usage=current_usage, limit=limit ) def increment_usage(self, user_id: str) -> int: """ Increment the daily usage counter for a user. This method should be called after successfully processing an AI request. The counter will automatically expire at midnight UTC. Args: user_id: The user ID to increment Returns: The new usage count after incrementing Raises: RedisServiceError: If Redis operation fails """ key = self._get_daily_key(user_id) # Increment the counter new_count = self.redis.incr(key) # Set expiration if this is the first increment (new_count == 1) # This ensures the key expires at midnight UTC if new_count == 1: ttl = self._get_seconds_until_midnight_utc() self.redis.expire(key, ttl) logger.debug( "Set expiration on new rate limit key", user_id=user_id, ttl=ttl ) logger.info( "Incremented usage counter", user_id=user_id, new_count=new_count ) return new_count def get_remaining_turns(self, user_id: str, user_tier: UserTier) -> int: """ Get the number of remaining turns for a user today. Args: user_id: The user ID to check user_tier: The user's subscription tier Returns: Number of turns remaining (0 if limit reached) """ current_usage = self.get_current_usage(user_id) limit = self.get_limit_for_tier(user_tier) remaining = max(0, limit - current_usage) logger.debug( "Calculated remaining turns", user_id=user_id, user_tier=user_tier.value, current_usage=current_usage, limit=limit, remaining=remaining ) return remaining def get_usage_info(self, user_id: str, user_tier: UserTier) -> dict: """ Get comprehensive usage information for a user. Args: user_id: The user ID to check user_tier: The user's subscription tier Returns: Dictionary with usage info: - user_id: User identifier - user_tier: Subscription tier - current_usage: Current daily usage - daily_limit: Daily limit for tier - remaining: Remaining turns - reset_time: ISO format UTC reset time - is_limited: Whether limit has been reached """ current_usage = self.get_current_usage(user_id) limit = self.get_limit_for_tier(user_tier) remaining = max(0, limit - current_usage) reset_time = self._get_reset_time() info = { "user_id": user_id, "user_tier": user_tier.value, "current_usage": current_usage, "daily_limit": limit, "remaining": remaining, "reset_time": reset_time.isoformat(), "is_limited": current_usage >= limit } logger.debug("Retrieved usage info", **info) return info def reset_usage(self, user_id: str) -> bool: """ Reset the daily usage counter for a user. This is primarily for admin/testing purposes. Args: user_id: The user ID to reset Returns: True if the counter was deleted, False if it didn't exist Raises: RedisServiceError: If Redis operation fails """ key = self._get_daily_key(user_id) deleted = self.redis.delete(key) logger.info( "Reset usage counter", user_id=user_id, deleted=deleted > 0 ) return deleted > 0 # ===== DM QUESTION RATE LIMITING ===== def _get_dm_question_key(self, user_id: str, day: Optional[date] = None) -> str: """ Generate the Redis key for a user's daily DM question counter. Args: user_id: The user ID day: The date (defaults to today UTC) Returns: Redis key in format "rate_limit:dm_questions:user_id:YYYY-MM-DD" """ if day is None: day = datetime.now(timezone.utc).date() return f"{self.DM_QUESTION_PREFIX}{user_id}:{day.isoformat()}" def get_dm_question_limit_for_tier(self, user_tier: UserTier) -> int: """ Get the daily DM question limit for a specific tier. Args: user_tier: The user's subscription tier Returns: Daily DM question limit for the tier (-1 for unlimited) """ return self.DM_QUESTION_LIMITS.get(user_tier, self.DM_QUESTION_LIMITS[UserTier.FREE]) def get_current_dm_usage(self, user_id: str) -> int: """ Get the current daily DM question usage count for a user. Args: user_id: The user ID to check Returns: Current DM question usage count (0 if no usage today) Raises: RedisServiceError: If Redis operation fails """ key = self._get_dm_question_key(user_id) try: value = self.redis.get(key) usage = int(value) if value else 0 logger.debug( "Retrieved current DM question usage", user_id=user_id, usage=usage ) return usage except (ValueError, TypeError) as e: logger.error( "Invalid DM question usage value in Redis", user_id=user_id, error=str(e) ) return 0 def check_dm_question_limit(self, user_id: str, user_tier: UserTier) -> None: """ Check if a user has exceeded their daily DM question limit. Args: user_id: The user ID to check user_tier: The user's subscription tier Raises: RateLimitExceeded: If the user has reached their daily DM question limit RedisServiceError: If Redis operation fails """ limit = self.get_dm_question_limit_for_tier(user_tier) # -1 means unlimited if limit == -1: logger.debug( "DM question limit check passed (unlimited)", user_id=user_id, user_tier=user_tier.value ) return current_usage = self.get_current_dm_usage(user_id) if current_usage >= limit: reset_time = self._get_reset_time() logger.warning( "DM question limit exceeded", user_id=user_id, user_tier=user_tier.value, current_usage=current_usage, limit=limit, reset_time=reset_time.isoformat() ) raise RateLimitExceeded( user_id=user_id, user_tier=user_tier, limit=limit, current_usage=current_usage, reset_time=reset_time ) logger.debug( "DM question limit check passed", user_id=user_id, user_tier=user_tier.value, current_usage=current_usage, limit=limit ) def increment_dm_usage(self, user_id: str) -> int: """ Increment the daily DM question counter for a user. Args: user_id: The user ID to increment Returns: The new DM question usage count after incrementing Raises: RedisServiceError: If Redis operation fails """ key = self._get_dm_question_key(user_id) # Increment the counter new_count = self.redis.incr(key) # Set expiration if this is the first increment if new_count == 1: ttl = self._get_seconds_until_midnight_utc() self.redis.expire(key, ttl) logger.debug( "Set expiration on new DM question key", user_id=user_id, ttl=ttl ) logger.info( "Incremented DM question counter", user_id=user_id, new_count=new_count ) return new_count def get_remaining_dm_questions(self, user_id: str, user_tier: UserTier) -> int: """ Get the number of remaining DM questions for a user today. Args: user_id: The user ID to check user_tier: The user's subscription tier Returns: Number of DM questions remaining (-1 if unlimited, 0 if limit reached) """ limit = self.get_dm_question_limit_for_tier(user_tier) # -1 means unlimited if limit == -1: return -1 current_usage = self.get_current_dm_usage(user_id) remaining = max(0, limit - current_usage) logger.debug( "Calculated remaining DM questions", user_id=user_id, user_tier=user_tier.value, current_usage=current_usage, limit=limit, remaining=remaining ) return remaining def reset_dm_usage(self, user_id: str) -> bool: """ Reset the daily DM question counter for a user. This is primarily for admin/testing purposes. Args: user_id: The user ID to reset Returns: True if the counter was deleted, False if it didn't exist Raises: RedisServiceError: If Redis operation fails """ key = self._get_dm_question_key(user_id) deleted = self.redis.delete(key) logger.info( "Reset DM question counter", user_id=user_id, deleted=deleted > 0 ) return deleted > 0