343 lines
11 KiB
Python
343 lines
11 KiB
Python
"""
|
|
Tests for RateLimiterService
|
|
|
|
Tests the tier-based rate limiting functionality including:
|
|
- Daily limits per tier
|
|
- Usage tracking and incrementing
|
|
- Rate limit checks and exceptions
|
|
- Reset functionality
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import MagicMock, patch
|
|
from datetime import datetime, timezone, timedelta
|
|
|
|
from app.services.rate_limiter_service import (
|
|
RateLimiterService,
|
|
RateLimitExceeded,
|
|
)
|
|
from app.ai.model_selector import UserTier
|
|
|
|
|
|
class TestRateLimiterService:
|
|
"""Tests for RateLimiterService."""
|
|
|
|
@pytest.fixture
|
|
def mock_redis(self):
|
|
"""Create a mock Redis service."""
|
|
mock = MagicMock()
|
|
mock.get.return_value = None
|
|
mock.incr.return_value = 1
|
|
mock.expire.return_value = True
|
|
mock.delete.return_value = 1
|
|
return mock
|
|
|
|
@pytest.fixture
|
|
def rate_limiter(self, mock_redis):
|
|
"""Create a RateLimiterService with mock Redis."""
|
|
return RateLimiterService(redis_service=mock_redis)
|
|
|
|
def test_tier_limits(self, rate_limiter):
|
|
"""Test that tier limits are correctly defined."""
|
|
assert rate_limiter.get_limit_for_tier(UserTier.FREE) == 20
|
|
assert rate_limiter.get_limit_for_tier(UserTier.BASIC) == 50
|
|
assert rate_limiter.get_limit_for_tier(UserTier.PREMIUM) == 100
|
|
assert rate_limiter.get_limit_for_tier(UserTier.ELITE) == 200
|
|
|
|
def test_get_daily_key(self, rate_limiter):
|
|
"""Test Redis key generation."""
|
|
from datetime import date
|
|
|
|
key = rate_limiter._get_daily_key("user_123", date(2025, 1, 15))
|
|
assert key == "rate_limit:daily:user_123:2025-01-15"
|
|
|
|
def test_get_current_usage_no_usage(self, rate_limiter, mock_redis):
|
|
"""Test getting current usage when no usage exists."""
|
|
mock_redis.get.return_value = None
|
|
|
|
usage = rate_limiter.get_current_usage("user_123")
|
|
|
|
assert usage == 0
|
|
mock_redis.get.assert_called_once()
|
|
|
|
def test_get_current_usage_with_usage(self, rate_limiter, mock_redis):
|
|
"""Test getting current usage when usage exists."""
|
|
mock_redis.get.return_value = "15"
|
|
|
|
usage = rate_limiter.get_current_usage("user_123")
|
|
|
|
assert usage == 15
|
|
|
|
def test_check_rate_limit_under_limit(self, rate_limiter, mock_redis):
|
|
"""Test that check passes when under limit."""
|
|
mock_redis.get.return_value = "10"
|
|
|
|
# Should not raise
|
|
rate_limiter.check_rate_limit("user_123", UserTier.FREE)
|
|
|
|
def test_check_rate_limit_at_limit(self, rate_limiter, mock_redis):
|
|
"""Test that check fails when at limit."""
|
|
mock_redis.get.return_value = "20" # Free tier limit
|
|
|
|
with pytest.raises(RateLimitExceeded) as exc_info:
|
|
rate_limiter.check_rate_limit("user_123", UserTier.FREE)
|
|
|
|
exc = exc_info.value
|
|
assert exc.user_id == "user_123"
|
|
assert exc.user_tier == UserTier.FREE
|
|
assert exc.limit == 20
|
|
assert exc.current_usage == 20
|
|
|
|
def test_check_rate_limit_over_limit(self, rate_limiter, mock_redis):
|
|
"""Test that check fails when over limit."""
|
|
mock_redis.get.return_value = "25" # Over free tier limit
|
|
|
|
with pytest.raises(RateLimitExceeded):
|
|
rate_limiter.check_rate_limit("user_123", UserTier.FREE)
|
|
|
|
def test_check_rate_limit_premium_tier(self, rate_limiter, mock_redis):
|
|
"""Test that premium tier has higher limit."""
|
|
mock_redis.get.return_value = "50" # Over free limit, under premium
|
|
|
|
# Should not raise for premium
|
|
rate_limiter.check_rate_limit("user_123", UserTier.PREMIUM)
|
|
|
|
# Should raise for free
|
|
with pytest.raises(RateLimitExceeded):
|
|
rate_limiter.check_rate_limit("user_123", UserTier.FREE)
|
|
|
|
def test_increment_usage_first_time(self, rate_limiter, mock_redis):
|
|
"""Test incrementing usage for the first time (sets expiration)."""
|
|
mock_redis.incr.return_value = 1
|
|
|
|
new_count = rate_limiter.increment_usage("user_123")
|
|
|
|
assert new_count == 1
|
|
mock_redis.incr.assert_called_once()
|
|
mock_redis.expire.assert_called_once() # Should set expiration
|
|
|
|
def test_increment_usage_subsequent(self, rate_limiter, mock_redis):
|
|
"""Test incrementing usage after first time (no expiration set)."""
|
|
mock_redis.incr.return_value = 5
|
|
|
|
new_count = rate_limiter.increment_usage("user_123")
|
|
|
|
assert new_count == 5
|
|
mock_redis.incr.assert_called_once()
|
|
mock_redis.expire.assert_not_called() # Should NOT set expiration
|
|
|
|
def test_get_remaining_turns_full(self, rate_limiter, mock_redis):
|
|
"""Test remaining turns when no usage."""
|
|
mock_redis.get.return_value = None
|
|
|
|
remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE)
|
|
|
|
assert remaining == 20
|
|
|
|
def test_get_remaining_turns_partial(self, rate_limiter, mock_redis):
|
|
"""Test remaining turns with partial usage."""
|
|
mock_redis.get.return_value = "12"
|
|
|
|
remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE)
|
|
|
|
assert remaining == 8
|
|
|
|
def test_get_remaining_turns_exhausted(self, rate_limiter, mock_redis):
|
|
"""Test remaining turns when limit reached."""
|
|
mock_redis.get.return_value = "20"
|
|
|
|
remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE)
|
|
|
|
assert remaining == 0
|
|
|
|
def test_get_remaining_turns_over_limit(self, rate_limiter, mock_redis):
|
|
"""Test remaining turns when over limit (should be 0, not negative)."""
|
|
mock_redis.get.return_value = "25"
|
|
|
|
remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE)
|
|
|
|
assert remaining == 0
|
|
|
|
def test_get_usage_info(self, rate_limiter, mock_redis):
|
|
"""Test getting comprehensive usage info."""
|
|
mock_redis.get.return_value = "15"
|
|
|
|
info = rate_limiter.get_usage_info("user_123", UserTier.FREE)
|
|
|
|
assert info["user_id"] == "user_123"
|
|
assert info["user_tier"] == "free"
|
|
assert info["current_usage"] == 15
|
|
assert info["daily_limit"] == 20
|
|
assert info["remaining"] == 5
|
|
assert info["is_limited"] is False
|
|
assert "reset_time" in info
|
|
|
|
def test_get_usage_info_limited(self, rate_limiter, mock_redis):
|
|
"""Test usage info when limited."""
|
|
mock_redis.get.return_value = "20"
|
|
|
|
info = rate_limiter.get_usage_info("user_123", UserTier.FREE)
|
|
|
|
assert info["is_limited"] is True
|
|
assert info["remaining"] == 0
|
|
|
|
def test_reset_usage(self, rate_limiter, mock_redis):
|
|
"""Test resetting usage counter."""
|
|
mock_redis.delete.return_value = 1
|
|
|
|
result = rate_limiter.reset_usage("user_123")
|
|
|
|
assert result is True
|
|
mock_redis.delete.assert_called_once()
|
|
|
|
def test_reset_usage_no_key(self, rate_limiter, mock_redis):
|
|
"""Test resetting when no usage exists."""
|
|
mock_redis.delete.return_value = 0
|
|
|
|
result = rate_limiter.reset_usage("user_123")
|
|
|
|
assert result is False
|
|
|
|
|
|
class TestRateLimitExceeded:
|
|
"""Tests for RateLimitExceeded exception."""
|
|
|
|
def test_exception_attributes(self):
|
|
"""Test that exception has correct attributes."""
|
|
reset_time = datetime(2025, 1, 16, 0, 0, 0, tzinfo=timezone.utc)
|
|
|
|
exc = RateLimitExceeded(
|
|
user_id="user_123",
|
|
user_tier=UserTier.FREE,
|
|
limit=20,
|
|
current_usage=20,
|
|
reset_time=reset_time
|
|
)
|
|
|
|
assert exc.user_id == "user_123"
|
|
assert exc.user_tier == UserTier.FREE
|
|
assert exc.limit == 20
|
|
assert exc.current_usage == 20
|
|
assert exc.reset_time == reset_time
|
|
|
|
def test_exception_message(self):
|
|
"""Test that exception message is formatted correctly."""
|
|
reset_time = datetime(2025, 1, 16, 0, 0, 0, tzinfo=timezone.utc)
|
|
|
|
exc = RateLimitExceeded(
|
|
user_id="user_123",
|
|
user_tier=UserTier.FREE,
|
|
limit=20,
|
|
current_usage=20,
|
|
reset_time=reset_time
|
|
)
|
|
|
|
message = str(exc)
|
|
assert "user_123" in message
|
|
assert "free tier" in message
|
|
assert "20/20" in message
|
|
|
|
|
|
class TestRateLimiterIntegration:
|
|
"""Integration tests for rate limiter workflow."""
|
|
|
|
@pytest.fixture
|
|
def mock_redis(self):
|
|
"""Create a mock Redis that simulates real behavior."""
|
|
storage = {}
|
|
|
|
mock = MagicMock()
|
|
|
|
def mock_get(key):
|
|
return storage.get(key)
|
|
|
|
def mock_incr(key):
|
|
if key not in storage:
|
|
storage[key] = 0
|
|
storage[key] = int(storage[key]) + 1
|
|
return storage[key]
|
|
|
|
def mock_delete(key):
|
|
if key in storage:
|
|
del storage[key]
|
|
return 1
|
|
return 0
|
|
|
|
mock.get.side_effect = mock_get
|
|
mock.incr.side_effect = mock_incr
|
|
mock.delete.side_effect = mock_delete
|
|
mock.expire.return_value = True
|
|
|
|
return mock
|
|
|
|
@pytest.fixture
|
|
def rate_limiter(self, mock_redis):
|
|
"""Create rate limiter with simulated Redis."""
|
|
return RateLimiterService(redis_service=mock_redis)
|
|
|
|
def test_full_workflow(self, rate_limiter):
|
|
"""Test complete rate limiting workflow."""
|
|
user_id = "user_123"
|
|
tier = UserTier.FREE # 20 turns/day
|
|
|
|
# Initial state - should pass
|
|
rate_limiter.check_rate_limit(user_id, tier)
|
|
assert rate_limiter.get_remaining_turns(user_id, tier) == 20
|
|
|
|
# Use some turns
|
|
for i in range(15):
|
|
rate_limiter.check_rate_limit(user_id, tier)
|
|
rate_limiter.increment_usage(user_id)
|
|
|
|
# Check remaining
|
|
assert rate_limiter.get_remaining_turns(user_id, tier) == 5
|
|
|
|
# Use remaining turns
|
|
for i in range(5):
|
|
rate_limiter.check_rate_limit(user_id, tier)
|
|
rate_limiter.increment_usage(user_id)
|
|
|
|
# Now should be limited
|
|
assert rate_limiter.get_remaining_turns(user_id, tier) == 0
|
|
|
|
with pytest.raises(RateLimitExceeded):
|
|
rate_limiter.check_rate_limit(user_id, tier)
|
|
|
|
def test_different_tiers_same_usage(self, rate_limiter):
|
|
"""Test that same usage affects different tiers differently."""
|
|
user_id = "user_123"
|
|
|
|
# Use 30 turns
|
|
for _ in range(30):
|
|
rate_limiter.increment_usage(user_id)
|
|
|
|
# Free tier (20) should be limited
|
|
with pytest.raises(RateLimitExceeded):
|
|
rate_limiter.check_rate_limit(user_id, UserTier.FREE)
|
|
|
|
# Basic tier (50) should not be limited
|
|
rate_limiter.check_rate_limit(user_id, UserTier.BASIC)
|
|
|
|
# Premium tier (100) should not be limited
|
|
rate_limiter.check_rate_limit(user_id, UserTier.PREMIUM)
|
|
|
|
def test_reset_clears_usage(self, rate_limiter):
|
|
"""Test that reset allows new usage."""
|
|
user_id = "user_123"
|
|
tier = UserTier.FREE
|
|
|
|
# Use all turns
|
|
for _ in range(20):
|
|
rate_limiter.increment_usage(user_id)
|
|
|
|
# Should be limited
|
|
with pytest.raises(RateLimitExceeded):
|
|
rate_limiter.check_rate_limit(user_id, tier)
|
|
|
|
# Reset usage
|
|
rate_limiter.reset_usage(user_id)
|
|
|
|
# Should be able to use again
|
|
rate_limiter.check_rate_limit(user_id, tier)
|
|
assert rate_limiter.get_remaining_turns(user_id, tier) == 20
|