- Migrate settings.py to Pydantic V2 (SettingsConfigDict, validation_alias) - Update config models to use @field_validator with @classmethod - Replace deprecated datetime.utcnow() with datetime.now(timezone.utc) - Migrate FastAPI app from @app.on_event to lifespan context manager - Implement comprehensive rate limiting middleware with: * Endpoint-specific rate limits (login: 5/min, register: 3/min) * IP-based and user-based tracking * Authenticated user multiplier (2x limits) * Bypass paths for health, docs, static, websocket endpoints * Rate limit headers in responses - Add 13 comprehensive tests for rate limiting (all passing) - Update instructions.md to mark completed tasks - Fix asyncio.create_task usage in anime_service.py All 714 tests passing. No deprecation warnings.
270 lines
8.5 KiB
Python
270 lines
8.5 KiB
Python
"""Tests for rate limiting middleware."""
|
|
|
|
from typing import Optional
|
|
|
|
import httpx
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.testclient import TestClient
|
|
|
|
from src.server.middleware.rate_limit import (
|
|
RateLimitConfig,
|
|
RateLimitMiddleware,
|
|
RateLimitStore,
|
|
)
|
|
|
|
# Shim for environments where httpx.Client.__init__ doesn't accept an
|
|
# 'app' kwarg (some httpx versions have a different signature). The
|
|
# TestClient in Starlette passes `app=` through; to keep tests portable
|
|
# we pop it before calling the real initializer.
|
|
_orig_httpx_init = httpx.Client.__init__
|
|
|
|
|
|
def _httpx_init_shim(self, *args, **kwargs):
|
|
kwargs.pop("app", None)
|
|
return _orig_httpx_init(self, *args, **kwargs)
|
|
|
|
|
|
httpx.Client.__init__ = _httpx_init_shim
|
|
|
|
|
|
class TestRateLimitStore:
|
|
"""Tests for RateLimitStore class."""
|
|
|
|
def test_check_limit_allows_within_limits(self):
|
|
"""Test that requests within limits are allowed."""
|
|
store = RateLimitStore()
|
|
|
|
# First request should be allowed
|
|
allowed, retry_after = store.check_limit("test_id", 10, 100)
|
|
assert allowed is True
|
|
assert retry_after is None
|
|
|
|
# Record the request
|
|
store.record_request("test_id")
|
|
|
|
# Next request should still be allowed
|
|
allowed, retry_after = store.check_limit("test_id", 10, 100)
|
|
assert allowed is True
|
|
assert retry_after is None
|
|
|
|
def test_check_limit_blocks_over_minute_limit(self):
|
|
"""Test that requests over minute limit are blocked."""
|
|
store = RateLimitStore()
|
|
|
|
# Fill up to the minute limit
|
|
for _ in range(5):
|
|
store.record_request("test_id")
|
|
|
|
# Next request should be blocked
|
|
allowed, retry_after = store.check_limit("test_id", 5, 100)
|
|
assert allowed is False
|
|
assert retry_after is not None
|
|
assert retry_after > 0
|
|
|
|
def test_check_limit_blocks_over_hour_limit(self):
|
|
"""Test that requests over hour limit are blocked."""
|
|
store = RateLimitStore()
|
|
|
|
# Fill up to hour limit
|
|
for _ in range(10):
|
|
store.record_request("test_id")
|
|
|
|
# Next request should be blocked
|
|
allowed, retry_after = store.check_limit("test_id", 100, 10)
|
|
assert allowed is False
|
|
assert retry_after is not None
|
|
assert retry_after > 0
|
|
|
|
def test_get_remaining_requests(self):
|
|
"""Test getting remaining requests."""
|
|
store = RateLimitStore()
|
|
|
|
# Initially, all requests are remaining
|
|
minute_rem, hour_rem = store.get_remaining_requests(
|
|
"test_id", 10, 100
|
|
)
|
|
assert minute_rem == 10
|
|
assert hour_rem == 100
|
|
|
|
# After one request
|
|
store.record_request("test_id")
|
|
minute_rem, hour_rem = store.get_remaining_requests(
|
|
"test_id", 10, 100
|
|
)
|
|
assert minute_rem == 9
|
|
assert hour_rem == 99
|
|
|
|
|
|
class TestRateLimitConfig:
|
|
"""Tests for RateLimitConfig class."""
|
|
|
|
def test_default_config(self):
|
|
"""Test default configuration values."""
|
|
config = RateLimitConfig()
|
|
assert config.requests_per_minute == 60
|
|
assert config.requests_per_hour == 1000
|
|
assert config.authenticated_multiplier == 2.0
|
|
|
|
def test_custom_config(self):
|
|
"""Test custom configuration values."""
|
|
config = RateLimitConfig(
|
|
requests_per_minute=10,
|
|
requests_per_hour=100,
|
|
authenticated_multiplier=3.0,
|
|
)
|
|
assert config.requests_per_minute == 10
|
|
assert config.requests_per_hour == 100
|
|
assert config.authenticated_multiplier == 3.0
|
|
|
|
|
|
class TestRateLimitMiddleware:
|
|
"""Tests for RateLimitMiddleware class."""
|
|
|
|
def create_app(
|
|
self, default_config: Optional[RateLimitConfig] = None
|
|
) -> FastAPI:
|
|
"""Create a test FastAPI app with rate limiting.
|
|
|
|
Args:
|
|
default_config: Optional default configuration
|
|
|
|
Returns:
|
|
Configured FastAPI app
|
|
"""
|
|
app = FastAPI()
|
|
|
|
# Add rate limiting middleware
|
|
app.add_middleware(
|
|
RateLimitMiddleware,
|
|
default_config=default_config,
|
|
)
|
|
|
|
@app.get("/api/test")
|
|
async def test_endpoint():
|
|
return {"message": "success"}
|
|
|
|
@app.get("/health")
|
|
async def health_endpoint():
|
|
return {"status": "ok"}
|
|
|
|
@app.get("/api/auth/login")
|
|
async def login_endpoint():
|
|
return {"message": "login"}
|
|
|
|
return app
|
|
|
|
def test_allows_requests_within_limit(self):
|
|
"""Test that requests within limit are allowed."""
|
|
app = self.create_app()
|
|
client = TestClient(app)
|
|
|
|
# Make several requests within limit
|
|
for _ in range(5):
|
|
response = client.get("/api/test")
|
|
assert response.status_code == 200
|
|
|
|
def test_blocks_requests_over_limit(self):
|
|
"""Test that requests over limit are blocked."""
|
|
config = RateLimitConfig(
|
|
requests_per_minute=3,
|
|
requests_per_hour=100,
|
|
)
|
|
app = self.create_app(config)
|
|
client = TestClient(app, raise_server_exceptions=False)
|
|
|
|
# Make requests up to limit
|
|
for _ in range(3):
|
|
response = client.get("/api/test")
|
|
assert response.status_code == 200
|
|
|
|
# Next request should be rate limited
|
|
response = client.get("/api/test")
|
|
assert response.status_code == 429
|
|
assert "Retry-After" in response.headers
|
|
|
|
def test_bypass_health_endpoint(self):
|
|
"""Test that health endpoint bypasses rate limiting."""
|
|
config = RateLimitConfig(
|
|
requests_per_minute=1,
|
|
requests_per_hour=1,
|
|
)
|
|
app = self.create_app(config)
|
|
client = TestClient(app)
|
|
|
|
# Make many requests to health endpoint
|
|
for _ in range(10):
|
|
response = client.get("/health")
|
|
assert response.status_code == 200
|
|
|
|
def test_endpoint_specific_limits(self):
|
|
"""Test that endpoint-specific limits are applied."""
|
|
app = self.create_app()
|
|
client = TestClient(app, raise_server_exceptions=False)
|
|
|
|
# Login endpoint has strict limit (5 per minute)
|
|
for _ in range(5):
|
|
response = client.get("/api/auth/login")
|
|
assert response.status_code == 200
|
|
|
|
# Next login request should be rate limited
|
|
response = client.get("/api/auth/login")
|
|
assert response.status_code == 429
|
|
|
|
def test_rate_limit_headers(self):
|
|
"""Test that rate limit headers are added to response."""
|
|
app = self.create_app()
|
|
client = TestClient(app)
|
|
|
|
response = client.get("/api/test")
|
|
assert response.status_code == 200
|
|
assert "X-RateLimit-Limit-Minute" in response.headers
|
|
assert "X-RateLimit-Limit-Hour" in response.headers
|
|
assert "X-RateLimit-Remaining-Minute" in response.headers
|
|
assert "X-RateLimit-Remaining-Hour" in response.headers
|
|
|
|
def test_authenticated_user_multiplier(self):
|
|
"""Test that authenticated users get higher limits."""
|
|
config = RateLimitConfig(
|
|
requests_per_minute=5,
|
|
requests_per_hour=100,
|
|
authenticated_multiplier=2.0,
|
|
)
|
|
app = self.create_app(config)
|
|
|
|
# Add middleware to simulate authentication
|
|
@app.middleware("http")
|
|
async def add_user_to_state(request: Request, call_next):
|
|
request.state.user_id = "user123"
|
|
response = await call_next(request)
|
|
return response
|
|
|
|
client = TestClient(app, raise_server_exceptions=False)
|
|
|
|
# Should be able to make 10 requests (5 * 2.0)
|
|
for _ in range(10):
|
|
response = client.get("/api/test")
|
|
assert response.status_code == 200
|
|
|
|
# Next request should be rate limited
|
|
response = client.get("/api/test")
|
|
assert response.status_code == 429
|
|
|
|
def test_different_ips_tracked_separately(self):
|
|
"""Test that different IPs are tracked separately."""
|
|
config = RateLimitConfig(
|
|
requests_per_minute=2,
|
|
requests_per_hour=100,
|
|
)
|
|
app = self.create_app(config)
|
|
client = TestClient(app, raise_server_exceptions=False)
|
|
|
|
# Make requests from "different" IPs
|
|
# Note: TestClient uses same IP, but we can test the logic
|
|
for _ in range(2):
|
|
response = client.get("/api/test")
|
|
assert response.status_code == 200
|
|
|
|
# Third request should be rate limited
|
|
response = client.get("/api/test")
|
|
assert response.status_code == 429
|