feat: migrate to Pydantic V2 and implement rate limiting middleware
- Migrate settings.py to Pydantic V2 (SettingsConfigDict, validation_alias) - Update config models to use @field_validator with @classmethod - Replace deprecated datetime.utcnow() with datetime.now(timezone.utc) - Migrate FastAPI app from @app.on_event to lifespan context manager - Implement comprehensive rate limiting middleware with: * Endpoint-specific rate limits (login: 5/min, register: 3/min) * IP-based and user-based tracking * Authenticated user multiplier (2x limits) * Bypass paths for health, docs, static, websocket endpoints * Rate limit headers in responses - Add 13 comprehensive tests for rate limiting (all passing) - Update instructions.md to mark completed tasks - Fix asyncio.create_task usage in anime_service.py All 714 tests passing. No deprecation warnings.
This commit is contained in:
@@ -220,7 +220,7 @@ class TestWebSocketDownloadIntegration:
|
||||
download_service.set_broadcast_callback(mock_broadcast)
|
||||
|
||||
# Manually add a completed item to test
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from src.server.models.download import DownloadItem
|
||||
|
||||
@@ -231,7 +231,7 @@ class TestWebSocketDownloadIntegration:
|
||||
episode=EpisodeIdentifier(season=1, episode=1),
|
||||
status=DownloadStatus.COMPLETED,
|
||||
priority=DownloadPriority.NORMAL,
|
||||
added_at=datetime.utcnow(),
|
||||
added_at=datetime.now(timezone.utc),
|
||||
)
|
||||
download_service._completed_items.append(completed_item)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -30,7 +30,7 @@ def test_setup_request_requires_min_length():
|
||||
|
||||
|
||||
def test_login_response_and_session_model():
|
||||
expires = datetime.utcnow() + timedelta(hours=1)
|
||||
expires = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
lr = LoginResponse(access_token="tok", expires_at=expires)
|
||||
assert lr.token_type == "bearer"
|
||||
assert lr.access_token == "tok"
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
Tests cover password setup and validation, JWT token operations,
|
||||
session management, lockout mechanism, and error handling.
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -217,8 +217,8 @@ class TestJWTTokens:
|
||||
|
||||
expired_payload = {
|
||||
"sub": "tester",
|
||||
"exp": int((datetime.utcnow() - timedelta(hours=1)).timestamp()),
|
||||
"iat": int(datetime.utcnow().timestamp()),
|
||||
"exp": int((datetime.now(timezone.utc) - timedelta(hours=1)).timestamp()),
|
||||
"iat": int(datetime.now(timezone.utc).timestamp()),
|
||||
}
|
||||
expired_token = jwt.encode(
|
||||
expired_payload, svc.secret, algorithm="HS256"
|
||||
|
||||
@@ -174,7 +174,7 @@ class TestEpisode:
|
||||
file_path="/anime/test/S01E05.mp4",
|
||||
file_size=524288000, # 500 MB
|
||||
is_downloaded=True,
|
||||
download_date=datetime.utcnow(),
|
||||
download_date=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
db_session.add(episode)
|
||||
@@ -310,7 +310,7 @@ class TestUserSession:
|
||||
|
||||
def test_create_user_session(self, db_session: Session):
|
||||
"""Test creating a user session."""
|
||||
expires = datetime.utcnow() + timedelta(hours=24)
|
||||
expires = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||
|
||||
session = UserSession(
|
||||
session_id="test-session-123",
|
||||
@@ -333,7 +333,7 @@ class TestUserSession:
|
||||
|
||||
def test_session_unique_session_id(self, db_session: Session):
|
||||
"""Test that session_id must be unique."""
|
||||
expires = datetime.utcnow() + timedelta(hours=24)
|
||||
expires = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||
|
||||
session1 = UserSession(
|
||||
session_id="duplicate-id",
|
||||
@@ -371,7 +371,7 @@ class TestUserSession:
|
||||
|
||||
def test_session_revoke(self, db_session: Session):
|
||||
"""Test session revocation."""
|
||||
expires = datetime.utcnow() + timedelta(hours=24)
|
||||
expires = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||
session = UserSession(
|
||||
session_id="revoke-test",
|
||||
token_hash="hash",
|
||||
@@ -531,7 +531,7 @@ class TestDatabaseQueries:
|
||||
|
||||
def test_query_active_sessions(self, db_session: Session):
|
||||
"""Test querying active user sessions."""
|
||||
expires = datetime.utcnow() + timedelta(hours=24)
|
||||
expires = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||
|
||||
# Create active and inactive sessions
|
||||
active = UserSession(
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
Tests CRUD operations for all database services using in-memory SQLite.
|
||||
"""
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
@@ -538,7 +538,7 @@ async def test_retry_failed_downloads(db_session):
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_session(db_session):
|
||||
"""Test creating a user session."""
|
||||
expires_at = datetime.utcnow() + timedelta(hours=24)
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||
session = await UserSessionService.create(
|
||||
db_session,
|
||||
session_id="test-session-1",
|
||||
@@ -556,7 +556,7 @@ async def test_create_user_session(db_session):
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session_by_id(db_session):
|
||||
"""Test retrieving session by ID."""
|
||||
expires_at = datetime.utcnow() + timedelta(hours=24)
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||
session = await UserSessionService.create(
|
||||
db_session,
|
||||
session_id="test-session-2",
|
||||
@@ -578,7 +578,7 @@ async def test_get_session_by_id(db_session):
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_sessions(db_session):
|
||||
"""Test retrieving active sessions."""
|
||||
expires_at = datetime.utcnow() + timedelta(hours=24)
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||
|
||||
# Create active session
|
||||
await UserSessionService.create(
|
||||
@@ -593,7 +593,7 @@ async def test_get_active_sessions(db_session):
|
||||
db_session,
|
||||
session_id="expired-session",
|
||||
token_hash="hashed-token",
|
||||
expires_at=datetime.utcnow() - timedelta(hours=1),
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
@@ -606,7 +606,7 @@ async def test_get_active_sessions(db_session):
|
||||
@pytest.mark.asyncio
|
||||
async def test_revoke_session(db_session):
|
||||
"""Test revoking a session."""
|
||||
expires_at = datetime.utcnow() + timedelta(hours=24)
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||
session = await UserSessionService.create(
|
||||
db_session,
|
||||
session_id="test-session-3",
|
||||
@@ -637,13 +637,13 @@ async def test_cleanup_expired_sessions(db_session):
|
||||
db_session,
|
||||
session_id="expired-1",
|
||||
token_hash="hashed-token",
|
||||
expires_at=datetime.utcnow() - timedelta(hours=1),
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
)
|
||||
await UserSessionService.create(
|
||||
db_session,
|
||||
session_id="expired-2",
|
||||
token_hash="hashed-token",
|
||||
expires_at=datetime.utcnow() - timedelta(hours=2),
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(hours=2),
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
@@ -657,7 +657,7 @@ async def test_cleanup_expired_sessions(db_session):
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_session_activity(db_session):
|
||||
"""Test updating session last activity."""
|
||||
expires_at = datetime.utcnow() + timedelta(hours=24)
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(hours=24)
|
||||
session = await UserSessionService.create(
|
||||
db_session,
|
||||
session_id="test-session-4",
|
||||
|
||||
@@ -221,7 +221,7 @@ class TestDownloadItem:
|
||||
def test_download_item_with_timestamps(self):
|
||||
"""Test download item with timestamp fields."""
|
||||
episode = EpisodeIdentifier(season=1, episode=1)
|
||||
now = datetime.utcnow()
|
||||
now = datetime.now(timezone.utc)
|
||||
item = DownloadItem(
|
||||
id="test_id",
|
||||
serie_id="serie_id",
|
||||
|
||||
@@ -7,7 +7,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
@@ -84,7 +84,7 @@ class TestDownloadServiceInitialization:
|
||||
"episode": {"season": 1, "episode": 1, "title": None},
|
||||
"status": "pending",
|
||||
"priority": "normal",
|
||||
"added_at": datetime.utcnow().isoformat(),
|
||||
"added_at": datetime.now(timezone.utc).isoformat(),
|
||||
"started_at": None,
|
||||
"completed_at": None,
|
||||
"progress": None,
|
||||
@@ -95,7 +95,7 @@ class TestDownloadServiceInitialization:
|
||||
],
|
||||
"active": [],
|
||||
"failed": [],
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
with open(persistence_file, "w", encoding="utf-8") as f:
|
||||
|
||||
269
tests/unit/test_rate_limit.py
Normal file
269
tests/unit/test_rate_limit.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""Tests for rate limiting middleware."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from src.server.middleware.rate_limit import (
|
||||
RateLimitConfig,
|
||||
RateLimitMiddleware,
|
||||
RateLimitStore,
|
||||
)
|
||||
|
||||
# Shim for environments where httpx.Client.__init__ doesn't accept an
|
||||
# 'app' kwarg (some httpx versions have a different signature). The
|
||||
# TestClient in Starlette passes `app=` through; to keep tests portable
|
||||
# we pop it before calling the real initializer.
|
||||
_orig_httpx_init = httpx.Client.__init__
|
||||
|
||||
|
||||
def _httpx_init_shim(self, *args, **kwargs):
|
||||
kwargs.pop("app", None)
|
||||
return _orig_httpx_init(self, *args, **kwargs)
|
||||
|
||||
|
||||
httpx.Client.__init__ = _httpx_init_shim
|
||||
|
||||
|
||||
class TestRateLimitStore:
|
||||
"""Tests for RateLimitStore class."""
|
||||
|
||||
def test_check_limit_allows_within_limits(self):
|
||||
"""Test that requests within limits are allowed."""
|
||||
store = RateLimitStore()
|
||||
|
||||
# First request should be allowed
|
||||
allowed, retry_after = store.check_limit("test_id", 10, 100)
|
||||
assert allowed is True
|
||||
assert retry_after is None
|
||||
|
||||
# Record the request
|
||||
store.record_request("test_id")
|
||||
|
||||
# Next request should still be allowed
|
||||
allowed, retry_after = store.check_limit("test_id", 10, 100)
|
||||
assert allowed is True
|
||||
assert retry_after is None
|
||||
|
||||
def test_check_limit_blocks_over_minute_limit(self):
|
||||
"""Test that requests over minute limit are blocked."""
|
||||
store = RateLimitStore()
|
||||
|
||||
# Fill up to the minute limit
|
||||
for _ in range(5):
|
||||
store.record_request("test_id")
|
||||
|
||||
# Next request should be blocked
|
||||
allowed, retry_after = store.check_limit("test_id", 5, 100)
|
||||
assert allowed is False
|
||||
assert retry_after is not None
|
||||
assert retry_after > 0
|
||||
|
||||
def test_check_limit_blocks_over_hour_limit(self):
|
||||
"""Test that requests over hour limit are blocked."""
|
||||
store = RateLimitStore()
|
||||
|
||||
# Fill up to hour limit
|
||||
for _ in range(10):
|
||||
store.record_request("test_id")
|
||||
|
||||
# Next request should be blocked
|
||||
allowed, retry_after = store.check_limit("test_id", 100, 10)
|
||||
assert allowed is False
|
||||
assert retry_after is not None
|
||||
assert retry_after > 0
|
||||
|
||||
def test_get_remaining_requests(self):
|
||||
"""Test getting remaining requests."""
|
||||
store = RateLimitStore()
|
||||
|
||||
# Initially, all requests are remaining
|
||||
minute_rem, hour_rem = store.get_remaining_requests(
|
||||
"test_id", 10, 100
|
||||
)
|
||||
assert minute_rem == 10
|
||||
assert hour_rem == 100
|
||||
|
||||
# After one request
|
||||
store.record_request("test_id")
|
||||
minute_rem, hour_rem = store.get_remaining_requests(
|
||||
"test_id", 10, 100
|
||||
)
|
||||
assert minute_rem == 9
|
||||
assert hour_rem == 99
|
||||
|
||||
|
||||
class TestRateLimitConfig:
|
||||
"""Tests for RateLimitConfig class."""
|
||||
|
||||
def test_default_config(self):
|
||||
"""Test default configuration values."""
|
||||
config = RateLimitConfig()
|
||||
assert config.requests_per_minute == 60
|
||||
assert config.requests_per_hour == 1000
|
||||
assert config.authenticated_multiplier == 2.0
|
||||
|
||||
def test_custom_config(self):
|
||||
"""Test custom configuration values."""
|
||||
config = RateLimitConfig(
|
||||
requests_per_minute=10,
|
||||
requests_per_hour=100,
|
||||
authenticated_multiplier=3.0,
|
||||
)
|
||||
assert config.requests_per_minute == 10
|
||||
assert config.requests_per_hour == 100
|
||||
assert config.authenticated_multiplier == 3.0
|
||||
|
||||
|
||||
class TestRateLimitMiddleware:
|
||||
"""Tests for RateLimitMiddleware class."""
|
||||
|
||||
def create_app(
|
||||
self, default_config: Optional[RateLimitConfig] = None
|
||||
) -> FastAPI:
|
||||
"""Create a test FastAPI app with rate limiting.
|
||||
|
||||
Args:
|
||||
default_config: Optional default configuration
|
||||
|
||||
Returns:
|
||||
Configured FastAPI app
|
||||
"""
|
||||
app = FastAPI()
|
||||
|
||||
# Add rate limiting middleware
|
||||
app.add_middleware(
|
||||
RateLimitMiddleware,
|
||||
default_config=default_config,
|
||||
)
|
||||
|
||||
@app.get("/api/test")
|
||||
async def test_endpoint():
|
||||
return {"message": "success"}
|
||||
|
||||
@app.get("/health")
|
||||
async def health_endpoint():
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/api/auth/login")
|
||||
async def login_endpoint():
|
||||
return {"message": "login"}
|
||||
|
||||
return app
|
||||
|
||||
def test_allows_requests_within_limit(self):
|
||||
"""Test that requests within limit are allowed."""
|
||||
app = self.create_app()
|
||||
client = TestClient(app)
|
||||
|
||||
# Make several requests within limit
|
||||
for _ in range(5):
|
||||
response = client.get("/api/test")
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_blocks_requests_over_limit(self):
|
||||
"""Test that requests over limit are blocked."""
|
||||
config = RateLimitConfig(
|
||||
requests_per_minute=3,
|
||||
requests_per_hour=100,
|
||||
)
|
||||
app = self.create_app(config)
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
|
||||
# Make requests up to limit
|
||||
for _ in range(3):
|
||||
response = client.get("/api/test")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Next request should be rate limited
|
||||
response = client.get("/api/test")
|
||||
assert response.status_code == 429
|
||||
assert "Retry-After" in response.headers
|
||||
|
||||
def test_bypass_health_endpoint(self):
|
||||
"""Test that health endpoint bypasses rate limiting."""
|
||||
config = RateLimitConfig(
|
||||
requests_per_minute=1,
|
||||
requests_per_hour=1,
|
||||
)
|
||||
app = self.create_app(config)
|
||||
client = TestClient(app)
|
||||
|
||||
# Make many requests to health endpoint
|
||||
for _ in range(10):
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_endpoint_specific_limits(self):
|
||||
"""Test that endpoint-specific limits are applied."""
|
||||
app = self.create_app()
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
|
||||
# Login endpoint has strict limit (5 per minute)
|
||||
for _ in range(5):
|
||||
response = client.get("/api/auth/login")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Next login request should be rate limited
|
||||
response = client.get("/api/auth/login")
|
||||
assert response.status_code == 429
|
||||
|
||||
def test_rate_limit_headers(self):
|
||||
"""Test that rate limit headers are added to response."""
|
||||
app = self.create_app()
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/api/test")
|
||||
assert response.status_code == 200
|
||||
assert "X-RateLimit-Limit-Minute" in response.headers
|
||||
assert "X-RateLimit-Limit-Hour" in response.headers
|
||||
assert "X-RateLimit-Remaining-Minute" in response.headers
|
||||
assert "X-RateLimit-Remaining-Hour" in response.headers
|
||||
|
||||
def test_authenticated_user_multiplier(self):
|
||||
"""Test that authenticated users get higher limits."""
|
||||
config = RateLimitConfig(
|
||||
requests_per_minute=5,
|
||||
requests_per_hour=100,
|
||||
authenticated_multiplier=2.0,
|
||||
)
|
||||
app = self.create_app(config)
|
||||
|
||||
# Add middleware to simulate authentication
|
||||
@app.middleware("http")
|
||||
async def add_user_to_state(request: Request, call_next):
|
||||
request.state.user_id = "user123"
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
|
||||
# Should be able to make 10 requests (5 * 2.0)
|
||||
for _ in range(10):
|
||||
response = client.get("/api/test")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Next request should be rate limited
|
||||
response = client.get("/api/test")
|
||||
assert response.status_code == 429
|
||||
|
||||
def test_different_ips_tracked_separately(self):
|
||||
"""Test that different IPs are tracked separately."""
|
||||
config = RateLimitConfig(
|
||||
requests_per_minute=2,
|
||||
requests_per_hour=100,
|
||||
)
|
||||
app = self.create_app(config)
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
|
||||
# Make requests from "different" IPs
|
||||
# Note: TestClient uses same IP, but we can test the logic
|
||||
for _ in range(2):
|
||||
response = client.get("/api/test")
|
||||
assert response.status_code == 200
|
||||
|
||||
# Third request should be rate limited
|
||||
response = client.get("/api/test")
|
||||
assert response.status_code == 429
|
||||
Reference in New Issue
Block a user