Add Task 7 edge case and regression tests
This commit is contained in:
185
tests/api/test_rate_limiting_edge_cases.py
Normal file
185
tests/api/test_rate_limiting_edge_cases.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""Edge case tests for rate limiting.
|
||||
|
||||
Tests the rate_limit_dependency from src/server/utils/dependencies.py
|
||||
under various edge conditions: multiple IPs, window resets, bursts, etc.
|
||||
"""
|
||||
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from src.server.utils.dependencies import (
|
||||
_RATE_LIMIT_BUCKETS,
|
||||
_RATE_LIMIT_WINDOW_SECONDS,
|
||||
RateLimitRecord,
|
||||
rate_limit_dependency,
|
||||
)
|
||||
|
||||
|
||||
def _make_request(ip: str = "127.0.0.1") -> MagicMock:
|
||||
"""Create a mock Request with a given client IP."""
|
||||
req = MagicMock()
|
||||
req.client = MagicMock()
|
||||
req.client.host = ip
|
||||
return req
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_buckets():
|
||||
"""Clear rate limit buckets before each test."""
|
||||
_RATE_LIMIT_BUCKETS.clear()
|
||||
yield
|
||||
_RATE_LIMIT_BUCKETS.clear()
|
||||
|
||||
|
||||
class TestRateLimitBasic:
|
||||
"""Basic rate limit behaviour."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.server.utils.dependencies.settings")
|
||||
async def test_first_request_passes(self, mock_settings):
|
||||
"""First request from a new IP is always allowed."""
|
||||
mock_settings.api_rate_limit = 100
|
||||
await rate_limit_dependency(_make_request("10.0.0.1"))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.server.utils.dependencies.settings")
|
||||
async def test_exceeding_limit_raises_429(self, mock_settings):
|
||||
"""Exceeding max requests within a window raises 429."""
|
||||
mock_settings.api_rate_limit = 3
|
||||
req = _make_request("10.0.0.2")
|
||||
for _ in range(3):
|
||||
await rate_limit_dependency(req)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await rate_limit_dependency(req)
|
||||
assert exc_info.value.status_code == 429
|
||||
|
||||
|
||||
class TestMultipleIPs:
|
||||
"""Rate limiting with different client IPs."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.server.utils.dependencies.settings")
|
||||
async def test_different_ips_have_separate_buckets(self, mock_settings):
|
||||
"""Each IP has its own counter."""
|
||||
mock_settings.api_rate_limit = 2
|
||||
ip1 = _make_request("1.1.1.1")
|
||||
ip2 = _make_request("2.2.2.2")
|
||||
|
||||
# Both use 2 requests
|
||||
await rate_limit_dependency(ip1)
|
||||
await rate_limit_dependency(ip1)
|
||||
await rate_limit_dependency(ip2)
|
||||
await rate_limit_dependency(ip2)
|
||||
|
||||
# ip1 is at limit, ip2 is at limit
|
||||
with pytest.raises(HTTPException):
|
||||
await rate_limit_dependency(ip1)
|
||||
with pytest.raises(HTTPException):
|
||||
await rate_limit_dependency(ip2)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.server.utils.dependencies.settings")
|
||||
async def test_one_ip_limited_does_not_affect_other(self, mock_settings):
|
||||
"""Blocking one IP doesn't block another."""
|
||||
mock_settings.api_rate_limit = 1
|
||||
ip_blocked = _make_request("10.0.0.1")
|
||||
ip_ok = _make_request("10.0.0.2")
|
||||
|
||||
await rate_limit_dependency(ip_blocked)
|
||||
with pytest.raises(HTTPException):
|
||||
await rate_limit_dependency(ip_blocked)
|
||||
|
||||
# Other IP still fine
|
||||
await rate_limit_dependency(ip_ok)
|
||||
|
||||
|
||||
class TestWindowReset:
|
||||
"""Rate limit window expiration and reset."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.server.utils.dependencies.settings")
|
||||
async def test_requests_allowed_after_window_expires(self, mock_settings):
|
||||
"""Counter resets when window expires."""
|
||||
mock_settings.api_rate_limit = 1
|
||||
req = _make_request("10.0.0.3")
|
||||
|
||||
await rate_limit_dependency(req)
|
||||
with pytest.raises(HTTPException):
|
||||
await rate_limit_dependency(req)
|
||||
|
||||
# Manually expire the window
|
||||
bucket = _RATE_LIMIT_BUCKETS.get("10.0.0.3")
|
||||
if bucket:
|
||||
bucket.window_start -= _RATE_LIMIT_WINDOW_SECONDS + 1
|
||||
|
||||
# Should pass now
|
||||
await rate_limit_dependency(req)
|
||||
|
||||
|
||||
class TestBurstTraffic:
|
||||
"""Burst traffic handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.server.utils.dependencies.settings")
|
||||
async def test_burst_up_to_limit_passes(self, mock_settings):
|
||||
"""Rapid requests up to the limit all pass."""
|
||||
mock_settings.api_rate_limit = 50
|
||||
req = _make_request("10.0.0.4")
|
||||
for _ in range(50):
|
||||
await rate_limit_dependency(req)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.server.utils.dependencies.settings")
|
||||
async def test_burst_over_limit_blocked(self, mock_settings):
|
||||
"""Request 51 in a burst of 50 is blocked."""
|
||||
mock_settings.api_rate_limit = 50
|
||||
req = _make_request("10.0.0.5")
|
||||
for _ in range(50):
|
||||
await rate_limit_dependency(req)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await rate_limit_dependency(req)
|
||||
assert exc_info.value.status_code == 429
|
||||
|
||||
|
||||
class TestMissingClient:
|
||||
"""Requests without client information."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.server.utils.dependencies.settings")
|
||||
async def test_no_client_uses_unknown_key(self, mock_settings):
|
||||
"""When request.client is None, 'unknown' is used as key."""
|
||||
mock_settings.api_rate_limit = 100
|
||||
req = MagicMock()
|
||||
req.client = None
|
||||
await rate_limit_dependency(req)
|
||||
assert "unknown" in _RATE_LIMIT_BUCKETS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("src.server.utils.dependencies.settings")
|
||||
async def test_multiple_unknown_share_bucket(self, mock_settings):
|
||||
"""All client-less requests share the 'unknown' bucket."""
|
||||
mock_settings.api_rate_limit = 2
|
||||
req = MagicMock()
|
||||
req.client = None
|
||||
await rate_limit_dependency(req)
|
||||
await rate_limit_dependency(req)
|
||||
with pytest.raises(HTTPException):
|
||||
await rate_limit_dependency(req)
|
||||
|
||||
|
||||
class TestRateLimitRecord:
|
||||
"""Unit tests for the RateLimitRecord dataclass."""
|
||||
|
||||
def test_creation(self):
|
||||
"""Can create a RateLimitRecord."""
|
||||
rec = RateLimitRecord(count=5, window_start=time.time())
|
||||
assert rec.count == 5
|
||||
|
||||
def test_mutation(self):
|
||||
"""Count can be incremented."""
|
||||
rec = RateLimitRecord(count=0, window_start=time.time())
|
||||
rec.count += 1
|
||||
assert rec.count == 1
|
||||
Reference in New Issue
Block a user