"""Tests for the global rate limiter and rate limit middleware.""" from __future__ import annotations import asyncio from httpx import AsyncClient _SETUP_PAYLOAD = { "master_password": "Mysecretpass1!", "database_path": "bangui.db", "fail2ban_socket": "/var/run/fail2ban/fail2ban.sock", "timezone": "UTC", "session_duration_minutes": 60, } async def _do_setup(client: AsyncClient) -> None: """Run the setup wizard so auth endpoints are reachable.""" resp = await client.post("/api/v1/setup", json=_SETUP_PAYLOAD) assert resp.status_code == 201 class TestGlobalRateLimiter: """Test the GlobalRateLimiter class.""" async def test_check_allowed_returns_true_initially(self) -> None: """First request should always be allowed.""" from app.utils.rate_limiter import GlobalRateLimiter limiter = GlobalRateLimiter(max_requests=5, window_seconds=60) is_allowed, retry_after = limiter.check_allowed("192.168.1.1") assert is_allowed is True assert retry_after == 0.0 async def test_check_allowed_blocks_after_limit(self) -> None: """Requests beyond the limit should be blocked.""" from app.utils.rate_limiter import GlobalRateLimiter limiter = GlobalRateLimiter(max_requests=2, window_seconds=60) # First two requests allowed assert limiter.check_allowed("192.168.1.1")[0] is True assert limiter.check_allowed("192.168.1.1")[0] is True # Third request blocked is_allowed, retry_after = limiter.check_allowed("192.168.1.1") assert is_allowed is False assert retry_after > 0 async def test_check_allowed_per_ip_isolation(self) -> None: """Different IPs should have independent limits.""" from app.utils.rate_limiter import GlobalRateLimiter limiter = GlobalRateLimiter(max_requests=2, window_seconds=60) # IP1 hits limit assert limiter.check_allowed("192.168.1.1")[0] is True assert limiter.check_allowed("192.168.1.1")[0] is True assert limiter.check_allowed("192.168.1.1")[0] is False # IP2 should still have allowance assert limiter.check_allowed("192.168.1.2")[0] is True assert limiter.check_allowed("192.168.1.2")[0] is True assert limiter.check_allowed("192.168.1.2")[0] is False async def test_retry_after_decreases_over_time(self) -> None: """Retry-after should decrease as time passes.""" from app.utils.rate_limiter import GlobalRateLimiter limiter = GlobalRateLimiter(max_requests=2, window_seconds=10) # Hit limit limiter.check_allowed("192.168.1.1") limiter.check_allowed("192.168.1.1") _, retry_after_1 = limiter.check_allowed("192.168.1.1") # Wait and check again await asyncio.sleep(2) _, retry_after_2 = limiter.check_allowed("192.168.1.1") assert retry_after_2 < retry_after_1 async def test_get_state(self) -> None: """get_state should return request counts per IP.""" from app.utils.rate_limiter import GlobalRateLimiter limiter = GlobalRateLimiter(max_requests=5, window_seconds=60) limiter.check_allowed("192.168.1.1") limiter.check_allowed("192.168.1.1") limiter.check_allowed("192.168.1.2") state = limiter.get_state() assert state["192.168.1.1"] == 2 assert state["192.168.1.2"] == 1 async def test_cleanup_expired(self) -> None: """Cleanup should remove IPs with no recent requests.""" from app.utils.rate_limiter import GlobalRateLimiter limiter = GlobalRateLimiter(max_requests=5, window_seconds=1) limiter.check_allowed("192.168.1.1") state_before = limiter.get_state() assert "192.168.1.1" in state_before # Wait for window to expire await asyncio.sleep(1.5) limiter.cleanup_expired() state_after = limiter.get_state() assert "192.168.1.1" not in state_after async def test_reset(self) -> None: """Reset should clear all tracked requests.""" from app.utils.rate_limiter import GlobalRateLimiter limiter = GlobalRateLimiter(max_requests=5, window_seconds=60) limiter.check_allowed("192.168.1.1") limiter.check_allowed("192.168.1.2") limiter.reset() state = limiter.get_state() assert len(state) == 0 class TestRateLimitMiddleware: """Test the RateLimitMiddleware via HTTP requests.""" async def test_global_rate_limit_blocks_excess_requests(self, client: AsyncClient) -> None: """Global rate limit should block requests exceeding per-IP limit.""" await _do_setup(client) limiter = client._transport.app.state.global_rate_limiter limiter.reset() original_max = limiter.max_requests limiter.max_requests = 3 try: for i in range(3): response = await client.get("/api/v1/health") assert response.status_code == 200, f"Request {i + 1} failed" response = await client.get("/api/v1/health") assert response.status_code == 429 assert response.json()["code"] == "rate_limit_exceeded" assert "Retry-After" in response.headers finally: limiter.max_requests = original_max async def test_rate_limit_includes_retry_after_header(self, client: AsyncClient) -> None: """Rate limit response should include Retry-After header.""" await _do_setup(client) limiter = client._transport.app.state.global_rate_limiter limiter.reset() original_max = limiter.max_requests limiter.max_requests = 2 try: response = await client.get("/api/v1/health") assert response.status_code == 200 response = await client.get("/api/v1/health") assert response.status_code == 200 response = await client.get("/api/v1/health") assert response.status_code == 429 assert "Retry-After" in response.headers retry_after = int(response.headers["Retry-After"]) assert retry_after > 0 assert retry_after <= 60 finally: limiter.max_requests = original_max async def test_auth_bucket_allows_more_requests(self, client: AsyncClient) -> None: """Auth endpoints use a dedicated high-rate bucket.""" await _do_setup(client) limiter = client._transport.app.state.global_rate_limiter limiter.reset() # The auth bucket is configured for 1000 req/min; we only need to # verify that it is *not* the global bucket (200 req/min). for _ in range(5): response = await client.post("/api/v1/auth/login", json={"password": "x"}) assert response.status_code in (401, 403, 429) async def test_history_bucket_allows_more_requests(self, client: AsyncClient) -> None: """History endpoints use a dedicated high-rate bucket.""" await _do_setup(client) limiter = client._transport.app.state.global_rate_limiter limiter.reset() for _ in range(5): response = await client.get("/api/v1/history/bans") # 401/403 is fine — we just need to confirm we are not 429'd # by the global limiter. assert response.status_code != 429