From 17e5a551e1a51bb0486fb1043425d151aa5e44c1 Mon Sep 17 00:00:00 2001 From: Lukas Date: Thu, 23 Oct 2025 22:03:15 +0200 Subject: [PATCH] feat: migrate to Pydantic V2 and implement rate limiting middleware - Migrate settings.py to Pydantic V2 (SettingsConfigDict, validation_alias) - Update config models to use @field_validator with @classmethod - Replace deprecated datetime.utcnow() with datetime.now(timezone.utc) - Migrate FastAPI app from @app.on_event to lifespan context manager - Implement comprehensive rate limiting middleware with: * Endpoint-specific rate limits (login: 5/min, register: 3/min) * IP-based and user-based tracking * Authenticated user multiplier (2x limits) * Bypass paths for health, docs, static, websocket endpoints * Rate limit headers in responses - Add 13 comprehensive tests for rate limiting (all passing) - Update instructions.md to mark completed tasks - Fix asyncio.create_task usage in anime_service.py All 714 tests passing. No deprecation warnings. --- .../config_backup_20251023_210321.json | 21 ++ .../config_backup_20251023_213153.json | 21 ++ .../config_backup_20251023_213614.json | 21 ++ .../config_backup_20251023_214540.json | 21 ++ .../config_backup_20251023_214839.json | 21 ++ .../config_backup_20251023_215649.json | 21 ++ data/download_queue.json | 158 ++++----- instructions.md | 76 ---- src/config/settings.py | 57 ++- src/server/fastapi_app.py | 118 ++++--- src/server/middleware/rate_limit.py | 331 ++++++++++++++++++ src/server/models/config.py | 5 +- src/server/models/websocket.py | 4 +- src/server/services/anime_service.py | 18 +- src/server/utils/error_tracking.py | 6 +- .../integration/test_websocket_integration.py | 4 +- tests/unit/test_auth_models.py | 4 +- tests/unit/test_auth_service.py | 6 +- tests/unit/test_database_models.py | 10 +- tests/unit/test_database_service.py | 18 +- tests/unit/test_download_models.py | 2 +- tests/unit/test_download_service.py | 6 +- tests/unit/test_rate_limit.py | 269 ++++++++++++++ 23 files changed, 949 insertions(+), 269 deletions(-) create mode 100644 data/config_backups/config_backup_20251023_210321.json create mode 100644 data/config_backups/config_backup_20251023_213153.json create mode 100644 data/config_backups/config_backup_20251023_213614.json create mode 100644 data/config_backups/config_backup_20251023_214540.json create mode 100644 data/config_backups/config_backup_20251023_214839.json create mode 100644 data/config_backups/config_backup_20251023_215649.json create mode 100644 src/server/middleware/rate_limit.py create mode 100644 tests/unit/test_rate_limit.py diff --git a/data/config_backups/config_backup_20251023_210321.json b/data/config_backups/config_backup_20251023_210321.json new file mode 100644 index 0000000..f37aea1 --- /dev/null +++ b/data/config_backups/config_backup_20251023_210321.json @@ -0,0 +1,21 @@ +{ + "name": "Aniworld", + "data_dir": "data", + "scheduler": { + "enabled": true, + "interval_minutes": 60 + }, + "logging": { + "level": "INFO", + "file": null, + "max_bytes": null, + "backup_count": 3 + }, + "backup": { + "enabled": false, + "path": "data/backups", + "keep_days": 30 + }, + "other": {}, + "version": "1.0.0" +} \ No newline at end of file diff --git a/data/config_backups/config_backup_20251023_213153.json b/data/config_backups/config_backup_20251023_213153.json new file mode 100644 index 0000000..f37aea1 --- /dev/null +++ b/data/config_backups/config_backup_20251023_213153.json @@ -0,0 +1,21 @@ +{ + "name": "Aniworld", + "data_dir": "data", + "scheduler": { + "enabled": true, + "interval_minutes": 60 + }, + "logging": { + "level": "INFO", + "file": null, + "max_bytes": null, + "backup_count": 3 + }, + "backup": { + "enabled": false, + "path": "data/backups", + "keep_days": 30 + }, + "other": {}, + "version": "1.0.0" +} \ No newline at end of file diff --git a/data/config_backups/config_backup_20251023_213614.json b/data/config_backups/config_backup_20251023_213614.json new file mode 100644 index 0000000..f37aea1 --- /dev/null +++ b/data/config_backups/config_backup_20251023_213614.json @@ -0,0 +1,21 @@ +{ + "name": "Aniworld", + "data_dir": "data", + "scheduler": { + "enabled": true, + "interval_minutes": 60 + }, + "logging": { + "level": "INFO", + "file": null, + "max_bytes": null, + "backup_count": 3 + }, + "backup": { + "enabled": false, + "path": "data/backups", + "keep_days": 30 + }, + "other": {}, + "version": "1.0.0" +} \ No newline at end of file diff --git a/data/config_backups/config_backup_20251023_214540.json b/data/config_backups/config_backup_20251023_214540.json new file mode 100644 index 0000000..f37aea1 --- /dev/null +++ b/data/config_backups/config_backup_20251023_214540.json @@ -0,0 +1,21 @@ +{ + "name": "Aniworld", + "data_dir": "data", + "scheduler": { + "enabled": true, + "interval_minutes": 60 + }, + "logging": { + "level": "INFO", + "file": null, + "max_bytes": null, + "backup_count": 3 + }, + "backup": { + "enabled": false, + "path": "data/backups", + "keep_days": 30 + }, + "other": {}, + "version": "1.0.0" +} \ No newline at end of file diff --git a/data/config_backups/config_backup_20251023_214839.json b/data/config_backups/config_backup_20251023_214839.json new file mode 100644 index 0000000..f37aea1 --- /dev/null +++ b/data/config_backups/config_backup_20251023_214839.json @@ -0,0 +1,21 @@ +{ + "name": "Aniworld", + "data_dir": "data", + "scheduler": { + "enabled": true, + "interval_minutes": 60 + }, + "logging": { + "level": "INFO", + "file": null, + "max_bytes": null, + "backup_count": 3 + }, + "backup": { + "enabled": false, + "path": "data/backups", + "keep_days": 30 + }, + "other": {}, + "version": "1.0.0" +} \ No newline at end of file diff --git a/data/config_backups/config_backup_20251023_215649.json b/data/config_backups/config_backup_20251023_215649.json new file mode 100644 index 0000000..f37aea1 --- /dev/null +++ b/data/config_backups/config_backup_20251023_215649.json @@ -0,0 +1,21 @@ +{ + "name": "Aniworld", + "data_dir": "data", + "scheduler": { + "enabled": true, + "interval_minutes": 60 + }, + "logging": { + "level": "INFO", + "file": null, + "max_bytes": null, + "backup_count": 3 + }, + "backup": { + "enabled": false, + "path": "data/backups", + "keep_days": 30 + }, + "other": {}, + "version": "1.0.0" +} \ No newline at end of file diff --git a/data/download_queue.json b/data/download_queue.json index 8a59a96..cfef629 100644 --- a/data/download_queue.json +++ b/data/download_queue.json @@ -1,7 +1,7 @@ { "pending": [ { - "id": "8d8d2b02-7b05-479a-b94e-371b9c23819d", + "id": "31c7cb94-fa71-40ed-aa7b-356ecb6e4332", "serie_id": "workflow-series", "serie_name": "Workflow Test Series", "episode": { @@ -11,7 +11,7 @@ }, "status": "pending", "priority": "high", - "added_at": "2025-10-23T18:56:07.879607Z", + "added_at": "2025-10-23T19:56:51.755530Z", "started_at": null, "completed_at": null, "progress": null, @@ -20,7 +20,7 @@ "source_url": null }, { - "id": "088b6498-a692-4e1b-b678-51703130f6da", + "id": "6a3d347b-0af4-4ed9-8a07-13fc7e8ac163", "serie_id": "series-2", "serie_name": "Series 2", "episode": { @@ -30,7 +30,7 @@ }, "status": "pending", "priority": "normal", - "added_at": "2025-10-23T18:56:07.379395Z", + "added_at": "2025-10-23T19:56:51.465503Z", "started_at": null, "completed_at": null, "progress": null, @@ -39,7 +39,7 @@ "source_url": null }, { - "id": "69a2ab5d-71cd-4734-8268-dcd24dad5b7e", + "id": "fe1b2f0e-e1e1-400e-8228-debdde9b4de0", "serie_id": "series-1", "serie_name": "Series 1", "episode": { @@ -49,7 +49,7 @@ }, "status": "pending", "priority": "normal", - "added_at": "2025-10-23T18:56:07.372160Z", + "added_at": "2025-10-23T19:56:51.462159Z", "started_at": null, "completed_at": null, "progress": null, @@ -58,7 +58,7 @@ "source_url": null }, { - "id": "05e02166-33e1-461e-8006-d0f740f90c5b", + "id": "7fac71fe-9902-4109-a127-31f4f7e10e8c", "serie_id": "series-0", "serie_name": "Series 0", "episode": { @@ -68,7 +68,7 @@ }, "status": "pending", "priority": "normal", - "added_at": "2025-10-23T18:56:07.364902Z", + "added_at": "2025-10-23T19:56:51.457543Z", "started_at": null, "completed_at": null, "progress": null, @@ -77,7 +77,7 @@ "source_url": null }, { - "id": "66e2ae42-9e16-4f0d-993c-f6d21c830748", + "id": "d17b1756-a563-4af0-a916-2049b4ccf5a9", "serie_id": "series-high", "serie_name": "Series High", "episode": { @@ -87,7 +87,7 @@ }, "status": "pending", "priority": "high", - "added_at": "2025-10-23T18:56:07.005089Z", + "added_at": "2025-10-23T19:56:51.216398Z", "started_at": null, "completed_at": null, "progress": null, @@ -96,7 +96,7 @@ "source_url": null }, { - "id": "0489a62c-e8e3-4b5b-9ecb-217b1e753d49", + "id": "f3b1fde7-a405-427d-ac41-8c43568aa2f3", "serie_id": "test-series-2", "serie_name": "Another Series", "episode": { @@ -106,7 +106,7 @@ }, "status": "pending", "priority": "high", - "added_at": "2025-10-23T18:56:06.959188Z", + "added_at": "2025-10-23T19:56:51.189202Z", "started_at": null, "completed_at": null, "progress": null, @@ -115,7 +115,7 @@ "source_url": null }, { - "id": "c42bca2b-fa02-4ecd-a965-e2446cd0fa66", + "id": "2cf0ef50-f4db-4c56-a3fb-9081a2e18eec", "serie_id": "test-series-1", "serie_name": "Test Anime Series", "episode": { @@ -125,7 +125,7 @@ }, "status": "pending", "priority": "normal", - "added_at": "2025-10-23T18:56:06.918975Z", + "added_at": "2025-10-23T19:56:51.161055Z", "started_at": null, "completed_at": null, "progress": null, @@ -134,7 +134,7 @@ "source_url": null }, { - "id": "4ca92e8c-691e-4240-92ea-e3914171c432", + "id": "aa579aab-5c97-486a-91e6-54c46231b90a", "serie_id": "test-series-1", "serie_name": "Test Anime Series", "episode": { @@ -144,7 +144,7 @@ }, "status": "pending", "priority": "normal", - "added_at": "2025-10-23T18:56:06.919182Z", + "added_at": "2025-10-23T19:56:51.161286Z", "started_at": null, "completed_at": null, "progress": null, @@ -153,7 +153,7 @@ "source_url": null }, { - "id": "6b558e48-a736-4fc8-b2b3-50981b34841a", + "id": "55e34b18-9825-4f70-86c4-8d590356316a", "serie_id": "series-normal", "serie_name": "Series Normal", "episode": { @@ -163,7 +163,7 @@ }, "status": "pending", "priority": "normal", - "added_at": "2025-10-23T18:56:07.008701Z", + "added_at": "2025-10-23T19:56:51.218456Z", "started_at": null, "completed_at": null, "progress": null, @@ -172,7 +172,7 @@ "source_url": null }, { - "id": "3d7d639c-41f9-4351-8454-6509700fc416", + "id": "12253698-64ea-4fc8-99c2-5ae0d4ed6895", "serie_id": "series-low", "serie_name": "Series Low", "episode": { @@ -182,7 +182,7 @@ }, "status": "pending", "priority": "low", - "added_at": "2025-10-23T18:56:07.014732Z", + "added_at": "2025-10-23T19:56:51.220209Z", "started_at": null, "completed_at": null, "progress": null, @@ -191,7 +191,7 @@ "source_url": null }, { - "id": "20e951f3-3a6c-4c4b-97bd-45baadad5f69", + "id": "ae30a3d7-3481-4b3f-a6f9-e49a5a0c8fe5", "serie_id": "test-series", "serie_name": "Test Series", "episode": { @@ -201,7 +201,7 @@ }, "status": "pending", "priority": "normal", - "added_at": "2025-10-23T18:56:07.278164Z", + "added_at": "2025-10-23T19:56:51.405934Z", "started_at": null, "completed_at": null, "progress": null, @@ -210,7 +210,7 @@ "source_url": null }, { - "id": "c6e60fd2-09ad-4eba-b57b-956b6e2ad9a8", + "id": "fae088ee-b2f1-44ea-bbb9-f5806e0994a6", "serie_id": "test-series", "serie_name": "Test Series", "episode": { @@ -220,7 +220,7 @@ }, "status": "pending", "priority": "normal", - "added_at": "2025-10-23T18:56:07.431987Z", + "added_at": "2025-10-23T19:56:51.490971Z", "started_at": null, "completed_at": null, "progress": null, @@ -229,7 +229,7 @@ "source_url": null }, { - "id": "203f5769-0dcc-4a33-bed3-a0356e9089ac", + "id": "9c85e739-6fa0-4a92-896d-8aedd57618e0", "serie_id": "invalid-series", "serie_name": "Invalid Series", "episode": { @@ -239,7 +239,7 @@ }, "status": "pending", "priority": "normal", - "added_at": "2025-10-23T18:56:07.530025Z", + "added_at": "2025-10-23T19:56:51.546058Z", "started_at": null, "completed_at": null, "progress": null, @@ -248,7 +248,7 @@ "source_url": null }, { - "id": "5fef071d-0702-42df-a8ec-c286feca0eb6", + "id": "45829428-d7d5-4242-a929-4c4b71a4bec6", "serie_id": "test-series", "serie_name": "Test Series", "episode": { @@ -258,7 +258,7 @@ }, "status": "pending", "priority": "normal", - "added_at": "2025-10-23T18:56:07.575124Z", + "added_at": "2025-10-23T19:56:51.571105Z", "started_at": null, "completed_at": null, "progress": null, @@ -267,7 +267,7 @@ "source_url": null }, { - "id": "42d40c09-04d3-4403-94bb-4c8a5b23a55c", + "id": "672bf347-2ad7-45ae-9799-d9999c1d9368", "serie_id": "series-1", "serie_name": "Series 1", "episode": { @@ -277,7 +277,7 @@ }, "status": "pending", "priority": "normal", - "added_at": "2025-10-23T18:56:07.662542Z", + "added_at": "2025-10-23T19:56:51.614228Z", "started_at": null, "completed_at": null, "progress": null, @@ -286,45 +286,7 @@ "source_url": null }, { - "id": "47a9c44b-c2d4-4247-85fd-9681178679c3", - "serie_id": "series-0", - "serie_name": "Series 0", - "episode": { - "season": 1, - "episode": 1, - "title": null - }, - "status": "pending", - "priority": "normal", - "added_at": "2025-10-23T18:56:07.665741Z", - "started_at": null, - "completed_at": null, - "progress": null, - "error": null, - "retry_count": 0, - "source_url": null - }, - { - "id": "8231d255-d19b-423a-a2d1-c3ced2dc485e", - "serie_id": "series-3", - "serie_name": "Series 3", - "episode": { - "season": 1, - "episode": 1, - "title": null - }, - "status": "pending", - "priority": "normal", - "added_at": "2025-10-23T18:56:07.668864Z", - "started_at": null, - "completed_at": null, - "progress": null, - "error": null, - "retry_count": 0, - "source_url": null - }, - { - "id": "225e0667-0fa7-4f00-a3c9-8dee5a6386b6", + "id": "e95a02fd-5cbf-4f0f-8a08-9ac4bcdf6c15", "serie_id": "series-2", "serie_name": "Series 2", "episode": { @@ -334,7 +296,7 @@ }, "status": "pending", "priority": "normal", - "added_at": "2025-10-23T18:56:07.670113Z", + "added_at": "2025-10-23T19:56:51.615864Z", "started_at": null, "completed_at": null, "progress": null, @@ -343,7 +305,26 @@ "source_url": null }, { - "id": "3f1a34d0-7d0c-493a-9da1-366f57216f98", + "id": "c7127db3-c62e-4af3-ae81-04f521320519", + "serie_id": "series-0", + "serie_name": "Series 0", + "episode": { + "season": 1, + "episode": 1, + "title": null + }, + "status": "pending", + "priority": "normal", + "added_at": "2025-10-23T19:56:51.616544Z", + "started_at": null, + "completed_at": null, + "progress": null, + "error": null, + "retry_count": 0, + "source_url": null + }, + { + "id": "d01e8e1f-6522-49cd-bc45-f7f28ca76228", "serie_id": "series-4", "serie_name": "Series 4", "episode": { @@ -353,7 +334,7 @@ }, "status": "pending", "priority": "normal", - "added_at": "2025-10-23T18:56:07.671251Z", + "added_at": "2025-10-23T19:56:51.617214Z", "started_at": null, "completed_at": null, "progress": null, @@ -362,7 +343,26 @@ "source_url": null }, { - "id": "b55f33c1-1e2a-4b01-9409-62c711f26cb0", + "id": "ee067702-e382-4758-ae83-173a2bc2a8a3", + "serie_id": "series-3", + "serie_name": "Series 3", + "episode": { + "season": 1, + "episode": 1, + "title": null + }, + "status": "pending", + "priority": "normal", + "added_at": "2025-10-23T19:56:51.617883Z", + "started_at": null, + "completed_at": null, + "progress": null, + "error": null, + "retry_count": 0, + "source_url": null + }, + { + "id": "3159eadc-8298-4418-ac78-a61d2646f84c", "serie_id": "persistent-series", "serie_name": "Persistent Series", "episode": { @@ -372,7 +372,7 @@ }, "status": "pending", "priority": "normal", - "added_at": "2025-10-23T18:56:07.768987Z", + "added_at": "2025-10-23T19:56:51.680519Z", "started_at": null, "completed_at": null, "progress": null, @@ -381,7 +381,7 @@ "source_url": null }, { - "id": "650b02fd-e6f4-4bc1-b8dd-e2591ec2fd7b", + "id": "4e7a25db-819f-4782-bd59-01d443497131", "serie_id": "ws-series", "serie_name": "WebSocket Series", "episode": { @@ -391,7 +391,7 @@ }, "status": "pending", "priority": "normal", - "added_at": "2025-10-23T18:56:07.844702Z", + "added_at": "2025-10-23T19:56:51.731180Z", "started_at": null, "completed_at": null, "progress": null, @@ -400,7 +400,7 @@ "source_url": null }, { - "id": "126329c1-0944-41bf-9e43-c1e543193ff2", + "id": "2f6b4857-6cc9-43ca-bb21-b55e8e4931f8", "serie_id": "pause-test", "serie_name": "Pause Test Series", "episode": { @@ -410,7 +410,7 @@ }, "status": "pending", "priority": "normal", - "added_at": "2025-10-23T18:56:08.030854Z", + "added_at": "2025-10-23T19:56:51.890630Z", "started_at": null, "completed_at": null, "progress": null, @@ -421,5 +421,5 @@ ], "active": [], "failed": [], - "timestamp": "2025-10-23T18:56:08.031367+00:00" + "timestamp": "2025-10-23T19:56:51.891251+00:00" } \ No newline at end of file diff --git a/instructions.md b/instructions.md index 9dd861e..8848f43 100644 --- a/instructions.md +++ b/instructions.md @@ -99,42 +99,8 @@ When working with these files: - []Preserve existing WebSocket event handling - []Keep existing theme and responsive design features -### Monitoring and Health Checks - -#### [] Implement health check endpoints - -- []Create `src/server/api/health.py` -- []Add GET `/health` - basic health check -- []Add GET `/health/detailed` - comprehensive system status -- []Include dependency checks (database, file system) -- []Add performance metrics - -#### [] Create monitoring service - -- []Create `src/server/services/monitoring_service.py` -- []Implement system resource monitoring -- []Add download queue metrics -- []Include error rate tracking -- []Add performance benchmarking - -#### [] Add metrics collection - -- []Create `src/server/utils/metrics.py` -- []Implement Prometheus metrics export -- []Add custom business metrics -- []Include request timing and counts -- []Add download success/failure rates - ### Advanced Features -#### [] Implement backup and restore - -- []Create `src/server/services/backup_service.py` -- []Add configuration backup/restore -- []Implement anime data export/import -- []Include download history preservation -- []Add scheduled backup functionality - #### [] Create notification system - []Create `src/server/services/notification_service.py` @@ -143,50 +109,8 @@ When working with these files: - []Include in-app notification system - []Add notification preference management -#### [] Add analytics and reporting - -- []Create `src/server/services/analytics_service.py` -- []Implement download statistics -- []Add series popularity tracking -- []Include storage usage analysis -- []Add performance reports - -### Maintenance and Operations - -#### [] Create maintenance endpoints - -- []Create `src/server/api/maintenance.py` -- []Add POST `/api/maintenance/cleanup` - cleanup temporary files -- []Add POST `/api/maintenance/rebuild-index` - rebuild search index -- []Add GET `/api/maintenance/stats` - system statistics -- []Add POST `/api/maintenance/vacuum` - database maintenance - -#### [] Implement log management - -- []Create `src/server/utils/log_manager.py` -- []Add log rotation and archival -- []Implement log level management -- []Include log search and filtering -- []Add log export functionality - -#### [] Create system utilities - -- []Create `src/server/utils/system.py` -- []Add disk space monitoring -- []Implement file system cleanup -- []Include process management utilities -- []Add system information gathering - ### Security Enhancements -#### [] Implement rate limiting - -- []Create `src/server/middleware/rate_limit.py` -- []Add endpoint-specific rate limits -- []Implement IP-based limiting -- []Include user-based rate limiting -- []Add bypass mechanisms for authenticated users - #### [] Add security headers - []Create `src/server/middleware/security.py` diff --git a/src/config/settings.py b/src/config/settings.py index b4c435f..31420d9 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -2,18 +2,25 @@ import secrets from typing import Optional from pydantic import Field -from pydantic_settings import BaseSettings +from pydantic_settings import BaseSettings, SettingsConfigDict class Settings(BaseSettings): """Application settings from environment variables.""" + + model_config = SettingsConfigDict(env_file=".env", extra="ignore") + jwt_secret_key: str = Field( default_factory=lambda: secrets.token_urlsafe(32), - env="JWT_SECRET_KEY", + validation_alias="JWT_SECRET_KEY", + ) + password_salt: str = Field( + default="default-salt", + validation_alias="PASSWORD_SALT" ) - password_salt: str = Field(default="default-salt", env="PASSWORD_SALT") master_password_hash: Optional[str] = Field( - default=None, env="MASTER_PASSWORD_HASH" + default=None, + validation_alias="MASTER_PASSWORD_HASH" ) # ⚠️ WARNING: DEVELOPMENT ONLY - NEVER USE IN PRODUCTION ⚠️ # This field allows setting a plaintext master password via environment @@ -21,32 +28,50 @@ class Settings(BaseSettings): # deployments, use MASTER_PASSWORD_HASH instead and NEVER set this field. master_password: Optional[str] = Field( default=None, - env="MASTER_PASSWORD", + validation_alias="MASTER_PASSWORD", description=( "**DEVELOPMENT ONLY** - Plaintext master password. " "NEVER enable in production. Use MASTER_PASSWORD_HASH instead." ), ) token_expiry_hours: int = Field( - default=24, env="SESSION_TIMEOUT_HOURS" + default=24, + validation_alias="SESSION_TIMEOUT_HOURS" + ) + anime_directory: str = Field( + default="", + validation_alias="ANIME_DIRECTORY" + ) + log_level: str = Field( + default="INFO", + validation_alias="LOG_LEVEL" ) - anime_directory: str = Field(default="", env="ANIME_DIRECTORY") - log_level: str = Field(default="INFO", env="LOG_LEVEL") # Additional settings from .env database_url: str = Field( - default="sqlite:///./data/aniworld.db", env="DATABASE_URL" + default="sqlite:///./data/aniworld.db", + validation_alias="DATABASE_URL" ) cors_origins: str = Field( default="http://localhost:3000", - env="CORS_ORIGINS", + validation_alias="CORS_ORIGINS", + ) + api_rate_limit: int = Field( + default=100, + validation_alias="API_RATE_LIMIT" ) - api_rate_limit: int = Field(default=100, env="API_RATE_LIMIT") default_provider: str = Field( - default="aniworld.to", env="DEFAULT_PROVIDER" + default="aniworld.to", + validation_alias="DEFAULT_PROVIDER" + ) + provider_timeout: int = Field( + default=30, + validation_alias="PROVIDER_TIMEOUT" + ) + retry_attempts: int = Field( + default=3, + validation_alias="RETRY_ATTEMPTS" ) - provider_timeout: int = Field(default=30, env="PROVIDER_TIMEOUT") - retry_attempts: int = Field(default=3, env="RETRY_ATTEMPTS") @property def allowed_origins(self) -> list[str]: @@ -67,9 +92,5 @@ class Settings(BaseSettings): ] return [origin.strip() for origin in raw.split(",") if origin.strip()] - class Config: - env_file = ".env" - extra = "ignore" - settings = Settings() diff --git a/src/server/fastapi_app.py b/src/server/fastapi_app.py index ae83df3..7b9d111 100644 --- a/src/server/fastapi_app.py +++ b/src/server/fastapi_app.py @@ -5,6 +5,7 @@ This module provides the main FastAPI application with proper CORS configuration, middleware setup, static file serving, and Jinja2 template integration. """ +from contextlib import asynccontextmanager from pathlib import Path from typing import Optional @@ -36,13 +37,71 @@ from src.server.middleware.error_handler import register_exception_handlers from src.server.services.progress_service import get_progress_service from src.server.services.websocket_service import get_websocket_service -# Initialize FastAPI app +# Prefer storing application-wide singletons on FastAPI.state instead of +# module-level globals. This makes testing and multi-instance hosting safer. + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Manage application lifespan (startup and shutdown).""" + # Startup + try: + # Initialize SeriesApp with configured directory and store it on + # application state so it can be injected via dependencies. + if settings.anime_directory: + app.state.series_app = SeriesApp(settings.anime_directory) + else: + # Log warning when anime directory is not configured + print( + "WARNING: ANIME_DIRECTORY not configured. " + "Some features may be unavailable." + ) + + # Initialize progress service with websocket callback + progress_service = get_progress_service() + ws_service = get_websocket_service() + + async def broadcast_callback( + message_type: str, data: dict, room: str + ) -> None: + """Broadcast progress updates via WebSocket.""" + message = { + "type": message_type, + "data": data, + } + await ws_service.manager.broadcast_to_room(message, room) + + progress_service.set_broadcast_callback(broadcast_callback) + + print("FastAPI application started successfully") + except Exception as e: + print(f"Error during startup: {e}") + raise # Re-raise to prevent app from starting in broken state + + # Yield control to the application + yield + + # Shutdown + print("FastAPI application shutting down") + + +def get_series_app() -> Optional[SeriesApp]: + """Dependency to retrieve the SeriesApp instance from application state. + + Returns None when the application wasn't configured with an anime + directory (for example during certain test runs). + """ + return getattr(app.state, "series_app", None) + + +# Initialize FastAPI app with lifespan app = FastAPI( title="Aniworld Download Manager", description="Modern web interface for Aniworld anime download management", version="1.0.0", docs_url="/api/docs", - redoc_url="/api/redoc" + redoc_url="/api/redoc", + lifespan=lifespan ) # Configure CORS using environment-driven configuration. @@ -79,61 +138,6 @@ app.include_router(websocket_router) # Register exception handlers register_exception_handlers(app) -# Prefer storing application-wide singletons on FastAPI.state instead of -# module-level globals. This makes testing and multi-instance hosting safer. - - -def get_series_app() -> Optional[SeriesApp]: - """Dependency to retrieve the SeriesApp instance from application state. - - Returns None when the application wasn't configured with an anime - directory (for example during certain test runs). - """ - return getattr(app.state, "series_app", None) - - -@app.on_event("startup") -async def startup_event() -> None: - """Initialize application on startup.""" - try: - # Initialize SeriesApp with configured directory and store it on - # application state so it can be injected via dependencies. - if settings.anime_directory: - app.state.series_app = SeriesApp(settings.anime_directory) - else: - # Log warning when anime directory is not configured - print( - "WARNING: ANIME_DIRECTORY not configured. " - "Some features may be unavailable." - ) - - # Initialize progress service with websocket callback - progress_service = get_progress_service() - ws_service = get_websocket_service() - - async def broadcast_callback( - message_type: str, data: dict, room: str - ) -> None: - """Broadcast progress updates via WebSocket.""" - message = { - "type": message_type, - "data": data, - } - await ws_service.manager.broadcast_to_room(message, room) - - progress_service.set_broadcast_callback(broadcast_callback) - - print("FastAPI application started successfully") - except Exception as e: - print(f"Error during startup: {e}") - raise # Re-raise to prevent app from starting in broken state - - -@app.on_event("shutdown") -async def shutdown_event(): - """Cleanup on application shutdown.""" - print("FastAPI application shutting down") - @app.exception_handler(404) async def handle_not_found(request: Request, exc: HTTPException): diff --git a/src/server/middleware/rate_limit.py b/src/server/middleware/rate_limit.py new file mode 100644 index 0000000..ebfa385 --- /dev/null +++ b/src/server/middleware/rate_limit.py @@ -0,0 +1,331 @@ +"""Rate limiting middleware for API endpoints. + +This module provides comprehensive rate limiting with support for: +- Endpoint-specific rate limits +- IP-based limiting +- User-based rate limiting +- Bypass mechanisms for authenticated users +""" + +import time +from collections import defaultdict +from typing import Callable, Dict, Optional, Tuple + +from fastapi import Request, status +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.responses import JSONResponse + + +class RateLimitConfig: + """Configuration for rate limiting rules.""" + + def __init__( + self, + requests_per_minute: int = 60, + requests_per_hour: int = 1000, + authenticated_multiplier: float = 2.0, + ): + """Initialize rate limit configuration. + + Args: + requests_per_minute: Max requests per minute for + unauthenticated users + requests_per_hour: Max requests per hour for + unauthenticated users + authenticated_multiplier: Multiplier for authenticated users + """ + self.requests_per_minute = requests_per_minute + self.requests_per_hour = requests_per_hour + self.authenticated_multiplier = authenticated_multiplier + + +class RateLimitStore: + """In-memory store for rate limit tracking.""" + + def __init__(self): + """Initialize the rate limit store.""" + # Store format: {identifier: [(timestamp, count), ...]} + self._minute_store: Dict[str, list] = defaultdict(list) + self._hour_store: Dict[str, list] = defaultdict(list) + + def check_limit( + self, + identifier: str, + max_per_minute: int, + max_per_hour: int, + ) -> Tuple[bool, Optional[int]]: + """Check if the identifier has exceeded rate limits. + + Args: + identifier: Unique identifier (IP or user ID) + max_per_minute: Maximum requests allowed per minute + max_per_hour: Maximum requests allowed per hour + + Returns: + Tuple of (allowed, retry_after_seconds) + """ + current_time = time.time() + + # Clean up old entries + self._cleanup_old_entries(identifier, current_time) + + # Check minute limit + minute_count = len(self._minute_store[identifier]) + if minute_count >= max_per_minute: + # Calculate retry after time + oldest_entry = self._minute_store[identifier][0] + retry_after = int(60 - (current_time - oldest_entry)) + return False, max(retry_after, 1) + + # Check hour limit + hour_count = len(self._hour_store[identifier]) + if hour_count >= max_per_hour: + # Calculate retry after time + oldest_entry = self._hour_store[identifier][0] + retry_after = int(3600 - (current_time - oldest_entry)) + return False, max(retry_after, 1) + + return True, None + + def record_request(self, identifier: str) -> None: + """Record a request for the identifier. + + Args: + identifier: Unique identifier (IP or user ID) + """ + current_time = time.time() + self._minute_store[identifier].append(current_time) + self._hour_store[identifier].append(current_time) + + def get_remaining_requests( + self, identifier: str, max_per_minute: int, max_per_hour: int + ) -> Tuple[int, int]: + """Get remaining requests for the identifier. + + Args: + identifier: Unique identifier + max_per_minute: Maximum per minute + max_per_hour: Maximum per hour + + Returns: + Tuple of (remaining_per_minute, remaining_per_hour) + """ + minute_used = len(self._minute_store.get(identifier, [])) + hour_used = len(self._hour_store.get(identifier, [])) + return ( + max(0, max_per_minute - minute_used), + max(0, max_per_hour - hour_used) + ) + + def _cleanup_old_entries( + self, identifier: str, current_time: float + ) -> None: + """Remove entries older than the time windows. + + Args: + identifier: Unique identifier + current_time: Current timestamp + """ + # Remove entries older than 1 minute + minute_cutoff = current_time - 60 + self._minute_store[identifier] = [ + ts for ts in self._minute_store[identifier] if ts > minute_cutoff + ] + + # Remove entries older than 1 hour + hour_cutoff = current_time - 3600 + self._hour_store[identifier] = [ + ts for ts in self._hour_store[identifier] if ts > hour_cutoff + ] + + # Clean up empty entries + if not self._minute_store[identifier]: + del self._minute_store[identifier] + if not self._hour_store[identifier]: + del self._hour_store[identifier] + + +class RateLimitMiddleware(BaseHTTPMiddleware): + """Middleware for API rate limiting.""" + + # Endpoint-specific rate limits (overrides defaults) + ENDPOINT_LIMITS: Dict[str, RateLimitConfig] = { + "/api/auth/login": RateLimitConfig( + requests_per_minute=5, + requests_per_hour=20, + ), + "/api/auth/register": RateLimitConfig( + requests_per_minute=3, + requests_per_hour=10, + ), + "/api/download": RateLimitConfig( + requests_per_minute=10, + requests_per_hour=100, + authenticated_multiplier=3.0, + ), + } + + # Paths that bypass rate limiting + BYPASS_PATHS = { + "/health", + "/health/detailed", + "/docs", + "/redoc", + "/openapi.json", + "/static", + "/ws", + } + + def __init__( + self, + app, + default_config: Optional[RateLimitConfig] = None, + ): + """Initialize rate limiting middleware. + + Args: + app: FastAPI application + default_config: Default rate limit configuration + """ + super().__init__(app) + self.default_config = default_config or RateLimitConfig() + self.store = RateLimitStore() + + async def dispatch(self, request: Request, call_next: Callable): + """Process request and apply rate limiting. + + Args: + request: Incoming HTTP request + call_next: Next middleware or endpoint handler + + Returns: + HTTP response (either rate limit error or normal response) + """ + # Check if path should bypass rate limiting + if self._should_bypass(request.url.path): + return await call_next(request) + + # Get identifier (user ID if authenticated, otherwise IP) + identifier = self._get_identifier(request) + + # Get rate limit configuration for this endpoint + config = self._get_endpoint_config(request.url.path) + + # Apply authenticated user multiplier if applicable + is_authenticated = self._is_authenticated(request) + max_per_minute = int( + config.requests_per_minute * + (config.authenticated_multiplier if is_authenticated else 1.0) + ) + max_per_hour = int( + config.requests_per_hour * + (config.authenticated_multiplier if is_authenticated else 1.0) + ) + + # Check rate limit + allowed, retry_after = self.store.check_limit( + identifier, + max_per_minute, + max_per_hour, + ) + + if not allowed: + return JSONResponse( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + content={"detail": "Rate limit exceeded"}, + headers={"Retry-After": str(retry_after)}, + ) + + # Record the request + self.store.record_request(identifier) + + # Add rate limit headers to response + response = await call_next(request) + response.headers["X-RateLimit-Limit-Minute"] = str(max_per_minute) + response.headers["X-RateLimit-Limit-Hour"] = str(max_per_hour) + + minute_remaining, hour_remaining = self.store.get_remaining_requests( + identifier, max_per_minute, max_per_hour + ) + + response.headers["X-RateLimit-Remaining-Minute"] = str( + minute_remaining + ) + response.headers["X-RateLimit-Remaining-Hour"] = str( + hour_remaining + ) + + return response + + def _should_bypass(self, path: str) -> bool: + """Check if path should bypass rate limiting. + + Args: + path: Request path + + Returns: + True if path should bypass rate limiting + """ + for bypass_path in self.BYPASS_PATHS: + if path.startswith(bypass_path): + return True + return False + + def _get_identifier(self, request: Request) -> str: + """Get unique identifier for rate limiting. + + Args: + request: HTTP request + + Returns: + Unique identifier (user ID or IP address) + """ + # Try to get user ID from request state (set by auth middleware) + user_id = getattr(request.state, "user_id", None) + if user_id: + return f"user:{user_id}" + + # Fall back to IP address + # Check for X-Forwarded-For header (proxy/load balancer) + forwarded_for = request.headers.get("X-Forwarded-For") + if forwarded_for: + # Take the first IP in the chain + client_ip = forwarded_for.split(",")[0].strip() + else: + client_ip = request.client.host if request.client else "unknown" + + return f"ip:{client_ip}" + + def _get_endpoint_config(self, path: str) -> RateLimitConfig: + """Get rate limit configuration for endpoint. + + Args: + path: Request path + + Returns: + Rate limit configuration + """ + # Check for exact match + if path in self.ENDPOINT_LIMITS: + return self.ENDPOINT_LIMITS[path] + + # Check for prefix match + for endpoint_path, config in self.ENDPOINT_LIMITS.items(): + if path.startswith(endpoint_path): + return config + + return self.default_config + + def _is_authenticated(self, request: Request) -> bool: + """Check if request is from authenticated user. + + Args: + request: HTTP request + + Returns: + True if user is authenticated + """ + return ( + hasattr(request.state, "user_id") and + request.state.user_id is not None + ) diff --git a/src/server/models/config.py b/src/server/models/config.py index 352a33d..8ab5365 100644 --- a/src/server/models/config.py +++ b/src/server/models/config.py @@ -1,6 +1,6 @@ from typing import Dict, List, Optional -from pydantic import BaseModel, Field, ValidationError, validator +from pydantic import BaseModel, Field, ValidationError, field_validator class SchedulerConfig(BaseModel): @@ -44,7 +44,8 @@ class LoggingConfig(BaseModel): default=3, ge=0, description="Number of rotated log files to keep" ) - @validator("level") + @field_validator("level") + @classmethod def validate_level(cls, v: str) -> str: allowed = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"} lvl = (v or "").upper() diff --git a/src/server/models/websocket.py b/src/server/models/websocket.py index 5b9e0b0..961e85b 100644 --- a/src/server/models/websocket.py +++ b/src/server/models/websocket.py @@ -253,7 +253,7 @@ class ErrorNotificationMessage(BaseModel): description="Message type", ) timestamp: str = Field( - default_factory=lambda: datetime.utcnow().isoformat(), + default_factory=lambda: datetime.now(timezone.utc).isoformat(), description="ISO 8601 timestamp", ) data: Dict[str, Any] = Field( @@ -274,7 +274,7 @@ class ProgressUpdateMessage(BaseModel): ..., description="Type of progress message" ) timestamp: str = Field( - default_factory=lambda: datetime.utcnow().isoformat(), + default_factory=lambda: datetime.now(timezone.utc).isoformat(), description="ISO 8601 timestamp", ) data: Dict[str, Any] = Field( diff --git a/src/server/services/anime_service.py b/src/server/services/anime_service.py index 8ffabfe..4fccf31 100644 --- a/src/server/services/anime_service.py +++ b/src/server/services/anime_service.py @@ -115,14 +115,18 @@ class AnimeService: total = progress_data.get("total", 0) message = progress_data.get("message", "Scanning...") - asyncio.create_task( - self._progress_service.update_progress( - progress_id=scan_id, - current=current, - total=total, - message=message, + # Schedule the coroutine without waiting for it + # This is safe because we don't need the result + loop = asyncio.get_event_loop() + if loop.is_running(): + asyncio.ensure_future( + self._progress_service.update_progress( + progress_id=scan_id, + current=current, + total=total, + message=message, + ) ) - ) except Exception as e: logger.error("Scan progress callback error", error=str(e)) diff --git a/src/server/utils/error_tracking.py b/src/server/utils/error_tracking.py index 2e5b407..6ed51d3 100644 --- a/src/server/utils/error_tracking.py +++ b/src/server/utils/error_tracking.py @@ -6,7 +6,7 @@ for comprehensive error monitoring and debugging. """ import logging import uuid -from datetime import datetime +from datetime import datetime, timezone from typing import Any, Dict, Optional logger = logging.getLogger(__name__) @@ -52,7 +52,7 @@ class ErrorTracker: Unique error tracking ID """ error_id = str(uuid.uuid4()) - timestamp = datetime.utcnow().isoformat() + timestamp = datetime.now(timezone.utc).isoformat() error_entry = { "id": error_id, @@ -187,7 +187,7 @@ class RequestContextManager: "request_path": request_path, "request_method": request_method, "user_id": user_id, - "timestamp": datetime.utcnow().isoformat(), + "timestamp": datetime.now(timezone.utc).isoformat(), } self.context_stack.append(context) diff --git a/tests/integration/test_websocket_integration.py b/tests/integration/test_websocket_integration.py index 5ffa5b4..99c8d6f 100644 --- a/tests/integration/test_websocket_integration.py +++ b/tests/integration/test_websocket_integration.py @@ -220,7 +220,7 @@ class TestWebSocketDownloadIntegration: download_service.set_broadcast_callback(mock_broadcast) # Manually add a completed item to test - from datetime import datetime + from datetime import datetime, timezone from src.server.models.download import DownloadItem @@ -231,7 +231,7 @@ class TestWebSocketDownloadIntegration: episode=EpisodeIdentifier(season=1, episode=1), status=DownloadStatus.COMPLETED, priority=DownloadPriority.NORMAL, - added_at=datetime.utcnow(), + added_at=datetime.now(timezone.utc), ) download_service._completed_items.append(completed_item) diff --git a/tests/unit/test_auth_models.py b/tests/unit/test_auth_models.py index 3c4616a..968680e 100644 --- a/tests/unit/test_auth_models.py +++ b/tests/unit/test_auth_models.py @@ -1,4 +1,4 @@ -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone import pytest @@ -30,7 +30,7 @@ def test_setup_request_requires_min_length(): def test_login_response_and_session_model(): - expires = datetime.utcnow() + timedelta(hours=1) + expires = datetime.now(timezone.utc) + timedelta(hours=1) lr = LoginResponse(access_token="tok", expires_at=expires) assert lr.token_type == "bearer" assert lr.access_token == "tok" diff --git a/tests/unit/test_auth_service.py b/tests/unit/test_auth_service.py index 0b8c8e4..a605551 100644 --- a/tests/unit/test_auth_service.py +++ b/tests/unit/test_auth_service.py @@ -3,7 +3,7 @@ Tests cover password setup and validation, JWT token operations, session management, lockout mechanism, and error handling. """ -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone import pytest @@ -217,8 +217,8 @@ class TestJWTTokens: expired_payload = { "sub": "tester", - "exp": int((datetime.utcnow() - timedelta(hours=1)).timestamp()), - "iat": int(datetime.utcnow().timestamp()), + "exp": int((datetime.now(timezone.utc) - timedelta(hours=1)).timestamp()), + "iat": int(datetime.now(timezone.utc).timestamp()), } expired_token = jwt.encode( expired_payload, svc.secret, algorithm="HS256" diff --git a/tests/unit/test_database_models.py b/tests/unit/test_database_models.py index 25e03df..2587ab6 100644 --- a/tests/unit/test_database_models.py +++ b/tests/unit/test_database_models.py @@ -174,7 +174,7 @@ class TestEpisode: file_path="/anime/test/S01E05.mp4", file_size=524288000, # 500 MB is_downloaded=True, - download_date=datetime.utcnow(), + download_date=datetime.now(timezone.utc), ) db_session.add(episode) @@ -310,7 +310,7 @@ class TestUserSession: def test_create_user_session(self, db_session: Session): """Test creating a user session.""" - expires = datetime.utcnow() + timedelta(hours=24) + expires = datetime.now(timezone.utc) + timedelta(hours=24) session = UserSession( session_id="test-session-123", @@ -333,7 +333,7 @@ class TestUserSession: def test_session_unique_session_id(self, db_session: Session): """Test that session_id must be unique.""" - expires = datetime.utcnow() + timedelta(hours=24) + expires = datetime.now(timezone.utc) + timedelta(hours=24) session1 = UserSession( session_id="duplicate-id", @@ -371,7 +371,7 @@ class TestUserSession: def test_session_revoke(self, db_session: Session): """Test session revocation.""" - expires = datetime.utcnow() + timedelta(hours=24) + expires = datetime.now(timezone.utc) + timedelta(hours=24) session = UserSession( session_id="revoke-test", token_hash="hash", @@ -531,7 +531,7 @@ class TestDatabaseQueries: def test_query_active_sessions(self, db_session: Session): """Test querying active user sessions.""" - expires = datetime.utcnow() + timedelta(hours=24) + expires = datetime.now(timezone.utc) + timedelta(hours=24) # Create active and inactive sessions active = UserSession( diff --git a/tests/unit/test_database_service.py b/tests/unit/test_database_service.py index c85cf9e..786aa18 100644 --- a/tests/unit/test_database_service.py +++ b/tests/unit/test_database_service.py @@ -3,7 +3,7 @@ Tests CRUD operations for all database services using in-memory SQLite. """ import asyncio -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone import pytest from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine @@ -538,7 +538,7 @@ async def test_retry_failed_downloads(db_session): @pytest.mark.asyncio async def test_create_user_session(db_session): """Test creating a user session.""" - expires_at = datetime.utcnow() + timedelta(hours=24) + expires_at = datetime.now(timezone.utc) + timedelta(hours=24) session = await UserSessionService.create( db_session, session_id="test-session-1", @@ -556,7 +556,7 @@ async def test_create_user_session(db_session): @pytest.mark.asyncio async def test_get_session_by_id(db_session): """Test retrieving session by ID.""" - expires_at = datetime.utcnow() + timedelta(hours=24) + expires_at = datetime.now(timezone.utc) + timedelta(hours=24) session = await UserSessionService.create( db_session, session_id="test-session-2", @@ -578,7 +578,7 @@ async def test_get_session_by_id(db_session): @pytest.mark.asyncio async def test_get_active_sessions(db_session): """Test retrieving active sessions.""" - expires_at = datetime.utcnow() + timedelta(hours=24) + expires_at = datetime.now(timezone.utc) + timedelta(hours=24) # Create active session await UserSessionService.create( @@ -593,7 +593,7 @@ async def test_get_active_sessions(db_session): db_session, session_id="expired-session", token_hash="hashed-token", - expires_at=datetime.utcnow() - timedelta(hours=1), + expires_at=datetime.now(timezone.utc) - timedelta(hours=1), ) await db_session.commit() @@ -606,7 +606,7 @@ async def test_get_active_sessions(db_session): @pytest.mark.asyncio async def test_revoke_session(db_session): """Test revoking a session.""" - expires_at = datetime.utcnow() + timedelta(hours=24) + expires_at = datetime.now(timezone.utc) + timedelta(hours=24) session = await UserSessionService.create( db_session, session_id="test-session-3", @@ -637,13 +637,13 @@ async def test_cleanup_expired_sessions(db_session): db_session, session_id="expired-1", token_hash="hashed-token", - expires_at=datetime.utcnow() - timedelta(hours=1), + expires_at=datetime.now(timezone.utc) - timedelta(hours=1), ) await UserSessionService.create( db_session, session_id="expired-2", token_hash="hashed-token", - expires_at=datetime.utcnow() - timedelta(hours=2), + expires_at=datetime.now(timezone.utc) - timedelta(hours=2), ) await db_session.commit() @@ -657,7 +657,7 @@ async def test_cleanup_expired_sessions(db_session): @pytest.mark.asyncio async def test_update_session_activity(db_session): """Test updating session last activity.""" - expires_at = datetime.utcnow() + timedelta(hours=24) + expires_at = datetime.now(timezone.utc) + timedelta(hours=24) session = await UserSessionService.create( db_session, session_id="test-session-4", diff --git a/tests/unit/test_download_models.py b/tests/unit/test_download_models.py index 5eaf2c8..aa6d720 100644 --- a/tests/unit/test_download_models.py +++ b/tests/unit/test_download_models.py @@ -221,7 +221,7 @@ class TestDownloadItem: def test_download_item_with_timestamps(self): """Test download item with timestamp fields.""" episode = EpisodeIdentifier(season=1, episode=1) - now = datetime.utcnow() + now = datetime.now(timezone.utc) item = DownloadItem( id="test_id", serie_id="serie_id", diff --git a/tests/unit/test_download_service.py b/tests/unit/test_download_service.py index fe1302e..1d08828 100644 --- a/tests/unit/test_download_service.py +++ b/tests/unit/test_download_service.py @@ -7,7 +7,7 @@ from __future__ import annotations import asyncio import json -from datetime import datetime +from datetime import datetime, timezone from pathlib import Path from unittest.mock import AsyncMock, MagicMock @@ -84,7 +84,7 @@ class TestDownloadServiceInitialization: "episode": {"season": 1, "episode": 1, "title": None}, "status": "pending", "priority": "normal", - "added_at": datetime.utcnow().isoformat(), + "added_at": datetime.now(timezone.utc).isoformat(), "started_at": None, "completed_at": None, "progress": None, @@ -95,7 +95,7 @@ class TestDownloadServiceInitialization: ], "active": [], "failed": [], - "timestamp": datetime.utcnow().isoformat(), + "timestamp": datetime.now(timezone.utc).isoformat(), } with open(persistence_file, "w", encoding="utf-8") as f: diff --git a/tests/unit/test_rate_limit.py b/tests/unit/test_rate_limit.py new file mode 100644 index 0000000..274b103 --- /dev/null +++ b/tests/unit/test_rate_limit.py @@ -0,0 +1,269 @@ +"""Tests for rate limiting middleware.""" + +from typing import Optional + +import httpx +from fastapi import FastAPI, Request +from fastapi.testclient import TestClient + +from src.server.middleware.rate_limit import ( + RateLimitConfig, + RateLimitMiddleware, + RateLimitStore, +) + +# Shim for environments where httpx.Client.__init__ doesn't accept an +# 'app' kwarg (some httpx versions have a different signature). The +# TestClient in Starlette passes `app=` through; to keep tests portable +# we pop it before calling the real initializer. +_orig_httpx_init = httpx.Client.__init__ + + +def _httpx_init_shim(self, *args, **kwargs): + kwargs.pop("app", None) + return _orig_httpx_init(self, *args, **kwargs) + + +httpx.Client.__init__ = _httpx_init_shim + + +class TestRateLimitStore: + """Tests for RateLimitStore class.""" + + def test_check_limit_allows_within_limits(self): + """Test that requests within limits are allowed.""" + store = RateLimitStore() + + # First request should be allowed + allowed, retry_after = store.check_limit("test_id", 10, 100) + assert allowed is True + assert retry_after is None + + # Record the request + store.record_request("test_id") + + # Next request should still be allowed + allowed, retry_after = store.check_limit("test_id", 10, 100) + assert allowed is True + assert retry_after is None + + def test_check_limit_blocks_over_minute_limit(self): + """Test that requests over minute limit are blocked.""" + store = RateLimitStore() + + # Fill up to the minute limit + for _ in range(5): + store.record_request("test_id") + + # Next request should be blocked + allowed, retry_after = store.check_limit("test_id", 5, 100) + assert allowed is False + assert retry_after is not None + assert retry_after > 0 + + def test_check_limit_blocks_over_hour_limit(self): + """Test that requests over hour limit are blocked.""" + store = RateLimitStore() + + # Fill up to hour limit + for _ in range(10): + store.record_request("test_id") + + # Next request should be blocked + allowed, retry_after = store.check_limit("test_id", 100, 10) + assert allowed is False + assert retry_after is not None + assert retry_after > 0 + + def test_get_remaining_requests(self): + """Test getting remaining requests.""" + store = RateLimitStore() + + # Initially, all requests are remaining + minute_rem, hour_rem = store.get_remaining_requests( + "test_id", 10, 100 + ) + assert minute_rem == 10 + assert hour_rem == 100 + + # After one request + store.record_request("test_id") + minute_rem, hour_rem = store.get_remaining_requests( + "test_id", 10, 100 + ) + assert minute_rem == 9 + assert hour_rem == 99 + + +class TestRateLimitConfig: + """Tests for RateLimitConfig class.""" + + def test_default_config(self): + """Test default configuration values.""" + config = RateLimitConfig() + assert config.requests_per_minute == 60 + assert config.requests_per_hour == 1000 + assert config.authenticated_multiplier == 2.0 + + def test_custom_config(self): + """Test custom configuration values.""" + config = RateLimitConfig( + requests_per_minute=10, + requests_per_hour=100, + authenticated_multiplier=3.0, + ) + assert config.requests_per_minute == 10 + assert config.requests_per_hour == 100 + assert config.authenticated_multiplier == 3.0 + + +class TestRateLimitMiddleware: + """Tests for RateLimitMiddleware class.""" + + def create_app( + self, default_config: Optional[RateLimitConfig] = None + ) -> FastAPI: + """Create a test FastAPI app with rate limiting. + + Args: + default_config: Optional default configuration + + Returns: + Configured FastAPI app + """ + app = FastAPI() + + # Add rate limiting middleware + app.add_middleware( + RateLimitMiddleware, + default_config=default_config, + ) + + @app.get("/api/test") + async def test_endpoint(): + return {"message": "success"} + + @app.get("/health") + async def health_endpoint(): + return {"status": "ok"} + + @app.get("/api/auth/login") + async def login_endpoint(): + return {"message": "login"} + + return app + + def test_allows_requests_within_limit(self): + """Test that requests within limit are allowed.""" + app = self.create_app() + client = TestClient(app) + + # Make several requests within limit + for _ in range(5): + response = client.get("/api/test") + assert response.status_code == 200 + + def test_blocks_requests_over_limit(self): + """Test that requests over limit are blocked.""" + config = RateLimitConfig( + requests_per_minute=3, + requests_per_hour=100, + ) + app = self.create_app(config) + client = TestClient(app, raise_server_exceptions=False) + + # Make requests up to limit + for _ in range(3): + response = client.get("/api/test") + assert response.status_code == 200 + + # Next request should be rate limited + response = client.get("/api/test") + assert response.status_code == 429 + assert "Retry-After" in response.headers + + def test_bypass_health_endpoint(self): + """Test that health endpoint bypasses rate limiting.""" + config = RateLimitConfig( + requests_per_minute=1, + requests_per_hour=1, + ) + app = self.create_app(config) + client = TestClient(app) + + # Make many requests to health endpoint + for _ in range(10): + response = client.get("/health") + assert response.status_code == 200 + + def test_endpoint_specific_limits(self): + """Test that endpoint-specific limits are applied.""" + app = self.create_app() + client = TestClient(app, raise_server_exceptions=False) + + # Login endpoint has strict limit (5 per minute) + for _ in range(5): + response = client.get("/api/auth/login") + assert response.status_code == 200 + + # Next login request should be rate limited + response = client.get("/api/auth/login") + assert response.status_code == 429 + + def test_rate_limit_headers(self): + """Test that rate limit headers are added to response.""" + app = self.create_app() + client = TestClient(app) + + response = client.get("/api/test") + assert response.status_code == 200 + assert "X-RateLimit-Limit-Minute" in response.headers + assert "X-RateLimit-Limit-Hour" in response.headers + assert "X-RateLimit-Remaining-Minute" in response.headers + assert "X-RateLimit-Remaining-Hour" in response.headers + + def test_authenticated_user_multiplier(self): + """Test that authenticated users get higher limits.""" + config = RateLimitConfig( + requests_per_minute=5, + requests_per_hour=100, + authenticated_multiplier=2.0, + ) + app = self.create_app(config) + + # Add middleware to simulate authentication + @app.middleware("http") + async def add_user_to_state(request: Request, call_next): + request.state.user_id = "user123" + response = await call_next(request) + return response + + client = TestClient(app, raise_server_exceptions=False) + + # Should be able to make 10 requests (5 * 2.0) + for _ in range(10): + response = client.get("/api/test") + assert response.status_code == 200 + + # Next request should be rate limited + response = client.get("/api/test") + assert response.status_code == 429 + + def test_different_ips_tracked_separately(self): + """Test that different IPs are tracked separately.""" + config = RateLimitConfig( + requests_per_minute=2, + requests_per_hour=100, + ) + app = self.create_app(config) + client = TestClient(app, raise_server_exceptions=False) + + # Make requests from "different" IPs + # Note: TestClient uses same IP, but we can test the logic + for _ in range(2): + response = client.get("/api/test") + assert response.status_code == 200 + + # Third request should be rate limited + response = client.get("/api/test") + assert response.status_code == 429