first commit

This commit is contained in:
2025-11-24 23:10:55 -06:00
commit 8315fa51c9
279 changed files with 74600 additions and 0 deletions

View File

@@ -0,0 +1,602 @@
"""
Rate Limiter Service
This module implements tier-based rate limiting for AI requests using Redis
for distributed counting. Each user tier has a different daily limit for
AI-generated turns.
Usage:
from app.services.rate_limiter_service import RateLimiterService, RateLimitExceeded
from app.ai.model_selector import UserTier
# Initialize service
rate_limiter = RateLimiterService()
# Check and increment usage
try:
rate_limiter.check_rate_limit("user_123", UserTier.FREE)
rate_limiter.increment_usage("user_123")
except RateLimitExceeded as e:
print(f"Rate limit exceeded: {e}")
# Get remaining turns
remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE)
"""
from datetime import date, datetime, timezone, timedelta
from typing import Optional
from app.services.redis_service import RedisService, RedisServiceError
from app.ai.model_selector import UserTier
from app.utils.logging import get_logger
# Initialize logger
logger = get_logger(__file__)
class RateLimitExceeded(Exception):
"""
Raised when a user has exceeded their daily rate limit.
Attributes:
user_id: The user who exceeded the limit
user_tier: The user's subscription tier
limit: The daily limit for their tier
current_usage: The current usage count
reset_time: UTC timestamp when the limit resets
"""
def __init__(
self,
user_id: str,
user_tier: UserTier,
limit: int,
current_usage: int,
reset_time: datetime
):
self.user_id = user_id
self.user_tier = user_tier
self.limit = limit
self.current_usage = current_usage
self.reset_time = reset_time
message = (
f"Rate limit exceeded for user {user_id} ({user_tier.value} tier). "
f"Used {current_usage}/{limit} turns. Resets at {reset_time.isoformat()}"
)
super().__init__(message)
class RateLimiterService:
"""
Service for managing tier-based rate limiting.
This service uses Redis to track daily AI usage per user and enforces
limits based on subscription tier. Counters reset daily at midnight UTC.
Tier Limits:
- Free: 20 turns/day
- Basic: 50 turns/day
- Premium: 100 turns/day
- Elite: 200 turns/day
Attributes:
redis: RedisService instance for counter storage
tier_limits: Mapping of tier to daily turn limit
"""
# Daily turn limits per tier
TIER_LIMITS = {
UserTier.FREE: 20,
UserTier.BASIC: 50,
UserTier.PREMIUM: 100,
UserTier.ELITE: 200,
}
# Daily DM question limits per tier
DM_QUESTION_LIMITS = {
UserTier.FREE: 10,
UserTier.BASIC: 20,
UserTier.PREMIUM: 50,
UserTier.ELITE: -1, # -1 means unlimited
}
# Redis key prefix for rate limit counters
KEY_PREFIX = "rate_limit:daily:"
DM_QUESTION_PREFIX = "rate_limit:dm_questions:"
def __init__(self, redis_service: Optional[RedisService] = None):
"""
Initialize the rate limiter service.
Args:
redis_service: Optional RedisService instance. If not provided,
a new instance will be created.
"""
self.redis = redis_service or RedisService()
logger.info(
"RateLimiterService initialized",
tier_limits=self.TIER_LIMITS
)
def _get_daily_key(self, user_id: str, day: Optional[date] = None) -> str:
"""
Generate the Redis key for a user's daily counter.
Args:
user_id: The user ID
day: The date (defaults to today UTC)
Returns:
Redis key in format "rate_limit:daily:user_id:YYYY-MM-DD"
"""
if day is None:
day = datetime.now(timezone.utc).date()
return f"{self.KEY_PREFIX}{user_id}:{day.isoformat()}"
def _get_seconds_until_midnight_utc(self) -> int:
"""
Calculate seconds remaining until midnight UTC.
Returns:
Number of seconds until the next UTC midnight
"""
now = datetime.now(timezone.utc)
tomorrow = datetime(
now.year, now.month, now.day,
tzinfo=timezone.utc
) + timedelta(days=1)
return int((tomorrow - now).total_seconds())
def _get_reset_time(self) -> datetime:
"""
Get the UTC datetime when the rate limit resets.
Returns:
Datetime of next midnight UTC
"""
now = datetime.now(timezone.utc)
return datetime(
now.year, now.month, now.day,
tzinfo=timezone.utc
) + timedelta(days=1)
def get_limit_for_tier(self, user_tier: UserTier) -> int:
"""
Get the daily turn limit for a specific tier.
Args:
user_tier: The user's subscription tier
Returns:
Daily turn limit for the tier
"""
return self.TIER_LIMITS.get(user_tier, self.TIER_LIMITS[UserTier.FREE])
def get_current_usage(self, user_id: str) -> int:
"""
Get the current daily usage count for a user.
Args:
user_id: The user ID to check
Returns:
Current usage count (0 if no usage today)
Raises:
RedisServiceError: If Redis operation fails
"""
key = self._get_daily_key(user_id)
try:
value = self.redis.get(key)
usage = int(value) if value else 0
logger.debug(
"Retrieved current usage",
user_id=user_id,
usage=usage
)
return usage
except (ValueError, TypeError) as e:
logger.error(
"Invalid usage value in Redis",
user_id=user_id,
error=str(e)
)
return 0
def check_rate_limit(self, user_id: str, user_tier: UserTier) -> None:
"""
Check if a user has exceeded their daily rate limit.
This method checks the current usage against the tier limit and
raises an exception if the limit has been reached.
Args:
user_id: The user ID to check
user_tier: The user's subscription tier
Raises:
RateLimitExceeded: If the user has reached their daily limit
RedisServiceError: If Redis operation fails
"""
current_usage = self.get_current_usage(user_id)
limit = self.get_limit_for_tier(user_tier)
if current_usage >= limit:
reset_time = self._get_reset_time()
logger.warning(
"Rate limit exceeded",
user_id=user_id,
user_tier=user_tier.value,
current_usage=current_usage,
limit=limit,
reset_time=reset_time.isoformat()
)
raise RateLimitExceeded(
user_id=user_id,
user_tier=user_tier,
limit=limit,
current_usage=current_usage,
reset_time=reset_time
)
logger.debug(
"Rate limit check passed",
user_id=user_id,
user_tier=user_tier.value,
current_usage=current_usage,
limit=limit
)
def increment_usage(self, user_id: str) -> int:
"""
Increment the daily usage counter for a user.
This method should be called after successfully processing an AI request.
The counter will automatically expire at midnight UTC.
Args:
user_id: The user ID to increment
Returns:
The new usage count after incrementing
Raises:
RedisServiceError: If Redis operation fails
"""
key = self._get_daily_key(user_id)
# Increment the counter
new_count = self.redis.incr(key)
# Set expiration if this is the first increment (new_count == 1)
# This ensures the key expires at midnight UTC
if new_count == 1:
ttl = self._get_seconds_until_midnight_utc()
self.redis.expire(key, ttl)
logger.debug(
"Set expiration on new rate limit key",
user_id=user_id,
ttl=ttl
)
logger.info(
"Incremented usage counter",
user_id=user_id,
new_count=new_count
)
return new_count
def get_remaining_turns(self, user_id: str, user_tier: UserTier) -> int:
"""
Get the number of remaining turns for a user today.
Args:
user_id: The user ID to check
user_tier: The user's subscription tier
Returns:
Number of turns remaining (0 if limit reached)
"""
current_usage = self.get_current_usage(user_id)
limit = self.get_limit_for_tier(user_tier)
remaining = max(0, limit - current_usage)
logger.debug(
"Calculated remaining turns",
user_id=user_id,
user_tier=user_tier.value,
current_usage=current_usage,
limit=limit,
remaining=remaining
)
return remaining
def get_usage_info(self, user_id: str, user_tier: UserTier) -> dict:
"""
Get comprehensive usage information for a user.
Args:
user_id: The user ID to check
user_tier: The user's subscription tier
Returns:
Dictionary with usage info:
- user_id: User identifier
- user_tier: Subscription tier
- current_usage: Current daily usage
- daily_limit: Daily limit for tier
- remaining: Remaining turns
- reset_time: ISO format UTC reset time
- is_limited: Whether limit has been reached
"""
current_usage = self.get_current_usage(user_id)
limit = self.get_limit_for_tier(user_tier)
remaining = max(0, limit - current_usage)
reset_time = self._get_reset_time()
info = {
"user_id": user_id,
"user_tier": user_tier.value,
"current_usage": current_usage,
"daily_limit": limit,
"remaining": remaining,
"reset_time": reset_time.isoformat(),
"is_limited": current_usage >= limit
}
logger.debug("Retrieved usage info", **info)
return info
def reset_usage(self, user_id: str) -> bool:
"""
Reset the daily usage counter for a user.
This is primarily for admin/testing purposes.
Args:
user_id: The user ID to reset
Returns:
True if the counter was deleted, False if it didn't exist
Raises:
RedisServiceError: If Redis operation fails
"""
key = self._get_daily_key(user_id)
deleted = self.redis.delete(key)
logger.info(
"Reset usage counter",
user_id=user_id,
deleted=deleted > 0
)
return deleted > 0
# ===== DM QUESTION RATE LIMITING =====
def _get_dm_question_key(self, user_id: str, day: Optional[date] = None) -> str:
"""
Generate the Redis key for a user's daily DM question counter.
Args:
user_id: The user ID
day: The date (defaults to today UTC)
Returns:
Redis key in format "rate_limit:dm_questions:user_id:YYYY-MM-DD"
"""
if day is None:
day = datetime.now(timezone.utc).date()
return f"{self.DM_QUESTION_PREFIX}{user_id}:{day.isoformat()}"
def get_dm_question_limit_for_tier(self, user_tier: UserTier) -> int:
"""
Get the daily DM question limit for a specific tier.
Args:
user_tier: The user's subscription tier
Returns:
Daily DM question limit for the tier (-1 for unlimited)
"""
return self.DM_QUESTION_LIMITS.get(user_tier, self.DM_QUESTION_LIMITS[UserTier.FREE])
def get_current_dm_usage(self, user_id: str) -> int:
"""
Get the current daily DM question usage count for a user.
Args:
user_id: The user ID to check
Returns:
Current DM question usage count (0 if no usage today)
Raises:
RedisServiceError: If Redis operation fails
"""
key = self._get_dm_question_key(user_id)
try:
value = self.redis.get(key)
usage = int(value) if value else 0
logger.debug(
"Retrieved current DM question usage",
user_id=user_id,
usage=usage
)
return usage
except (ValueError, TypeError) as e:
logger.error(
"Invalid DM question usage value in Redis",
user_id=user_id,
error=str(e)
)
return 0
def check_dm_question_limit(self, user_id: str, user_tier: UserTier) -> None:
"""
Check if a user has exceeded their daily DM question limit.
Args:
user_id: The user ID to check
user_tier: The user's subscription tier
Raises:
RateLimitExceeded: If the user has reached their daily DM question limit
RedisServiceError: If Redis operation fails
"""
limit = self.get_dm_question_limit_for_tier(user_tier)
# -1 means unlimited
if limit == -1:
logger.debug(
"DM question limit check passed (unlimited)",
user_id=user_id,
user_tier=user_tier.value
)
return
current_usage = self.get_current_dm_usage(user_id)
if current_usage >= limit:
reset_time = self._get_reset_time()
logger.warning(
"DM question limit exceeded",
user_id=user_id,
user_tier=user_tier.value,
current_usage=current_usage,
limit=limit,
reset_time=reset_time.isoformat()
)
raise RateLimitExceeded(
user_id=user_id,
user_tier=user_tier,
limit=limit,
current_usage=current_usage,
reset_time=reset_time
)
logger.debug(
"DM question limit check passed",
user_id=user_id,
user_tier=user_tier.value,
current_usage=current_usage,
limit=limit
)
def increment_dm_usage(self, user_id: str) -> int:
"""
Increment the daily DM question counter for a user.
Args:
user_id: The user ID to increment
Returns:
The new DM question usage count after incrementing
Raises:
RedisServiceError: If Redis operation fails
"""
key = self._get_dm_question_key(user_id)
# Increment the counter
new_count = self.redis.incr(key)
# Set expiration if this is the first increment
if new_count == 1:
ttl = self._get_seconds_until_midnight_utc()
self.redis.expire(key, ttl)
logger.debug(
"Set expiration on new DM question key",
user_id=user_id,
ttl=ttl
)
logger.info(
"Incremented DM question counter",
user_id=user_id,
new_count=new_count
)
return new_count
def get_remaining_dm_questions(self, user_id: str, user_tier: UserTier) -> int:
"""
Get the number of remaining DM questions for a user today.
Args:
user_id: The user ID to check
user_tier: The user's subscription tier
Returns:
Number of DM questions remaining (-1 if unlimited, 0 if limit reached)
"""
limit = self.get_dm_question_limit_for_tier(user_tier)
# -1 means unlimited
if limit == -1:
return -1
current_usage = self.get_current_dm_usage(user_id)
remaining = max(0, limit - current_usage)
logger.debug(
"Calculated remaining DM questions",
user_id=user_id,
user_tier=user_tier.value,
current_usage=current_usage,
limit=limit,
remaining=remaining
)
return remaining
def reset_dm_usage(self, user_id: str) -> bool:
"""
Reset the daily DM question counter for a user.
This is primarily for admin/testing purposes.
Args:
user_id: The user ID to reset
Returns:
True if the counter was deleted, False if it didn't exist
Raises:
RedisServiceError: If Redis operation fails
"""
key = self._get_dm_question_key(user_id)
deleted = self.redis.delete(key)
logger.info(
"Reset DM question counter",
user_id=user_id,
deleted=deleted > 0
)
return deleted > 0