"""Tests for rate limiting middleware.""" from typing import Optional import httpx from fastapi import FastAPI, Request from fastapi.testclient import TestClient from src.server.middleware.rate_limit import ( RateLimitConfig, RateLimitMiddleware, RateLimitStore, ) # Shim for environments where httpx.Client.__init__ doesn't accept an # 'app' kwarg (some httpx versions have a different signature). The # TestClient in Starlette passes `app=` through; to keep tests portable # we pop it before calling the real initializer. _orig_httpx_init = httpx.Client.__init__ def _httpx_init_shim(self, *args, **kwargs): kwargs.pop("app", None) return _orig_httpx_init(self, *args, **kwargs) httpx.Client.__init__ = _httpx_init_shim class TestRateLimitStore: """Tests for RateLimitStore class.""" def test_check_limit_allows_within_limits(self): """Test that requests within limits are allowed.""" store = RateLimitStore() # First request should be allowed allowed, retry_after = store.check_limit("test_id", 10, 100) assert allowed is True assert retry_after is None # Record the request store.record_request("test_id") # Next request should still be allowed allowed, retry_after = store.check_limit("test_id", 10, 100) assert allowed is True assert retry_after is None def test_check_limit_blocks_over_minute_limit(self): """Test that requests over minute limit are blocked.""" store = RateLimitStore() # Fill up to the minute limit for _ in range(5): store.record_request("test_id") # Next request should be blocked allowed, retry_after = store.check_limit("test_id", 5, 100) assert allowed is False assert retry_after is not None assert retry_after > 0 def test_check_limit_blocks_over_hour_limit(self): """Test that requests over hour limit are blocked.""" store = RateLimitStore() # Fill up to hour limit for _ in range(10): store.record_request("test_id") # Next request should be blocked allowed, retry_after = store.check_limit("test_id", 100, 10) assert allowed is False assert retry_after is not None assert retry_after > 0 def test_get_remaining_requests(self): """Test getting remaining requests.""" store = RateLimitStore() # Initially, all requests are remaining minute_rem, hour_rem = store.get_remaining_requests( "test_id", 10, 100 ) assert minute_rem == 10 assert hour_rem == 100 # After one request store.record_request("test_id") minute_rem, hour_rem = store.get_remaining_requests( "test_id", 10, 100 ) assert minute_rem == 9 assert hour_rem == 99 class TestRateLimitConfig: """Tests for RateLimitConfig class.""" def test_default_config(self): """Test default configuration values.""" config = RateLimitConfig() assert config.requests_per_minute == 60 assert config.requests_per_hour == 1000 assert config.authenticated_multiplier == 2.0 def test_custom_config(self): """Test custom configuration values.""" config = RateLimitConfig( requests_per_minute=10, requests_per_hour=100, authenticated_multiplier=3.0, ) assert config.requests_per_minute == 10 assert config.requests_per_hour == 100 assert config.authenticated_multiplier == 3.0 class TestRateLimitMiddleware: """Tests for RateLimitMiddleware class.""" def create_app( self, default_config: Optional[RateLimitConfig] = None ) -> FastAPI: """Create a test FastAPI app with rate limiting. Args: default_config: Optional default configuration Returns: Configured FastAPI app """ app = FastAPI() # Add rate limiting middleware app.add_middleware( RateLimitMiddleware, default_config=default_config, ) @app.get("/api/test") async def test_endpoint(): return {"message": "success"} @app.get("/health") async def health_endpoint(): return {"status": "ok"} @app.get("/api/auth/login") async def login_endpoint(): return {"message": "login"} return app def test_allows_requests_within_limit(self): """Test that requests within limit are allowed.""" app = self.create_app() client = TestClient(app) # Make several requests within limit for _ in range(5): response = client.get("/api/test") assert response.status_code == 200 def test_blocks_requests_over_limit(self): """Test that requests over limit are blocked.""" config = RateLimitConfig( requests_per_minute=3, requests_per_hour=100, ) app = self.create_app(config) client = TestClient(app, raise_server_exceptions=False) # Make requests up to limit for _ in range(3): response = client.get("/api/test") assert response.status_code == 200 # Next request should be rate limited response = client.get("/api/test") assert response.status_code == 429 assert "Retry-After" in response.headers def test_bypass_health_endpoint(self): """Test that health endpoint bypasses rate limiting.""" config = RateLimitConfig( requests_per_minute=1, requests_per_hour=1, ) app = self.create_app(config) client = TestClient(app) # Make many requests to health endpoint for _ in range(10): response = client.get("/health") assert response.status_code == 200 def test_endpoint_specific_limits(self): """Test that endpoint-specific limits are applied.""" app = self.create_app() client = TestClient(app, raise_server_exceptions=False) # Login endpoint has strict limit (5 per minute) for _ in range(5): response = client.get("/api/auth/login") assert response.status_code == 200 # Next login request should be rate limited response = client.get("/api/auth/login") assert response.status_code == 429 def test_rate_limit_headers(self): """Test that rate limit headers are added to response.""" app = self.create_app() client = TestClient(app) response = client.get("/api/test") assert response.status_code == 200 assert "X-RateLimit-Limit-Minute" in response.headers assert "X-RateLimit-Limit-Hour" in response.headers assert "X-RateLimit-Remaining-Minute" in response.headers assert "X-RateLimit-Remaining-Hour" in response.headers def test_authenticated_user_multiplier(self): """Test that authenticated users get higher limits.""" config = RateLimitConfig( requests_per_minute=5, requests_per_hour=100, authenticated_multiplier=2.0, ) app = self.create_app(config) # Add middleware to simulate authentication @app.middleware("http") async def add_user_to_state(request: Request, call_next): request.state.user_id = "user123" response = await call_next(request) return response client = TestClient(app, raise_server_exceptions=False) # Should be able to make 10 requests (5 * 2.0) for _ in range(10): response = client.get("/api/test") assert response.status_code == 200 # Next request should be rate limited response = client.get("/api/test") assert response.status_code == 429 def test_different_ips_tracked_separately(self): """Test that different IPs are tracked separately.""" config = RateLimitConfig( requests_per_minute=2, requests_per_hour=100, ) app = self.create_app(config) client = TestClient(app, raise_server_exceptions=False) # Make requests from "different" IPs # Note: TestClient uses same IP, but we can test the logic for _ in range(2): response = client.get("/api/test") assert response.status_code == 200 # Third request should be rate limited response = client.get("/api/test") assert response.status_code == 429