first commit

This commit is contained in:
2025-11-24 23:10:55 -06:00
commit 8315fa51c9
279 changed files with 74600 additions and 0 deletions

View File

@@ -0,0 +1,342 @@
"""
Tests for RateLimiterService
Tests the tier-based rate limiting functionality including:
- Daily limits per tier
- Usage tracking and incrementing
- Rate limit checks and exceptions
- Reset functionality
"""
import pytest
from unittest.mock import MagicMock, patch
from datetime import datetime, timezone, timedelta
from app.services.rate_limiter_service import (
RateLimiterService,
RateLimitExceeded,
)
from app.ai.model_selector import UserTier
class TestRateLimiterService:
"""Tests for RateLimiterService."""
@pytest.fixture
def mock_redis(self):
"""Create a mock Redis service."""
mock = MagicMock()
mock.get.return_value = None
mock.incr.return_value = 1
mock.expire.return_value = True
mock.delete.return_value = 1
return mock
@pytest.fixture
def rate_limiter(self, mock_redis):
"""Create a RateLimiterService with mock Redis."""
return RateLimiterService(redis_service=mock_redis)
def test_tier_limits(self, rate_limiter):
"""Test that tier limits are correctly defined."""
assert rate_limiter.get_limit_for_tier(UserTier.FREE) == 20
assert rate_limiter.get_limit_for_tier(UserTier.BASIC) == 50
assert rate_limiter.get_limit_for_tier(UserTier.PREMIUM) == 100
assert rate_limiter.get_limit_for_tier(UserTier.ELITE) == 200
def test_get_daily_key(self, rate_limiter):
"""Test Redis key generation."""
from datetime import date
key = rate_limiter._get_daily_key("user_123", date(2025, 1, 15))
assert key == "rate_limit:daily:user_123:2025-01-15"
def test_get_current_usage_no_usage(self, rate_limiter, mock_redis):
"""Test getting current usage when no usage exists."""
mock_redis.get.return_value = None
usage = rate_limiter.get_current_usage("user_123")
assert usage == 0
mock_redis.get.assert_called_once()
def test_get_current_usage_with_usage(self, rate_limiter, mock_redis):
"""Test getting current usage when usage exists."""
mock_redis.get.return_value = "15"
usage = rate_limiter.get_current_usage("user_123")
assert usage == 15
def test_check_rate_limit_under_limit(self, rate_limiter, mock_redis):
"""Test that check passes when under limit."""
mock_redis.get.return_value = "10"
# Should not raise
rate_limiter.check_rate_limit("user_123", UserTier.FREE)
def test_check_rate_limit_at_limit(self, rate_limiter, mock_redis):
"""Test that check fails when at limit."""
mock_redis.get.return_value = "20" # Free tier limit
with pytest.raises(RateLimitExceeded) as exc_info:
rate_limiter.check_rate_limit("user_123", UserTier.FREE)
exc = exc_info.value
assert exc.user_id == "user_123"
assert exc.user_tier == UserTier.FREE
assert exc.limit == 20
assert exc.current_usage == 20
def test_check_rate_limit_over_limit(self, rate_limiter, mock_redis):
"""Test that check fails when over limit."""
mock_redis.get.return_value = "25" # Over free tier limit
with pytest.raises(RateLimitExceeded):
rate_limiter.check_rate_limit("user_123", UserTier.FREE)
def test_check_rate_limit_premium_tier(self, rate_limiter, mock_redis):
"""Test that premium tier has higher limit."""
mock_redis.get.return_value = "50" # Over free limit, under premium
# Should not raise for premium
rate_limiter.check_rate_limit("user_123", UserTier.PREMIUM)
# Should raise for free
with pytest.raises(RateLimitExceeded):
rate_limiter.check_rate_limit("user_123", UserTier.FREE)
def test_increment_usage_first_time(self, rate_limiter, mock_redis):
"""Test incrementing usage for the first time (sets expiration)."""
mock_redis.incr.return_value = 1
new_count = rate_limiter.increment_usage("user_123")
assert new_count == 1
mock_redis.incr.assert_called_once()
mock_redis.expire.assert_called_once() # Should set expiration
def test_increment_usage_subsequent(self, rate_limiter, mock_redis):
"""Test incrementing usage after first time (no expiration set)."""
mock_redis.incr.return_value = 5
new_count = rate_limiter.increment_usage("user_123")
assert new_count == 5
mock_redis.incr.assert_called_once()
mock_redis.expire.assert_not_called() # Should NOT set expiration
def test_get_remaining_turns_full(self, rate_limiter, mock_redis):
"""Test remaining turns when no usage."""
mock_redis.get.return_value = None
remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE)
assert remaining == 20
def test_get_remaining_turns_partial(self, rate_limiter, mock_redis):
"""Test remaining turns with partial usage."""
mock_redis.get.return_value = "12"
remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE)
assert remaining == 8
def test_get_remaining_turns_exhausted(self, rate_limiter, mock_redis):
"""Test remaining turns when limit reached."""
mock_redis.get.return_value = "20"
remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE)
assert remaining == 0
def test_get_remaining_turns_over_limit(self, rate_limiter, mock_redis):
"""Test remaining turns when over limit (should be 0, not negative)."""
mock_redis.get.return_value = "25"
remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE)
assert remaining == 0
def test_get_usage_info(self, rate_limiter, mock_redis):
"""Test getting comprehensive usage info."""
mock_redis.get.return_value = "15"
info = rate_limiter.get_usage_info("user_123", UserTier.FREE)
assert info["user_id"] == "user_123"
assert info["user_tier"] == "free"
assert info["current_usage"] == 15
assert info["daily_limit"] == 20
assert info["remaining"] == 5
assert info["is_limited"] is False
assert "reset_time" in info
def test_get_usage_info_limited(self, rate_limiter, mock_redis):
"""Test usage info when limited."""
mock_redis.get.return_value = "20"
info = rate_limiter.get_usage_info("user_123", UserTier.FREE)
assert info["is_limited"] is True
assert info["remaining"] == 0
def test_reset_usage(self, rate_limiter, mock_redis):
"""Test resetting usage counter."""
mock_redis.delete.return_value = 1
result = rate_limiter.reset_usage("user_123")
assert result is True
mock_redis.delete.assert_called_once()
def test_reset_usage_no_key(self, rate_limiter, mock_redis):
"""Test resetting when no usage exists."""
mock_redis.delete.return_value = 0
result = rate_limiter.reset_usage("user_123")
assert result is False
class TestRateLimitExceeded:
"""Tests for RateLimitExceeded exception."""
def test_exception_attributes(self):
"""Test that exception has correct attributes."""
reset_time = datetime(2025, 1, 16, 0, 0, 0, tzinfo=timezone.utc)
exc = RateLimitExceeded(
user_id="user_123",
user_tier=UserTier.FREE,
limit=20,
current_usage=20,
reset_time=reset_time
)
assert exc.user_id == "user_123"
assert exc.user_tier == UserTier.FREE
assert exc.limit == 20
assert exc.current_usage == 20
assert exc.reset_time == reset_time
def test_exception_message(self):
"""Test that exception message is formatted correctly."""
reset_time = datetime(2025, 1, 16, 0, 0, 0, tzinfo=timezone.utc)
exc = RateLimitExceeded(
user_id="user_123",
user_tier=UserTier.FREE,
limit=20,
current_usage=20,
reset_time=reset_time
)
message = str(exc)
assert "user_123" in message
assert "free tier" in message
assert "20/20" in message
class TestRateLimiterIntegration:
"""Integration tests for rate limiter workflow."""
@pytest.fixture
def mock_redis(self):
"""Create a mock Redis that simulates real behavior."""
storage = {}
mock = MagicMock()
def mock_get(key):
return storage.get(key)
def mock_incr(key):
if key not in storage:
storage[key] = 0
storage[key] = int(storage[key]) + 1
return storage[key]
def mock_delete(key):
if key in storage:
del storage[key]
return 1
return 0
mock.get.side_effect = mock_get
mock.incr.side_effect = mock_incr
mock.delete.side_effect = mock_delete
mock.expire.return_value = True
return mock
@pytest.fixture
def rate_limiter(self, mock_redis):
"""Create rate limiter with simulated Redis."""
return RateLimiterService(redis_service=mock_redis)
def test_full_workflow(self, rate_limiter):
"""Test complete rate limiting workflow."""
user_id = "user_123"
tier = UserTier.FREE # 20 turns/day
# Initial state - should pass
rate_limiter.check_rate_limit(user_id, tier)
assert rate_limiter.get_remaining_turns(user_id, tier) == 20
# Use some turns
for i in range(15):
rate_limiter.check_rate_limit(user_id, tier)
rate_limiter.increment_usage(user_id)
# Check remaining
assert rate_limiter.get_remaining_turns(user_id, tier) == 5
# Use remaining turns
for i in range(5):
rate_limiter.check_rate_limit(user_id, tier)
rate_limiter.increment_usage(user_id)
# Now should be limited
assert rate_limiter.get_remaining_turns(user_id, tier) == 0
with pytest.raises(RateLimitExceeded):
rate_limiter.check_rate_limit(user_id, tier)
def test_different_tiers_same_usage(self, rate_limiter):
"""Test that same usage affects different tiers differently."""
user_id = "user_123"
# Use 30 turns
for _ in range(30):
rate_limiter.increment_usage(user_id)
# Free tier (20) should be limited
with pytest.raises(RateLimitExceeded):
rate_limiter.check_rate_limit(user_id, UserTier.FREE)
# Basic tier (50) should not be limited
rate_limiter.check_rate_limit(user_id, UserTier.BASIC)
# Premium tier (100) should not be limited
rate_limiter.check_rate_limit(user_id, UserTier.PREMIUM)
def test_reset_clears_usage(self, rate_limiter):
"""Test that reset allows new usage."""
user_id = "user_123"
tier = UserTier.FREE
# Use all turns
for _ in range(20):
rate_limiter.increment_usage(user_id)
# Should be limited
with pytest.raises(RateLimitExceeded):
rate_limiter.check_rate_limit(user_id, tier)
# Reset usage
rate_limiter.reset_usage(user_id)
# Should be able to use again
rate_limiter.check_rate_limit(user_id, tier)
assert rate_limiter.get_remaining_turns(user_id, tier) == 20