186 lines
6.0 KiB
Python
186 lines
6.0 KiB
Python
"""Edge case tests for rate limiting.
|
|
|
|
Tests the rate_limit_dependency from src/server/utils/dependencies.py
|
|
under various edge conditions: multiple IPs, window resets, bursts, etc.
|
|
"""
|
|
|
|
import time
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from fastapi import HTTPException
|
|
|
|
from src.server.utils.dependencies import (
|
|
_RATE_LIMIT_BUCKETS,
|
|
_RATE_LIMIT_WINDOW_SECONDS,
|
|
RateLimitRecord,
|
|
rate_limit_dependency,
|
|
)
|
|
|
|
|
|
def _make_request(ip: str = "127.0.0.1") -> MagicMock:
|
|
"""Create a mock Request with a given client IP."""
|
|
req = MagicMock()
|
|
req.client = MagicMock()
|
|
req.client.host = ip
|
|
return req
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _clear_buckets():
|
|
"""Clear rate limit buckets before each test."""
|
|
_RATE_LIMIT_BUCKETS.clear()
|
|
yield
|
|
_RATE_LIMIT_BUCKETS.clear()
|
|
|
|
|
|
class TestRateLimitBasic:
|
|
"""Basic rate limit behaviour."""
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("src.server.utils.dependencies.settings")
|
|
async def test_first_request_passes(self, mock_settings):
|
|
"""First request from a new IP is always allowed."""
|
|
mock_settings.api_rate_limit = 100
|
|
await rate_limit_dependency(_make_request("10.0.0.1"))
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("src.server.utils.dependencies.settings")
|
|
async def test_exceeding_limit_raises_429(self, mock_settings):
|
|
"""Exceeding max requests within a window raises 429."""
|
|
mock_settings.api_rate_limit = 3
|
|
req = _make_request("10.0.0.2")
|
|
for _ in range(3):
|
|
await rate_limit_dependency(req)
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await rate_limit_dependency(req)
|
|
assert exc_info.value.status_code == 429
|
|
|
|
|
|
class TestMultipleIPs:
|
|
"""Rate limiting with different client IPs."""
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("src.server.utils.dependencies.settings")
|
|
async def test_different_ips_have_separate_buckets(self, mock_settings):
|
|
"""Each IP has its own counter."""
|
|
mock_settings.api_rate_limit = 2
|
|
ip1 = _make_request("1.1.1.1")
|
|
ip2 = _make_request("2.2.2.2")
|
|
|
|
# Both use 2 requests
|
|
await rate_limit_dependency(ip1)
|
|
await rate_limit_dependency(ip1)
|
|
await rate_limit_dependency(ip2)
|
|
await rate_limit_dependency(ip2)
|
|
|
|
# ip1 is at limit, ip2 is at limit
|
|
with pytest.raises(HTTPException):
|
|
await rate_limit_dependency(ip1)
|
|
with pytest.raises(HTTPException):
|
|
await rate_limit_dependency(ip2)
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("src.server.utils.dependencies.settings")
|
|
async def test_one_ip_limited_does_not_affect_other(self, mock_settings):
|
|
"""Blocking one IP doesn't block another."""
|
|
mock_settings.api_rate_limit = 1
|
|
ip_blocked = _make_request("10.0.0.1")
|
|
ip_ok = _make_request("10.0.0.2")
|
|
|
|
await rate_limit_dependency(ip_blocked)
|
|
with pytest.raises(HTTPException):
|
|
await rate_limit_dependency(ip_blocked)
|
|
|
|
# Other IP still fine
|
|
await rate_limit_dependency(ip_ok)
|
|
|
|
|
|
class TestWindowReset:
|
|
"""Rate limit window expiration and reset."""
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("src.server.utils.dependencies.settings")
|
|
async def test_requests_allowed_after_window_expires(self, mock_settings):
|
|
"""Counter resets when window expires."""
|
|
mock_settings.api_rate_limit = 1
|
|
req = _make_request("10.0.0.3")
|
|
|
|
await rate_limit_dependency(req)
|
|
with pytest.raises(HTTPException):
|
|
await rate_limit_dependency(req)
|
|
|
|
# Manually expire the window
|
|
bucket = _RATE_LIMIT_BUCKETS.get("10.0.0.3")
|
|
if bucket:
|
|
bucket.window_start -= _RATE_LIMIT_WINDOW_SECONDS + 1
|
|
|
|
# Should pass now
|
|
await rate_limit_dependency(req)
|
|
|
|
|
|
class TestBurstTraffic:
|
|
"""Burst traffic handling."""
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("src.server.utils.dependencies.settings")
|
|
async def test_burst_up_to_limit_passes(self, mock_settings):
|
|
"""Rapid requests up to the limit all pass."""
|
|
mock_settings.api_rate_limit = 50
|
|
req = _make_request("10.0.0.4")
|
|
for _ in range(50):
|
|
await rate_limit_dependency(req)
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("src.server.utils.dependencies.settings")
|
|
async def test_burst_over_limit_blocked(self, mock_settings):
|
|
"""Request 51 in a burst of 50 is blocked."""
|
|
mock_settings.api_rate_limit = 50
|
|
req = _make_request("10.0.0.5")
|
|
for _ in range(50):
|
|
await rate_limit_dependency(req)
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
await rate_limit_dependency(req)
|
|
assert exc_info.value.status_code == 429
|
|
|
|
|
|
class TestMissingClient:
|
|
"""Requests without client information."""
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("src.server.utils.dependencies.settings")
|
|
async def test_no_client_uses_unknown_key(self, mock_settings):
|
|
"""When request.client is None, 'unknown' is used as key."""
|
|
mock_settings.api_rate_limit = 100
|
|
req = MagicMock()
|
|
req.client = None
|
|
await rate_limit_dependency(req)
|
|
assert "unknown" in _RATE_LIMIT_BUCKETS
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("src.server.utils.dependencies.settings")
|
|
async def test_multiple_unknown_share_bucket(self, mock_settings):
|
|
"""All client-less requests share the 'unknown' bucket."""
|
|
mock_settings.api_rate_limit = 2
|
|
req = MagicMock()
|
|
req.client = None
|
|
await rate_limit_dependency(req)
|
|
await rate_limit_dependency(req)
|
|
with pytest.raises(HTTPException):
|
|
await rate_limit_dependency(req)
|
|
|
|
|
|
class TestRateLimitRecord:
|
|
"""Unit tests for the RateLimitRecord dataclass."""
|
|
|
|
def test_creation(self):
|
|
"""Can create a RateLimitRecord."""
|
|
rec = RateLimitRecord(count=5, window_start=time.time())
|
|
assert rec.count == 5
|
|
|
|
def test_mutation(self):
|
|
"""Count can be incremented."""
|
|
rec = RateLimitRecord(count=0, window_start=time.time())
|
|
rec.count += 1
|
|
assert rec.count == 1
|