first commit
This commit is contained in:
342
api/tests/test_rate_limiter_service.py
Normal file
342
api/tests/test_rate_limiter_service.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
Tests for RateLimiterService
|
||||
|
||||
Tests the tier-based rate limiting functionality including:
|
||||
- Daily limits per tier
|
||||
- Usage tracking and incrementing
|
||||
- Rate limit checks and exceptions
|
||||
- Reset functionality
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
from app.services.rate_limiter_service import (
|
||||
RateLimiterService,
|
||||
RateLimitExceeded,
|
||||
)
|
||||
from app.ai.model_selector import UserTier
|
||||
|
||||
|
||||
class TestRateLimiterService:
|
||||
"""Tests for RateLimiterService."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis(self):
|
||||
"""Create a mock Redis service."""
|
||||
mock = MagicMock()
|
||||
mock.get.return_value = None
|
||||
mock.incr.return_value = 1
|
||||
mock.expire.return_value = True
|
||||
mock.delete.return_value = 1
|
||||
return mock
|
||||
|
||||
@pytest.fixture
|
||||
def rate_limiter(self, mock_redis):
|
||||
"""Create a RateLimiterService with mock Redis."""
|
||||
return RateLimiterService(redis_service=mock_redis)
|
||||
|
||||
def test_tier_limits(self, rate_limiter):
|
||||
"""Test that tier limits are correctly defined."""
|
||||
assert rate_limiter.get_limit_for_tier(UserTier.FREE) == 20
|
||||
assert rate_limiter.get_limit_for_tier(UserTier.BASIC) == 50
|
||||
assert rate_limiter.get_limit_for_tier(UserTier.PREMIUM) == 100
|
||||
assert rate_limiter.get_limit_for_tier(UserTier.ELITE) == 200
|
||||
|
||||
def test_get_daily_key(self, rate_limiter):
|
||||
"""Test Redis key generation."""
|
||||
from datetime import date
|
||||
|
||||
key = rate_limiter._get_daily_key("user_123", date(2025, 1, 15))
|
||||
assert key == "rate_limit:daily:user_123:2025-01-15"
|
||||
|
||||
def test_get_current_usage_no_usage(self, rate_limiter, mock_redis):
|
||||
"""Test getting current usage when no usage exists."""
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
usage = rate_limiter.get_current_usage("user_123")
|
||||
|
||||
assert usage == 0
|
||||
mock_redis.get.assert_called_once()
|
||||
|
||||
def test_get_current_usage_with_usage(self, rate_limiter, mock_redis):
|
||||
"""Test getting current usage when usage exists."""
|
||||
mock_redis.get.return_value = "15"
|
||||
|
||||
usage = rate_limiter.get_current_usage("user_123")
|
||||
|
||||
assert usage == 15
|
||||
|
||||
def test_check_rate_limit_under_limit(self, rate_limiter, mock_redis):
|
||||
"""Test that check passes when under limit."""
|
||||
mock_redis.get.return_value = "10"
|
||||
|
||||
# Should not raise
|
||||
rate_limiter.check_rate_limit("user_123", UserTier.FREE)
|
||||
|
||||
def test_check_rate_limit_at_limit(self, rate_limiter, mock_redis):
|
||||
"""Test that check fails when at limit."""
|
||||
mock_redis.get.return_value = "20" # Free tier limit
|
||||
|
||||
with pytest.raises(RateLimitExceeded) as exc_info:
|
||||
rate_limiter.check_rate_limit("user_123", UserTier.FREE)
|
||||
|
||||
exc = exc_info.value
|
||||
assert exc.user_id == "user_123"
|
||||
assert exc.user_tier == UserTier.FREE
|
||||
assert exc.limit == 20
|
||||
assert exc.current_usage == 20
|
||||
|
||||
def test_check_rate_limit_over_limit(self, rate_limiter, mock_redis):
|
||||
"""Test that check fails when over limit."""
|
||||
mock_redis.get.return_value = "25" # Over free tier limit
|
||||
|
||||
with pytest.raises(RateLimitExceeded):
|
||||
rate_limiter.check_rate_limit("user_123", UserTier.FREE)
|
||||
|
||||
def test_check_rate_limit_premium_tier(self, rate_limiter, mock_redis):
|
||||
"""Test that premium tier has higher limit."""
|
||||
mock_redis.get.return_value = "50" # Over free limit, under premium
|
||||
|
||||
# Should not raise for premium
|
||||
rate_limiter.check_rate_limit("user_123", UserTier.PREMIUM)
|
||||
|
||||
# Should raise for free
|
||||
with pytest.raises(RateLimitExceeded):
|
||||
rate_limiter.check_rate_limit("user_123", UserTier.FREE)
|
||||
|
||||
def test_increment_usage_first_time(self, rate_limiter, mock_redis):
|
||||
"""Test incrementing usage for the first time (sets expiration)."""
|
||||
mock_redis.incr.return_value = 1
|
||||
|
||||
new_count = rate_limiter.increment_usage("user_123")
|
||||
|
||||
assert new_count == 1
|
||||
mock_redis.incr.assert_called_once()
|
||||
mock_redis.expire.assert_called_once() # Should set expiration
|
||||
|
||||
def test_increment_usage_subsequent(self, rate_limiter, mock_redis):
|
||||
"""Test incrementing usage after first time (no expiration set)."""
|
||||
mock_redis.incr.return_value = 5
|
||||
|
||||
new_count = rate_limiter.increment_usage("user_123")
|
||||
|
||||
assert new_count == 5
|
||||
mock_redis.incr.assert_called_once()
|
||||
mock_redis.expire.assert_not_called() # Should NOT set expiration
|
||||
|
||||
def test_get_remaining_turns_full(self, rate_limiter, mock_redis):
|
||||
"""Test remaining turns when no usage."""
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE)
|
||||
|
||||
assert remaining == 20
|
||||
|
||||
def test_get_remaining_turns_partial(self, rate_limiter, mock_redis):
|
||||
"""Test remaining turns with partial usage."""
|
||||
mock_redis.get.return_value = "12"
|
||||
|
||||
remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE)
|
||||
|
||||
assert remaining == 8
|
||||
|
||||
def test_get_remaining_turns_exhausted(self, rate_limiter, mock_redis):
|
||||
"""Test remaining turns when limit reached."""
|
||||
mock_redis.get.return_value = "20"
|
||||
|
||||
remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE)
|
||||
|
||||
assert remaining == 0
|
||||
|
||||
def test_get_remaining_turns_over_limit(self, rate_limiter, mock_redis):
|
||||
"""Test remaining turns when over limit (should be 0, not negative)."""
|
||||
mock_redis.get.return_value = "25"
|
||||
|
||||
remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE)
|
||||
|
||||
assert remaining == 0
|
||||
|
||||
def test_get_usage_info(self, rate_limiter, mock_redis):
|
||||
"""Test getting comprehensive usage info."""
|
||||
mock_redis.get.return_value = "15"
|
||||
|
||||
info = rate_limiter.get_usage_info("user_123", UserTier.FREE)
|
||||
|
||||
assert info["user_id"] == "user_123"
|
||||
assert info["user_tier"] == "free"
|
||||
assert info["current_usage"] == 15
|
||||
assert info["daily_limit"] == 20
|
||||
assert info["remaining"] == 5
|
||||
assert info["is_limited"] is False
|
||||
assert "reset_time" in info
|
||||
|
||||
def test_get_usage_info_limited(self, rate_limiter, mock_redis):
|
||||
"""Test usage info when limited."""
|
||||
mock_redis.get.return_value = "20"
|
||||
|
||||
info = rate_limiter.get_usage_info("user_123", UserTier.FREE)
|
||||
|
||||
assert info["is_limited"] is True
|
||||
assert info["remaining"] == 0
|
||||
|
||||
def test_reset_usage(self, rate_limiter, mock_redis):
|
||||
"""Test resetting usage counter."""
|
||||
mock_redis.delete.return_value = 1
|
||||
|
||||
result = rate_limiter.reset_usage("user_123")
|
||||
|
||||
assert result is True
|
||||
mock_redis.delete.assert_called_once()
|
||||
|
||||
def test_reset_usage_no_key(self, rate_limiter, mock_redis):
|
||||
"""Test resetting when no usage exists."""
|
||||
mock_redis.delete.return_value = 0
|
||||
|
||||
result = rate_limiter.reset_usage("user_123")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestRateLimitExceeded:
|
||||
"""Tests for RateLimitExceeded exception."""
|
||||
|
||||
def test_exception_attributes(self):
|
||||
"""Test that exception has correct attributes."""
|
||||
reset_time = datetime(2025, 1, 16, 0, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
exc = RateLimitExceeded(
|
||||
user_id="user_123",
|
||||
user_tier=UserTier.FREE,
|
||||
limit=20,
|
||||
current_usage=20,
|
||||
reset_time=reset_time
|
||||
)
|
||||
|
||||
assert exc.user_id == "user_123"
|
||||
assert exc.user_tier == UserTier.FREE
|
||||
assert exc.limit == 20
|
||||
assert exc.current_usage == 20
|
||||
assert exc.reset_time == reset_time
|
||||
|
||||
def test_exception_message(self):
|
||||
"""Test that exception message is formatted correctly."""
|
||||
reset_time = datetime(2025, 1, 16, 0, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
exc = RateLimitExceeded(
|
||||
user_id="user_123",
|
||||
user_tier=UserTier.FREE,
|
||||
limit=20,
|
||||
current_usage=20,
|
||||
reset_time=reset_time
|
||||
)
|
||||
|
||||
message = str(exc)
|
||||
assert "user_123" in message
|
||||
assert "free tier" in message
|
||||
assert "20/20" in message
|
||||
|
||||
|
||||
class TestRateLimiterIntegration:
|
||||
"""Integration tests for rate limiter workflow."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis(self):
|
||||
"""Create a mock Redis that simulates real behavior."""
|
||||
storage = {}
|
||||
|
||||
mock = MagicMock()
|
||||
|
||||
def mock_get(key):
|
||||
return storage.get(key)
|
||||
|
||||
def mock_incr(key):
|
||||
if key not in storage:
|
||||
storage[key] = 0
|
||||
storage[key] = int(storage[key]) + 1
|
||||
return storage[key]
|
||||
|
||||
def mock_delete(key):
|
||||
if key in storage:
|
||||
del storage[key]
|
||||
return 1
|
||||
return 0
|
||||
|
||||
mock.get.side_effect = mock_get
|
||||
mock.incr.side_effect = mock_incr
|
||||
mock.delete.side_effect = mock_delete
|
||||
mock.expire.return_value = True
|
||||
|
||||
return mock
|
||||
|
||||
@pytest.fixture
|
||||
def rate_limiter(self, mock_redis):
|
||||
"""Create rate limiter with simulated Redis."""
|
||||
return RateLimiterService(redis_service=mock_redis)
|
||||
|
||||
def test_full_workflow(self, rate_limiter):
|
||||
"""Test complete rate limiting workflow."""
|
||||
user_id = "user_123"
|
||||
tier = UserTier.FREE # 20 turns/day
|
||||
|
||||
# Initial state - should pass
|
||||
rate_limiter.check_rate_limit(user_id, tier)
|
||||
assert rate_limiter.get_remaining_turns(user_id, tier) == 20
|
||||
|
||||
# Use some turns
|
||||
for i in range(15):
|
||||
rate_limiter.check_rate_limit(user_id, tier)
|
||||
rate_limiter.increment_usage(user_id)
|
||||
|
||||
# Check remaining
|
||||
assert rate_limiter.get_remaining_turns(user_id, tier) == 5
|
||||
|
||||
# Use remaining turns
|
||||
for i in range(5):
|
||||
rate_limiter.check_rate_limit(user_id, tier)
|
||||
rate_limiter.increment_usage(user_id)
|
||||
|
||||
# Now should be limited
|
||||
assert rate_limiter.get_remaining_turns(user_id, tier) == 0
|
||||
|
||||
with pytest.raises(RateLimitExceeded):
|
||||
rate_limiter.check_rate_limit(user_id, tier)
|
||||
|
||||
def test_different_tiers_same_usage(self, rate_limiter):
|
||||
"""Test that same usage affects different tiers differently."""
|
||||
user_id = "user_123"
|
||||
|
||||
# Use 30 turns
|
||||
for _ in range(30):
|
||||
rate_limiter.increment_usage(user_id)
|
||||
|
||||
# Free tier (20) should be limited
|
||||
with pytest.raises(RateLimitExceeded):
|
||||
rate_limiter.check_rate_limit(user_id, UserTier.FREE)
|
||||
|
||||
# Basic tier (50) should not be limited
|
||||
rate_limiter.check_rate_limit(user_id, UserTier.BASIC)
|
||||
|
||||
# Premium tier (100) should not be limited
|
||||
rate_limiter.check_rate_limit(user_id, UserTier.PREMIUM)
|
||||
|
||||
def test_reset_clears_usage(self, rate_limiter):
|
||||
"""Test that reset allows new usage."""
|
||||
user_id = "user_123"
|
||||
tier = UserTier.FREE
|
||||
|
||||
# Use all turns
|
||||
for _ in range(20):
|
||||
rate_limiter.increment_usage(user_id)
|
||||
|
||||
# Should be limited
|
||||
with pytest.raises(RateLimitExceeded):
|
||||
rate_limiter.check_rate_limit(user_id, tier)
|
||||
|
||||
# Reset usage
|
||||
rate_limiter.reset_usage(user_id)
|
||||
|
||||
# Should be able to use again
|
||||
rate_limiter.check_rate_limit(user_id, tier)
|
||||
assert rate_limiter.get_remaining_turns(user_id, tier) == 20
|
||||
Reference in New Issue
Block a user