Aniworld/tests/unit/test_rate_limit.py
Lukas 17e5a551e1 feat: migrate to Pydantic V2 and implement rate limiting middleware
- Migrate settings.py to Pydantic V2 (SettingsConfigDict, validation_alias)
- Update config models to use @field_validator with @classmethod
- Replace deprecated datetime.utcnow() with datetime.now(timezone.utc)
- Migrate FastAPI app from @app.on_event to lifespan context manager
- Implement comprehensive rate limiting middleware with:
  * Endpoint-specific rate limits (login: 5/min, register: 3/min)
  * IP-based and user-based tracking
  * Authenticated user multiplier (2x limits)
  * Bypass paths for health, docs, static, websocket endpoints
  * Rate limit headers in responses
- Add 13 comprehensive tests for rate limiting (all passing)
- Update instructions.md to mark completed tasks
- Fix asyncio.create_task usage in anime_service.py

All 714 tests passing. No deprecation warnings.
2025-10-23 22:03:15 +02:00

270 lines
8.5 KiB
Python

"""Tests for rate limiting middleware."""
from typing import Optional
import httpx
from fastapi import FastAPI, Request
from fastapi.testclient import TestClient
from src.server.middleware.rate_limit import (
RateLimitConfig,
RateLimitMiddleware,
RateLimitStore,
)
# Shim for environments where httpx.Client.__init__ doesn't accept an
# 'app' kwarg (some httpx versions have a different signature). The
# TestClient in Starlette passes `app=` through; to keep tests portable
# we pop it before calling the real initializer.
_orig_httpx_init = httpx.Client.__init__
def _httpx_init_shim(self, *args, **kwargs):
kwargs.pop("app", None)
return _orig_httpx_init(self, *args, **kwargs)
httpx.Client.__init__ = _httpx_init_shim
class TestRateLimitStore:
"""Tests for RateLimitStore class."""
def test_check_limit_allows_within_limits(self):
"""Test that requests within limits are allowed."""
store = RateLimitStore()
# First request should be allowed
allowed, retry_after = store.check_limit("test_id", 10, 100)
assert allowed is True
assert retry_after is None
# Record the request
store.record_request("test_id")
# Next request should still be allowed
allowed, retry_after = store.check_limit("test_id", 10, 100)
assert allowed is True
assert retry_after is None
def test_check_limit_blocks_over_minute_limit(self):
"""Test that requests over minute limit are blocked."""
store = RateLimitStore()
# Fill up to the minute limit
for _ in range(5):
store.record_request("test_id")
# Next request should be blocked
allowed, retry_after = store.check_limit("test_id", 5, 100)
assert allowed is False
assert retry_after is not None
assert retry_after > 0
def test_check_limit_blocks_over_hour_limit(self):
"""Test that requests over hour limit are blocked."""
store = RateLimitStore()
# Fill up to hour limit
for _ in range(10):
store.record_request("test_id")
# Next request should be blocked
allowed, retry_after = store.check_limit("test_id", 100, 10)
assert allowed is False
assert retry_after is not None
assert retry_after > 0
def test_get_remaining_requests(self):
"""Test getting remaining requests."""
store = RateLimitStore()
# Initially, all requests are remaining
minute_rem, hour_rem = store.get_remaining_requests(
"test_id", 10, 100
)
assert minute_rem == 10
assert hour_rem == 100
# After one request
store.record_request("test_id")
minute_rem, hour_rem = store.get_remaining_requests(
"test_id", 10, 100
)
assert minute_rem == 9
assert hour_rem == 99
class TestRateLimitConfig:
"""Tests for RateLimitConfig class."""
def test_default_config(self):
"""Test default configuration values."""
config = RateLimitConfig()
assert config.requests_per_minute == 60
assert config.requests_per_hour == 1000
assert config.authenticated_multiplier == 2.0
def test_custom_config(self):
"""Test custom configuration values."""
config = RateLimitConfig(
requests_per_minute=10,
requests_per_hour=100,
authenticated_multiplier=3.0,
)
assert config.requests_per_minute == 10
assert config.requests_per_hour == 100
assert config.authenticated_multiplier == 3.0
class TestRateLimitMiddleware:
"""Tests for RateLimitMiddleware class."""
def create_app(
self, default_config: Optional[RateLimitConfig] = None
) -> FastAPI:
"""Create a test FastAPI app with rate limiting.
Args:
default_config: Optional default configuration
Returns:
Configured FastAPI app
"""
app = FastAPI()
# Add rate limiting middleware
app.add_middleware(
RateLimitMiddleware,
default_config=default_config,
)
@app.get("/api/test")
async def test_endpoint():
return {"message": "success"}
@app.get("/health")
async def health_endpoint():
return {"status": "ok"}
@app.get("/api/auth/login")
async def login_endpoint():
return {"message": "login"}
return app
def test_allows_requests_within_limit(self):
"""Test that requests within limit are allowed."""
app = self.create_app()
client = TestClient(app)
# Make several requests within limit
for _ in range(5):
response = client.get("/api/test")
assert response.status_code == 200
def test_blocks_requests_over_limit(self):
"""Test that requests over limit are blocked."""
config = RateLimitConfig(
requests_per_minute=3,
requests_per_hour=100,
)
app = self.create_app(config)
client = TestClient(app, raise_server_exceptions=False)
# Make requests up to limit
for _ in range(3):
response = client.get("/api/test")
assert response.status_code == 200
# Next request should be rate limited
response = client.get("/api/test")
assert response.status_code == 429
assert "Retry-After" in response.headers
def test_bypass_health_endpoint(self):
"""Test that health endpoint bypasses rate limiting."""
config = RateLimitConfig(
requests_per_minute=1,
requests_per_hour=1,
)
app = self.create_app(config)
client = TestClient(app)
# Make many requests to health endpoint
for _ in range(10):
response = client.get("/health")
assert response.status_code == 200
def test_endpoint_specific_limits(self):
"""Test that endpoint-specific limits are applied."""
app = self.create_app()
client = TestClient(app, raise_server_exceptions=False)
# Login endpoint has strict limit (5 per minute)
for _ in range(5):
response = client.get("/api/auth/login")
assert response.status_code == 200
# Next login request should be rate limited
response = client.get("/api/auth/login")
assert response.status_code == 429
def test_rate_limit_headers(self):
"""Test that rate limit headers are added to response."""
app = self.create_app()
client = TestClient(app)
response = client.get("/api/test")
assert response.status_code == 200
assert "X-RateLimit-Limit-Minute" in response.headers
assert "X-RateLimit-Limit-Hour" in response.headers
assert "X-RateLimit-Remaining-Minute" in response.headers
assert "X-RateLimit-Remaining-Hour" in response.headers
def test_authenticated_user_multiplier(self):
"""Test that authenticated users get higher limits."""
config = RateLimitConfig(
requests_per_minute=5,
requests_per_hour=100,
authenticated_multiplier=2.0,
)
app = self.create_app(config)
# Add middleware to simulate authentication
@app.middleware("http")
async def add_user_to_state(request: Request, call_next):
request.state.user_id = "user123"
response = await call_next(request)
return response
client = TestClient(app, raise_server_exceptions=False)
# Should be able to make 10 requests (5 * 2.0)
for _ in range(10):
response = client.get("/api/test")
assert response.status_code == 200
# Next request should be rate limited
response = client.get("/api/test")
assert response.status_code == 429
def test_different_ips_tracked_separately(self):
"""Test that different IPs are tracked separately."""
config = RateLimitConfig(
requests_per_minute=2,
requests_per_hour=100,
)
app = self.create_app(config)
client = TestClient(app, raise_server_exceptions=False)
# Make requests from "different" IPs
# Note: TestClient uses same IP, but we can test the logic
for _ in range(2):
response = client.get("/api/test")
assert response.status_code == 200
# Third request should be rate limited
response = client.get("/api/test")
assert response.status_code == 429