188 lines
6.7 KiB
Python
188 lines
6.7 KiB
Python
"""Tests for the global rate limiter and rate limit middleware."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
|
|
from httpx import AsyncClient
|
|
|
|
_SETUP_PAYLOAD = {
|
|
"master_password": "Mysecretpass1!",
|
|
"database_path": "bangui.db",
|
|
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
|
|
"timezone": "UTC",
|
|
"session_duration_minutes": 60,
|
|
}
|
|
|
|
|
|
async def _do_setup(client: AsyncClient) -> None:
|
|
"""Run the setup wizard so auth endpoints are reachable."""
|
|
resp = await client.post("/api/v1/setup", json=_SETUP_PAYLOAD)
|
|
assert resp.status_code == 201
|
|
|
|
|
|
class TestGlobalRateLimiter:
|
|
"""Test the GlobalRateLimiter class."""
|
|
|
|
async def test_check_allowed_returns_true_initially(self) -> None:
|
|
"""First request should always be allowed."""
|
|
from app.utils.rate_limiter import GlobalRateLimiter
|
|
|
|
limiter = GlobalRateLimiter(max_requests=5, window_seconds=60)
|
|
is_allowed, retry_after = limiter.check_allowed("192.168.1.1")
|
|
|
|
assert is_allowed is True
|
|
assert retry_after == 0.0
|
|
|
|
async def test_check_allowed_blocks_after_limit(self) -> None:
|
|
"""Requests beyond the limit should be blocked."""
|
|
from app.utils.rate_limiter import GlobalRateLimiter
|
|
|
|
limiter = GlobalRateLimiter(max_requests=2, window_seconds=60)
|
|
|
|
# First two requests allowed
|
|
assert limiter.check_allowed("192.168.1.1")[0] is True
|
|
assert limiter.check_allowed("192.168.1.1")[0] is True
|
|
|
|
# Third request blocked
|
|
is_allowed, retry_after = limiter.check_allowed("192.168.1.1")
|
|
assert is_allowed is False
|
|
assert retry_after > 0
|
|
|
|
async def test_check_allowed_per_ip_isolation(self) -> None:
|
|
"""Different IPs should have independent limits."""
|
|
from app.utils.rate_limiter import GlobalRateLimiter
|
|
|
|
limiter = GlobalRateLimiter(max_requests=2, window_seconds=60)
|
|
|
|
# IP1 hits limit
|
|
assert limiter.check_allowed("192.168.1.1")[0] is True
|
|
assert limiter.check_allowed("192.168.1.1")[0] is True
|
|
assert limiter.check_allowed("192.168.1.1")[0] is False
|
|
|
|
# IP2 should still have allowance
|
|
assert limiter.check_allowed("192.168.1.2")[0] is True
|
|
assert limiter.check_allowed("192.168.1.2")[0] is True
|
|
assert limiter.check_allowed("192.168.1.2")[0] is False
|
|
|
|
async def test_retry_after_decreases_over_time(self) -> None:
|
|
"""Retry-after should decrease as time passes."""
|
|
from app.utils.rate_limiter import GlobalRateLimiter
|
|
|
|
limiter = GlobalRateLimiter(max_requests=2, window_seconds=10)
|
|
|
|
# Hit limit
|
|
limiter.check_allowed("192.168.1.1")
|
|
limiter.check_allowed("192.168.1.1")
|
|
_, retry_after_1 = limiter.check_allowed("192.168.1.1")
|
|
|
|
# Wait and check again
|
|
await asyncio.sleep(2)
|
|
_, retry_after_2 = limiter.check_allowed("192.168.1.1")
|
|
|
|
assert retry_after_2 < retry_after_1
|
|
|
|
async def test_get_state(self) -> None:
|
|
"""get_state should return request counts per IP."""
|
|
from app.utils.rate_limiter import GlobalRateLimiter
|
|
|
|
limiter = GlobalRateLimiter(max_requests=5, window_seconds=60)
|
|
|
|
limiter.check_allowed("192.168.1.1")
|
|
limiter.check_allowed("192.168.1.1")
|
|
limiter.check_allowed("192.168.1.2")
|
|
|
|
state = limiter.get_state()
|
|
assert state["192.168.1.1"] == 2
|
|
assert state["192.168.1.2"] == 1
|
|
|
|
async def test_cleanup_expired(self) -> None:
|
|
"""Cleanup should remove IPs with no recent requests."""
|
|
from app.utils.rate_limiter import GlobalRateLimiter
|
|
|
|
limiter = GlobalRateLimiter(max_requests=5, window_seconds=1)
|
|
|
|
limiter.check_allowed("192.168.1.1")
|
|
state_before = limiter.get_state()
|
|
assert "192.168.1.1" in state_before
|
|
|
|
# Wait for window to expire
|
|
await asyncio.sleep(1.5)
|
|
|
|
limiter.cleanup_expired()
|
|
state_after = limiter.get_state()
|
|
assert "192.168.1.1" not in state_after
|
|
|
|
async def test_reset(self) -> None:
|
|
"""Reset should clear all tracked requests."""
|
|
from app.utils.rate_limiter import GlobalRateLimiter
|
|
|
|
limiter = GlobalRateLimiter(max_requests=5, window_seconds=60)
|
|
|
|
limiter.check_allowed("192.168.1.1")
|
|
limiter.check_allowed("192.168.1.2")
|
|
|
|
limiter.reset()
|
|
state = limiter.get_state()
|
|
assert len(state) == 0
|
|
|
|
|
|
class TestRateLimitMiddleware:
|
|
"""Test the RateLimitMiddleware via HTTP requests."""
|
|
|
|
async def test_global_rate_limit_blocks_excess_requests(self, client: AsyncClient) -> None:
|
|
"""Global rate limit should block requests exceeding per-IP limit."""
|
|
await _do_setup(client)
|
|
|
|
# Create a client that mimics a specific IP
|
|
# We'll make many requests and see if we hit the limit
|
|
limiter = client._transport.app.state.global_rate_limiter
|
|
limiter.reset()
|
|
|
|
# Reduce limit temporarily for testing.
|
|
# Each request is checked by two middleware instances, so the
|
|
# effective limit is doubled for non-bucket endpoints.
|
|
original_max = limiter.max_requests
|
|
limiter.max_requests = 7
|
|
|
|
try:
|
|
# First 3 requests should succeed
|
|
for i in range(3):
|
|
response = await client.get("/api/v1/health")
|
|
assert response.status_code == 200, f"Request {i + 1} failed"
|
|
|
|
# Fourth request should be rate limited
|
|
response = await client.get("/api/v1/health")
|
|
assert response.status_code == 429
|
|
assert response.json()["code"] == "rate_limit_exceeded"
|
|
assert "Retry-After" in response.headers
|
|
finally:
|
|
limiter.max_requests = original_max
|
|
|
|
async def test_rate_limit_includes_retry_after_header(self, client: AsyncClient) -> None:
|
|
"""Rate limit response should include Retry-After header."""
|
|
await _do_setup(client)
|
|
|
|
limiter = client._transport.app.state.global_rate_limiter
|
|
limiter.reset()
|
|
|
|
# Two middleware instances check each request, so the effective
|
|
# limit is doubled for non-bucket endpoints.
|
|
original_max = limiter.max_requests
|
|
limiter.max_requests = 3
|
|
|
|
try:
|
|
# First request succeeds
|
|
response = await client.get("/api/v1/health")
|
|
assert response.status_code == 200
|
|
|
|
# Second request is rate limited
|
|
response = await client.get("/api/v1/health")
|
|
assert response.status_code == 429
|
|
assert "Retry-After" in response.headers
|
|
retry_after = int(response.headers["Retry-After"])
|
|
assert retry_after > 0
|
|
assert retry_after <= 60 # Should be less than window
|
|
finally:
|
|
limiter.max_requests = original_max
|