Files
Aniworld/tests/api/test_rate_limiting_edge_cases.py

186 lines
6.0 KiB
Python

"""Edge case tests for rate limiting.
Tests the rate_limit_dependency from src/server/utils/dependencies.py
under various edge conditions: multiple IPs, window resets, bursts, etc.
"""
import time
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException
from src.server.utils.dependencies import (
_RATE_LIMIT_BUCKETS,
_RATE_LIMIT_WINDOW_SECONDS,
RateLimitRecord,
rate_limit_dependency,
)
def _make_request(ip: str = "127.0.0.1") -> MagicMock:
"""Create a mock Request with a given client IP."""
req = MagicMock()
req.client = MagicMock()
req.client.host = ip
return req
@pytest.fixture(autouse=True)
def _clear_buckets():
"""Clear rate limit buckets before each test."""
_RATE_LIMIT_BUCKETS.clear()
yield
_RATE_LIMIT_BUCKETS.clear()
class TestRateLimitBasic:
"""Basic rate limit behaviour."""
@pytest.mark.asyncio
@patch("src.server.utils.dependencies.settings")
async def test_first_request_passes(self, mock_settings):
"""First request from a new IP is always allowed."""
mock_settings.api_rate_limit = 100
await rate_limit_dependency(_make_request("10.0.0.1"))
@pytest.mark.asyncio
@patch("src.server.utils.dependencies.settings")
async def test_exceeding_limit_raises_429(self, mock_settings):
"""Exceeding max requests within a window raises 429."""
mock_settings.api_rate_limit = 3
req = _make_request("10.0.0.2")
for _ in range(3):
await rate_limit_dependency(req)
with pytest.raises(HTTPException) as exc_info:
await rate_limit_dependency(req)
assert exc_info.value.status_code == 429
class TestMultipleIPs:
"""Rate limiting with different client IPs."""
@pytest.mark.asyncio
@patch("src.server.utils.dependencies.settings")
async def test_different_ips_have_separate_buckets(self, mock_settings):
"""Each IP has its own counter."""
mock_settings.api_rate_limit = 2
ip1 = _make_request("1.1.1.1")
ip2 = _make_request("2.2.2.2")
# Both use 2 requests
await rate_limit_dependency(ip1)
await rate_limit_dependency(ip1)
await rate_limit_dependency(ip2)
await rate_limit_dependency(ip2)
# ip1 is at limit, ip2 is at limit
with pytest.raises(HTTPException):
await rate_limit_dependency(ip1)
with pytest.raises(HTTPException):
await rate_limit_dependency(ip2)
@pytest.mark.asyncio
@patch("src.server.utils.dependencies.settings")
async def test_one_ip_limited_does_not_affect_other(self, mock_settings):
"""Blocking one IP doesn't block another."""
mock_settings.api_rate_limit = 1
ip_blocked = _make_request("10.0.0.1")
ip_ok = _make_request("10.0.0.2")
await rate_limit_dependency(ip_blocked)
with pytest.raises(HTTPException):
await rate_limit_dependency(ip_blocked)
# Other IP still fine
await rate_limit_dependency(ip_ok)
class TestWindowReset:
"""Rate limit window expiration and reset."""
@pytest.mark.asyncio
@patch("src.server.utils.dependencies.settings")
async def test_requests_allowed_after_window_expires(self, mock_settings):
"""Counter resets when window expires."""
mock_settings.api_rate_limit = 1
req = _make_request("10.0.0.3")
await rate_limit_dependency(req)
with pytest.raises(HTTPException):
await rate_limit_dependency(req)
# Manually expire the window
bucket = _RATE_LIMIT_BUCKETS.get("10.0.0.3")
if bucket:
bucket.window_start -= _RATE_LIMIT_WINDOW_SECONDS + 1
# Should pass now
await rate_limit_dependency(req)
class TestBurstTraffic:
"""Burst traffic handling."""
@pytest.mark.asyncio
@patch("src.server.utils.dependencies.settings")
async def test_burst_up_to_limit_passes(self, mock_settings):
"""Rapid requests up to the limit all pass."""
mock_settings.api_rate_limit = 50
req = _make_request("10.0.0.4")
for _ in range(50):
await rate_limit_dependency(req)
@pytest.mark.asyncio
@patch("src.server.utils.dependencies.settings")
async def test_burst_over_limit_blocked(self, mock_settings):
"""Request 51 in a burst of 50 is blocked."""
mock_settings.api_rate_limit = 50
req = _make_request("10.0.0.5")
for _ in range(50):
await rate_limit_dependency(req)
with pytest.raises(HTTPException) as exc_info:
await rate_limit_dependency(req)
assert exc_info.value.status_code == 429
class TestMissingClient:
"""Requests without client information."""
@pytest.mark.asyncio
@patch("src.server.utils.dependencies.settings")
async def test_no_client_uses_unknown_key(self, mock_settings):
"""When request.client is None, 'unknown' is used as key."""
mock_settings.api_rate_limit = 100
req = MagicMock()
req.client = None
await rate_limit_dependency(req)
assert "unknown" in _RATE_LIMIT_BUCKETS
@pytest.mark.asyncio
@patch("src.server.utils.dependencies.settings")
async def test_multiple_unknown_share_bucket(self, mock_settings):
"""All client-less requests share the 'unknown' bucket."""
mock_settings.api_rate_limit = 2
req = MagicMock()
req.client = None
await rate_limit_dependency(req)
await rate_limit_dependency(req)
with pytest.raises(HTTPException):
await rate_limit_dependency(req)
class TestRateLimitRecord:
"""Unit tests for the RateLimitRecord dataclass."""
def test_creation(self):
"""Can create a RateLimitRecord."""
rec = RateLimitRecord(count=5, window_start=time.time())
assert rec.count == 5
def test_mutation(self):
"""Count can be incremented."""
rec = RateLimitRecord(count=0, window_start=time.time())
rec.count += 1
assert rec.count == 1