"""Tests for the global rate limiter and rate limit middleware.""" from __future__ import annotations import asyncio from httpx import AsyncClient _SETUP_PAYLOAD = { "master_password": "Mysecretpass1!", "database_path": "bangui.db", "fail2ban_socket": "/var/run/fail2ban/fail2ban.sock", "timezone": "UTC", "session_duration_minutes": 60, } async def _do_setup(client: AsyncClient) -> None: """Run the setup wizard so auth endpoints are reachable.""" resp = await client.post("/api/setup", json=_SETUP_PAYLOAD) assert resp.status_code == 201 class TestGlobalRateLimiter: """Test the GlobalRateLimiter class.""" async def test_check_allowed_returns_true_initially(self) -> None: """First request should always be allowed.""" from app.utils.rate_limiter import GlobalRateLimiter limiter = GlobalRateLimiter(max_requests=5, window_seconds=60) is_allowed, retry_after = limiter.check_allowed("192.168.1.1") assert is_allowed is True assert retry_after == 0.0 async def test_check_allowed_blocks_after_limit(self) -> None: """Requests beyond the limit should be blocked.""" from app.utils.rate_limiter import GlobalRateLimiter limiter = GlobalRateLimiter(max_requests=2, window_seconds=60) # First two requests allowed assert limiter.check_allowed("192.168.1.1")[0] is True assert limiter.check_allowed("192.168.1.1")[0] is True # Third request blocked is_allowed, retry_after = limiter.check_allowed("192.168.1.1") assert is_allowed is False assert retry_after > 0 async def test_check_allowed_per_ip_isolation(self) -> None: """Different IPs should have independent limits.""" from app.utils.rate_limiter import GlobalRateLimiter limiter = GlobalRateLimiter(max_requests=2, window_seconds=60) # IP1 hits limit assert limiter.check_allowed("192.168.1.1")[0] is True assert limiter.check_allowed("192.168.1.1")[0] is True assert limiter.check_allowed("192.168.1.1")[0] is False # IP2 should still have allowance assert limiter.check_allowed("192.168.1.2")[0] is True assert limiter.check_allowed("192.168.1.2")[0] is True assert limiter.check_allowed("192.168.1.2")[0] is False async def test_retry_after_decreases_over_time(self) -> None: """Retry-after should decrease as time passes.""" from app.utils.rate_limiter import GlobalRateLimiter limiter = GlobalRateLimiter(max_requests=2, window_seconds=10) # Hit limit limiter.check_allowed("192.168.1.1") limiter.check_allowed("192.168.1.1") _, retry_after_1 = limiter.check_allowed("192.168.1.1") # Wait and check again await asyncio.sleep(2) _, retry_after_2 = limiter.check_allowed("192.168.1.1") assert retry_after_2 < retry_after_1 async def test_get_state(self) -> None: """get_state should return request counts per IP.""" from app.utils.rate_limiter import GlobalRateLimiter limiter = GlobalRateLimiter(max_requests=5, window_seconds=60) limiter.check_allowed("192.168.1.1") limiter.check_allowed("192.168.1.1") limiter.check_allowed("192.168.1.2") state = limiter.get_state() assert state["192.168.1.1"] == 2 assert state["192.168.1.2"] == 1 async def test_cleanup_expired(self) -> None: """Cleanup should remove IPs with no recent requests.""" from app.utils.rate_limiter import GlobalRateLimiter limiter = GlobalRateLimiter(max_requests=5, window_seconds=1) limiter.check_allowed("192.168.1.1") state_before = limiter.get_state() assert "192.168.1.1" in state_before # Wait for window to expire await asyncio.sleep(1.5) limiter.cleanup_expired() state_after = limiter.get_state() assert "192.168.1.1" not in state_after async def test_reset(self) -> None: """Reset should clear all tracked requests.""" from app.utils.rate_limiter import GlobalRateLimiter limiter = GlobalRateLimiter(max_requests=5, window_seconds=60) limiter.check_allowed("192.168.1.1") limiter.check_allowed("192.168.1.2") limiter.reset() state = limiter.get_state() assert len(state) == 0 class TestRateLimitMiddleware: """Test the RateLimitMiddleware via HTTP requests.""" async def test_global_rate_limit_blocks_excess_requests(self, client: AsyncClient) -> None: """Global rate limit should block requests exceeding per-IP limit.""" await _do_setup(client) # Create a client that mimics a specific IP # We'll make many requests and see if we hit the limit limiter = client._transport.app.state.global_rate_limiter limiter.reset() # Reduce limit temporarily for testing original_max = limiter.max_requests limiter.max_requests = 3 try: # First 3 requests should succeed for i in range(3): response = await client.get("/api/health") assert response.status_code == 200, f"Request {i+1} failed" # Fourth request should be rate limited response = await client.get("/api/health") assert response.status_code == 429 assert response.json()["code"] == "rate_limit_exceeded" assert "Retry-After" in response.headers finally: limiter.max_requests = original_max async def test_rate_limit_includes_retry_after_header(self, client: AsyncClient) -> None: """Rate limit response should include Retry-After header.""" await _do_setup(client) limiter = client._transport.app.state.global_rate_limiter limiter.reset() original_max = limiter.max_requests limiter.max_requests = 1 try: # First request succeeds response = await client.get("/api/health") assert response.status_code == 200 # Second request is rate limited response = await client.get("/api/health") assert response.status_code == 429 assert "Retry-After" in response.headers retry_after = int(response.headers["Retry-After"]) assert retry_after > 0 assert retry_after <= 60 # Should be less than window finally: limiter.max_requests = original_max