""" Tests for RateLimiterService Tests the tier-based rate limiting functionality including: - Daily limits per tier - Usage tracking and incrementing - Rate limit checks and exceptions - Reset functionality """ import pytest from unittest.mock import MagicMock, patch from datetime import datetime, timezone, timedelta from app.services.rate_limiter_service import ( RateLimiterService, RateLimitExceeded, ) from app.ai.model_selector import UserTier class TestRateLimiterService: """Tests for RateLimiterService.""" @pytest.fixture def mock_redis(self): """Create a mock Redis service.""" mock = MagicMock() mock.get.return_value = None mock.incr.return_value = 1 mock.expire.return_value = True mock.delete.return_value = 1 return mock @pytest.fixture def rate_limiter(self, mock_redis): """Create a RateLimiterService with mock Redis.""" return RateLimiterService(redis_service=mock_redis) def test_tier_limits(self, rate_limiter): """Test that tier limits are correctly defined.""" assert rate_limiter.get_limit_for_tier(UserTier.FREE) == 20 assert rate_limiter.get_limit_for_tier(UserTier.BASIC) == 50 assert rate_limiter.get_limit_for_tier(UserTier.PREMIUM) == 100 assert rate_limiter.get_limit_for_tier(UserTier.ELITE) == 200 def test_get_daily_key(self, rate_limiter): """Test Redis key generation.""" from datetime import date key = rate_limiter._get_daily_key("user_123", date(2025, 1, 15)) assert key == "rate_limit:daily:user_123:2025-01-15" def test_get_current_usage_no_usage(self, rate_limiter, mock_redis): """Test getting current usage when no usage exists.""" mock_redis.get.return_value = None usage = rate_limiter.get_current_usage("user_123") assert usage == 0 mock_redis.get.assert_called_once() def test_get_current_usage_with_usage(self, rate_limiter, mock_redis): """Test getting current usage when usage exists.""" mock_redis.get.return_value = "15" usage = rate_limiter.get_current_usage("user_123") assert usage == 15 def test_check_rate_limit_under_limit(self, rate_limiter, mock_redis): """Test that check passes when under limit.""" mock_redis.get.return_value = "10" # Should not raise rate_limiter.check_rate_limit("user_123", UserTier.FREE) def test_check_rate_limit_at_limit(self, rate_limiter, mock_redis): """Test that check fails when at limit.""" mock_redis.get.return_value = "20" # Free tier limit with pytest.raises(RateLimitExceeded) as exc_info: rate_limiter.check_rate_limit("user_123", UserTier.FREE) exc = exc_info.value assert exc.user_id == "user_123" assert exc.user_tier == UserTier.FREE assert exc.limit == 20 assert exc.current_usage == 20 def test_check_rate_limit_over_limit(self, rate_limiter, mock_redis): """Test that check fails when over limit.""" mock_redis.get.return_value = "25" # Over free tier limit with pytest.raises(RateLimitExceeded): rate_limiter.check_rate_limit("user_123", UserTier.FREE) def test_check_rate_limit_premium_tier(self, rate_limiter, mock_redis): """Test that premium tier has higher limit.""" mock_redis.get.return_value = "50" # Over free limit, under premium # Should not raise for premium rate_limiter.check_rate_limit("user_123", UserTier.PREMIUM) # Should raise for free with pytest.raises(RateLimitExceeded): rate_limiter.check_rate_limit("user_123", UserTier.FREE) def test_increment_usage_first_time(self, rate_limiter, mock_redis): """Test incrementing usage for the first time (sets expiration).""" mock_redis.incr.return_value = 1 new_count = rate_limiter.increment_usage("user_123") assert new_count == 1 mock_redis.incr.assert_called_once() mock_redis.expire.assert_called_once() # Should set expiration def test_increment_usage_subsequent(self, rate_limiter, mock_redis): """Test incrementing usage after first time (no expiration set).""" mock_redis.incr.return_value = 5 new_count = rate_limiter.increment_usage("user_123") assert new_count == 5 mock_redis.incr.assert_called_once() mock_redis.expire.assert_not_called() # Should NOT set expiration def test_get_remaining_turns_full(self, rate_limiter, mock_redis): """Test remaining turns when no usage.""" mock_redis.get.return_value = None remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE) assert remaining == 20 def test_get_remaining_turns_partial(self, rate_limiter, mock_redis): """Test remaining turns with partial usage.""" mock_redis.get.return_value = "12" remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE) assert remaining == 8 def test_get_remaining_turns_exhausted(self, rate_limiter, mock_redis): """Test remaining turns when limit reached.""" mock_redis.get.return_value = "20" remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE) assert remaining == 0 def test_get_remaining_turns_over_limit(self, rate_limiter, mock_redis): """Test remaining turns when over limit (should be 0, not negative).""" mock_redis.get.return_value = "25" remaining = rate_limiter.get_remaining_turns("user_123", UserTier.FREE) assert remaining == 0 def test_get_usage_info(self, rate_limiter, mock_redis): """Test getting comprehensive usage info.""" mock_redis.get.return_value = "15" info = rate_limiter.get_usage_info("user_123", UserTier.FREE) assert info["user_id"] == "user_123" assert info["user_tier"] == "free" assert info["current_usage"] == 15 assert info["daily_limit"] == 20 assert info["remaining"] == 5 assert info["is_limited"] is False assert "reset_time" in info def test_get_usage_info_limited(self, rate_limiter, mock_redis): """Test usage info when limited.""" mock_redis.get.return_value = "20" info = rate_limiter.get_usage_info("user_123", UserTier.FREE) assert info["is_limited"] is True assert info["remaining"] == 0 def test_reset_usage(self, rate_limiter, mock_redis): """Test resetting usage counter.""" mock_redis.delete.return_value = 1 result = rate_limiter.reset_usage("user_123") assert result is True mock_redis.delete.assert_called_once() def test_reset_usage_no_key(self, rate_limiter, mock_redis): """Test resetting when no usage exists.""" mock_redis.delete.return_value = 0 result = rate_limiter.reset_usage("user_123") assert result is False class TestRateLimitExceeded: """Tests for RateLimitExceeded exception.""" def test_exception_attributes(self): """Test that exception has correct attributes.""" reset_time = datetime(2025, 1, 16, 0, 0, 0, tzinfo=timezone.utc) exc = RateLimitExceeded( user_id="user_123", user_tier=UserTier.FREE, limit=20, current_usage=20, reset_time=reset_time ) assert exc.user_id == "user_123" assert exc.user_tier == UserTier.FREE assert exc.limit == 20 assert exc.current_usage == 20 assert exc.reset_time == reset_time def test_exception_message(self): """Test that exception message is formatted correctly.""" reset_time = datetime(2025, 1, 16, 0, 0, 0, tzinfo=timezone.utc) exc = RateLimitExceeded( user_id="user_123", user_tier=UserTier.FREE, limit=20, current_usage=20, reset_time=reset_time ) message = str(exc) assert "user_123" in message assert "free tier" in message assert "20/20" in message class TestRateLimiterIntegration: """Integration tests for rate limiter workflow.""" @pytest.fixture def mock_redis(self): """Create a mock Redis that simulates real behavior.""" storage = {} mock = MagicMock() def mock_get(key): return storage.get(key) def mock_incr(key): if key not in storage: storage[key] = 0 storage[key] = int(storage[key]) + 1 return storage[key] def mock_delete(key): if key in storage: del storage[key] return 1 return 0 mock.get.side_effect = mock_get mock.incr.side_effect = mock_incr mock.delete.side_effect = mock_delete mock.expire.return_value = True return mock @pytest.fixture def rate_limiter(self, mock_redis): """Create rate limiter with simulated Redis.""" return RateLimiterService(redis_service=mock_redis) def test_full_workflow(self, rate_limiter): """Test complete rate limiting workflow.""" user_id = "user_123" tier = UserTier.FREE # 20 turns/day # Initial state - should pass rate_limiter.check_rate_limit(user_id, tier) assert rate_limiter.get_remaining_turns(user_id, tier) == 20 # Use some turns for i in range(15): rate_limiter.check_rate_limit(user_id, tier) rate_limiter.increment_usage(user_id) # Check remaining assert rate_limiter.get_remaining_turns(user_id, tier) == 5 # Use remaining turns for i in range(5): rate_limiter.check_rate_limit(user_id, tier) rate_limiter.increment_usage(user_id) # Now should be limited assert rate_limiter.get_remaining_turns(user_id, tier) == 0 with pytest.raises(RateLimitExceeded): rate_limiter.check_rate_limit(user_id, tier) def test_different_tiers_same_usage(self, rate_limiter): """Test that same usage affects different tiers differently.""" user_id = "user_123" # Use 30 turns for _ in range(30): rate_limiter.increment_usage(user_id) # Free tier (20) should be limited with pytest.raises(RateLimitExceeded): rate_limiter.check_rate_limit(user_id, UserTier.FREE) # Basic tier (50) should not be limited rate_limiter.check_rate_limit(user_id, UserTier.BASIC) # Premium tier (100) should not be limited rate_limiter.check_rate_limit(user_id, UserTier.PREMIUM) def test_reset_clears_usage(self, rate_limiter): """Test that reset allows new usage.""" user_id = "user_123" tier = UserTier.FREE # Use all turns for _ in range(20): rate_limiter.increment_usage(user_id) # Should be limited with pytest.raises(RateLimitExceeded): rate_limiter.check_rate_limit(user_id, tier) # Reset usage rate_limiter.reset_usage(user_id) # Should be able to use again rate_limiter.check_rate_limit(user_id, tier) assert rate_limiter.get_remaining_turns(user_id, tier) == 20