Implement global rate limiter and refactor auth middleware
- Add global rate limiter utility with configurable limits and cleanup - Move rate limiting logic to middleware for consistent application - Update auth routes to use new rate limiter - Add comprehensive tests for rate limiter functionality - Update documentation with backend development guidelines and tasks Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
183
backend/tests/test_utils/test_global_rate_limiter.py
Normal file
183
backend/tests/test_utils/test_global_rate_limiter.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""Tests for the global rate limiter and rate limit middleware."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from httpx import AsyncClient
|
||||
|
||||
_SETUP_PAYLOAD = {
|
||||
"master_password": "Mysecretpass1!",
|
||||
"database_path": "bangui.db",
|
||||
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
|
||||
"timezone": "UTC",
|
||||
"session_duration_minutes": 60,
|
||||
}
|
||||
|
||||
|
||||
async def _do_setup(client: AsyncClient) -> None:
|
||||
"""Run the setup wizard so auth endpoints are reachable."""
|
||||
resp = await client.post("/api/setup", json=_SETUP_PAYLOAD)
|
||||
assert resp.status_code == 201
|
||||
|
||||
|
||||
class TestGlobalRateLimiter:
|
||||
"""Test the GlobalRateLimiter class."""
|
||||
|
||||
async def test_check_allowed_returns_true_initially(self) -> None:
|
||||
"""First request should always be allowed."""
|
||||
from app.utils.rate_limiter import GlobalRateLimiter
|
||||
|
||||
limiter = GlobalRateLimiter(max_requests=5, window_seconds=60)
|
||||
is_allowed, retry_after = limiter.check_allowed("192.168.1.1")
|
||||
|
||||
assert is_allowed is True
|
||||
assert retry_after == 0.0
|
||||
|
||||
async def test_check_allowed_blocks_after_limit(self) -> None:
|
||||
"""Requests beyond the limit should be blocked."""
|
||||
from app.utils.rate_limiter import GlobalRateLimiter
|
||||
|
||||
limiter = GlobalRateLimiter(max_requests=2, window_seconds=60)
|
||||
|
||||
# First two requests allowed
|
||||
assert limiter.check_allowed("192.168.1.1")[0] is True
|
||||
assert limiter.check_allowed("192.168.1.1")[0] is True
|
||||
|
||||
# Third request blocked
|
||||
is_allowed, retry_after = limiter.check_allowed("192.168.1.1")
|
||||
assert is_allowed is False
|
||||
assert retry_after > 0
|
||||
|
||||
async def test_check_allowed_per_ip_isolation(self) -> None:
|
||||
"""Different IPs should have independent limits."""
|
||||
from app.utils.rate_limiter import GlobalRateLimiter
|
||||
|
||||
limiter = GlobalRateLimiter(max_requests=2, window_seconds=60)
|
||||
|
||||
# IP1 hits limit
|
||||
assert limiter.check_allowed("192.168.1.1")[0] is True
|
||||
assert limiter.check_allowed("192.168.1.1")[0] is True
|
||||
assert limiter.check_allowed("192.168.1.1")[0] is False
|
||||
|
||||
# IP2 should still have allowance
|
||||
assert limiter.check_allowed("192.168.1.2")[0] is True
|
||||
assert limiter.check_allowed("192.168.1.2")[0] is True
|
||||
assert limiter.check_allowed("192.168.1.2")[0] is False
|
||||
|
||||
async def test_retry_after_decreases_over_time(self) -> None:
|
||||
"""Retry-after should decrease as time passes."""
|
||||
from app.utils.rate_limiter import GlobalRateLimiter
|
||||
|
||||
limiter = GlobalRateLimiter(max_requests=2, window_seconds=10)
|
||||
|
||||
# Hit limit
|
||||
limiter.check_allowed("192.168.1.1")
|
||||
limiter.check_allowed("192.168.1.1")
|
||||
_, retry_after_1 = limiter.check_allowed("192.168.1.1")
|
||||
|
||||
# Wait and check again
|
||||
await asyncio.sleep(2)
|
||||
_, retry_after_2 = limiter.check_allowed("192.168.1.1")
|
||||
|
||||
assert retry_after_2 < retry_after_1
|
||||
|
||||
async def test_get_state(self) -> None:
|
||||
"""get_state should return request counts per IP."""
|
||||
from app.utils.rate_limiter import GlobalRateLimiter
|
||||
|
||||
limiter = GlobalRateLimiter(max_requests=5, window_seconds=60)
|
||||
|
||||
limiter.check_allowed("192.168.1.1")
|
||||
limiter.check_allowed("192.168.1.1")
|
||||
limiter.check_allowed("192.168.1.2")
|
||||
|
||||
state = limiter.get_state()
|
||||
assert state["192.168.1.1"] == 2
|
||||
assert state["192.168.1.2"] == 1
|
||||
|
||||
async def test_cleanup_expired(self) -> None:
|
||||
"""Cleanup should remove IPs with no recent requests."""
|
||||
from app.utils.rate_limiter import GlobalRateLimiter
|
||||
|
||||
limiter = GlobalRateLimiter(max_requests=5, window_seconds=1)
|
||||
|
||||
limiter.check_allowed("192.168.1.1")
|
||||
state_before = limiter.get_state()
|
||||
assert "192.168.1.1" in state_before
|
||||
|
||||
# Wait for window to expire
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
limiter.cleanup_expired()
|
||||
state_after = limiter.get_state()
|
||||
assert "192.168.1.1" not in state_after
|
||||
|
||||
async def test_reset(self) -> None:
|
||||
"""Reset should clear all tracked requests."""
|
||||
from app.utils.rate_limiter import GlobalRateLimiter
|
||||
|
||||
limiter = GlobalRateLimiter(max_requests=5, window_seconds=60)
|
||||
|
||||
limiter.check_allowed("192.168.1.1")
|
||||
limiter.check_allowed("192.168.1.2")
|
||||
|
||||
limiter.reset()
|
||||
state = limiter.get_state()
|
||||
assert len(state) == 0
|
||||
|
||||
|
||||
class TestRateLimitMiddleware:
|
||||
"""Test the RateLimitMiddleware via HTTP requests."""
|
||||
|
||||
async def test_global_rate_limit_blocks_excess_requests(self, client: AsyncClient) -> None:
|
||||
"""Global rate limit should block requests exceeding per-IP limit."""
|
||||
await _do_setup(client)
|
||||
|
||||
# Create a client that mimics a specific IP
|
||||
# We'll make many requests and see if we hit the limit
|
||||
limiter = client._transport.app.state.global_rate_limiter
|
||||
limiter.reset()
|
||||
|
||||
# Reduce limit temporarily for testing
|
||||
original_max = limiter.max_requests
|
||||
limiter.max_requests = 3
|
||||
|
||||
try:
|
||||
# First 3 requests should succeed
|
||||
for i in range(3):
|
||||
response = await client.get("/api/health")
|
||||
assert response.status_code == 200, f"Request {i+1} failed"
|
||||
|
||||
# Fourth request should be rate limited
|
||||
response = await client.get("/api/health")
|
||||
assert response.status_code == 429
|
||||
assert response.json()["code"] == "rate_limit_exceeded"
|
||||
assert "Retry-After" in response.headers
|
||||
finally:
|
||||
limiter.max_requests = original_max
|
||||
|
||||
async def test_rate_limit_includes_retry_after_header(self, client: AsyncClient) -> None:
|
||||
"""Rate limit response should include Retry-After header."""
|
||||
await _do_setup(client)
|
||||
|
||||
limiter = client._transport.app.state.global_rate_limiter
|
||||
limiter.reset()
|
||||
|
||||
original_max = limiter.max_requests
|
||||
limiter.max_requests = 1
|
||||
|
||||
try:
|
||||
# First request succeeds
|
||||
response = await client.get("/api/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Second request is rate limited
|
||||
response = await client.get("/api/health")
|
||||
assert response.status_code == 429
|
||||
assert "Retry-After" in response.headers
|
||||
retry_after = int(response.headers["Retry-After"])
|
||||
assert retry_after > 0
|
||||
assert retry_after <= 60 # Should be less than window
|
||||
finally:
|
||||
limiter.max_requests = original_max
|
||||
Reference in New Issue
Block a user