Improve docs and security defaults
This commit is contained in:
parent
ebb0769ed4
commit
92795cf9b3
@ -104,16 +104,8 @@ conda run -n AniWorld python -m pytest tests/ -v -s
|
||||
|
||||
**Module-Level Docstrings**
|
||||
|
||||
- [ ] `src/core/entities/SerieList.py` - Check module docstring
|
||||
- [ ] `src/core/providers/streaming/doodstream.py` - Check module docstring
|
||||
- [ ] `src/core/providers/streaming/filemoon.py` - Check module docstring
|
||||
- [ ] `src/core/providers/streaming/hanime.py` - Check module docstring
|
||||
- [ ] `src/server/api/maintenance.py` - Check module docstring
|
||||
|
||||
**Class Docstrings**
|
||||
|
||||
- [ ] `src/cli/Main.py` - `SeriesApp` class lacks comprehensive docstring
|
||||
- [ ] `src/core/providers/enhanced_provider.py` - Class docstring incomplete
|
||||
- [ ] `src/server/utils/dependencies.py` - `CommonQueryParams` class lacks docstring
|
||||
|
||||
**Method/Function Docstrings**
|
||||
@ -123,13 +115,6 @@ conda run -n AniWorld python -m pytest tests/ -v -s
|
||||
|
||||
#### Unclear Variable Names
|
||||
|
||||
- [ ] `src/core/providers/enhanced_provider.py` line 35
|
||||
- `rec` in dictionary - should be `rate_limit_record`
|
||||
- [ ] `src/core/SerieScanner.py` line 138
|
||||
- `missings` - should be `missing_episodes`
|
||||
- [ ] `src/server/utils/dependencies.py` line 34
|
||||
- `security` - should be `http_bearer_security`
|
||||
|
||||
#### Unclear Comments or Missing Context
|
||||
|
||||
- [ ] `src/core/providers/enhanced_provider.py` line 231
|
||||
@ -143,44 +128,6 @@ conda run -n AniWorld python -m pytest tests/ -v -s
|
||||
|
||||
#### Complex Algorithms Without Comments
|
||||
|
||||
**SerieScanner Scanning Logic**
|
||||
|
||||
- [ ] `src/core/SerieScanner.py` lines 105-200
|
||||
- `Scan()` method with nested loops needs explaining comments
|
||||
- Algorithm: finds MP4 files, reads metadata, fetches missing episodes
|
||||
- Comment needed explaining the overall flow
|
||||
- Loop structure needs explanation
|
||||
|
||||
**Download Progress Tracking**
|
||||
|
||||
- [ ] `src/core/SeriesApp.py` lines 200-270
|
||||
- `wrapped_callback()` function has complex state tracking
|
||||
- Explains cancellation checking, callback chaining, progress reporting
|
||||
- Comments should explain why multiple callbacks are needed
|
||||
- Line 245-260 needs comments explaining callback wrapping logic
|
||||
|
||||
**Provider Discovery and Setup**
|
||||
|
||||
- [ ] `src/core/providers/enhanced_provider.py` lines 200-240
|
||||
- Multiple parsing strategies without explanation
|
||||
- Should comment why 3 strategies are tried
|
||||
- Line 215-230 needs comment block explaining each strategy
|
||||
- Should explain what error conditions trigger each fallback
|
||||
|
||||
**Download Queue Priority Handling**
|
||||
|
||||
- [ ] `src/server/services/download_service.py` lines 200-250
|
||||
- Priority-based queue insertion needs comment explaining algorithm
|
||||
- `appendleft()` for HIGH priority not explained
|
||||
- Lines 220-230 should have comment block for priority logic
|
||||
|
||||
**Rate Limiting Algorithm**
|
||||
|
||||
- [ ] `src/server/middleware/auth.py` lines 40-65
|
||||
- Time-window based rate limiting needs explanation
|
||||
- Window reset logic (line 50-53) needs comment
|
||||
- Calculation at line 55 needs explanation
|
||||
|
||||
**Episode Discovery Pattern Matching**
|
||||
|
||||
- [ ] `src/core/providers/streaming/*.py` files
|
||||
@ -191,11 +138,6 @@ conda run -n AniWorld python -m pytest tests/ -v -s
|
||||
|
||||
**JSON/HTML Parsing Logic**
|
||||
|
||||
- [ ] `src/core/providers/enhanced_provider.py` lines 215-235
|
||||
- `_parse_anime_response()` method has 3 parsing strategies
|
||||
- Needs comment explaining fallback order
|
||||
- Each lambda strategy should be explained
|
||||
|
||||
**Session Retry Configuration**
|
||||
|
||||
- [ ] `src/core/providers/enhanced_provider.py` lines 108-125
|
||||
@ -209,18 +151,8 @@ conda run -n AniWorld python -m pytest tests/ -v -s
|
||||
|
||||
#### Code Smells and Shortcuts
|
||||
|
||||
**Bare Pass Statements (Incomplete Implementation)**
|
||||
|
||||
- [ ] `src/server/utils/dependencies.py` lines 223-235
|
||||
- `TODO: Implement rate limiting logic` - bare pass
|
||||
- `TODO: Implement request logging logic` - bare pass
|
||||
- These are incomplete - either implement or remove
|
||||
- [ ] `src/server/utils/system.py` line 255
|
||||
- Bare `pass` in exception handler - should log error
|
||||
- [ ] `src/core/providers/streaming/Provider.py` line 7
|
||||
- Abstract method implementation should not have pass
|
||||
- [ ] `src/core/providers/streaming/hanime.py` line 86
|
||||
- Incomplete exception handling with pass
|
||||
- [ ] `src/core/providers/streaming/Provider.py` line 7 - Abstract method implementation should not have pass
|
||||
**Bare Pass Statements (Incomplete Implementation)**
|
||||
|
||||
**Duplicate Code**
|
||||
|
||||
|
||||
@ -43,6 +43,8 @@ class MatchNotFoundError(Exception):
|
||||
|
||||
|
||||
class SeriesApp:
|
||||
"""Interactive CLI controller orchestrating scanning and downloads."""
|
||||
|
||||
_initialization_count = 0 # Track initialization calls
|
||||
|
||||
def __init__(self, directory_to_search: str) -> None:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import secrets
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
@ -7,15 +8,20 @@ from pydantic_settings import BaseSettings
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings from environment variables."""
|
||||
jwt_secret_key: str = Field(
|
||||
default="your-secret-key-here", env="JWT_SECRET_KEY"
|
||||
default_factory=lambda: secrets.token_urlsafe(32),
|
||||
env="JWT_SECRET_KEY",
|
||||
)
|
||||
password_salt: str = Field(default="default-salt", env="PASSWORD_SALT")
|
||||
master_password_hash: Optional[str] = Field(
|
||||
default=None, env="MASTER_PASSWORD_HASH"
|
||||
)
|
||||
# For development
|
||||
# For development only. Never rely on this in production deployments.
|
||||
master_password: Optional[str] = Field(
|
||||
default=None, env="MASTER_PASSWORD"
|
||||
default=None,
|
||||
env="MASTER_PASSWORD",
|
||||
description=(
|
||||
"Development-only master password; do not enable in production."
|
||||
),
|
||||
)
|
||||
token_expiry_hours: int = Field(
|
||||
default=24, env="SESSION_TIMEOUT_HOURS"
|
||||
@ -27,7 +33,10 @@ class Settings(BaseSettings):
|
||||
database_url: str = Field(
|
||||
default="sqlite:///./data/aniworld.db", env="DATABASE_URL"
|
||||
)
|
||||
cors_origins: str = Field(default="*", env="CORS_ORIGINS")
|
||||
cors_origins: str = Field(
|
||||
default="http://localhost:3000",
|
||||
env="CORS_ORIGINS",
|
||||
)
|
||||
api_rate_limit: int = Field(default=100, env="API_RATE_LIMIT")
|
||||
default_provider: str = Field(
|
||||
default="aniworld.to", env="DEFAULT_PROVIDER"
|
||||
@ -35,6 +44,25 @@ class Settings(BaseSettings):
|
||||
provider_timeout: int = Field(default=30, env="PROVIDER_TIMEOUT")
|
||||
retry_attempts: int = Field(default=3, env="RETRY_ATTEMPTS")
|
||||
|
||||
@property
|
||||
def allowed_origins(self) -> list[str]:
|
||||
"""Return the list of allowed CORS origins.
|
||||
|
||||
The environment variable should contain a comma-separated list.
|
||||
When ``*`` is provided we fall back to a safe local development
|
||||
default instead of allowing every origin in production.
|
||||
"""
|
||||
|
||||
raw = (self.cors_origins or "").strip()
|
||||
if not raw:
|
||||
return []
|
||||
if raw == "*":
|
||||
return [
|
||||
"http://localhost:3000",
|
||||
"http://localhost:8000",
|
||||
]
|
||||
return [origin.strip() for origin in raw.split(",") if origin.strip()]
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
extra = "ignore"
|
||||
|
||||
@ -124,6 +124,9 @@ class SerieScanner:
|
||||
total_to_scan = self.get_total_to_scan()
|
||||
logger.info("Total folders to scan: %d", total_to_scan)
|
||||
|
||||
# The scanner enumerates folders with mp4 files, loads existing
|
||||
# metadata, calculates the missing episodes via the provider, and
|
||||
# persists the refreshed metadata while emitting progress events.
|
||||
result = self.__find_mp4_files()
|
||||
counter = 0
|
||||
|
||||
@ -137,6 +140,9 @@ class SerieScanner:
|
||||
else:
|
||||
percentage = 0.0
|
||||
|
||||
# Progress is surfaced both through the callback manager
|
||||
# (for the web/UI layer) and, for compatibility, through a
|
||||
# legacy callback that updates CLI progress bars.
|
||||
# Notify progress
|
||||
self._callback_manager.notify_progress(
|
||||
ProgressContext(
|
||||
@ -160,12 +166,16 @@ class SerieScanner:
|
||||
serie is not None
|
||||
and not self.is_null_or_whitespace(serie.key)
|
||||
):
|
||||
missings, site = (
|
||||
# Delegate the provider to compare local files with
|
||||
# remote metadata, yielding missing episodes per
|
||||
# season. Results are saved back to disk so that both
|
||||
# CLI and API consumers see consistent state.
|
||||
missing_episodes, site = (
|
||||
self.__get_missing_episodes_and_season(
|
||||
serie.key, mp4_files
|
||||
)
|
||||
)
|
||||
serie.episodeDict = missings
|
||||
serie.episodeDict = missing_episodes
|
||||
serie.folder = folder
|
||||
data_path = os.path.join(
|
||||
self.directory, folder, 'data'
|
||||
|
||||
@ -241,7 +241,9 @@ class SeriesApp:
|
||||
message="Download cancelled before starting"
|
||||
)
|
||||
|
||||
# Wrap callback to check for cancellation and report progress
|
||||
# Wrap callback to enforce cancellation checks and bridge the new
|
||||
# event-driven progress reporting with the legacy callback API that
|
||||
# the CLI still relies on.
|
||||
def wrapped_callback(progress: float):
|
||||
if self._is_cancelled():
|
||||
raise InterruptedError("Download cancelled by user")
|
||||
@ -268,6 +270,9 @@ class SeriesApp:
|
||||
if callback:
|
||||
callback(progress)
|
||||
|
||||
# Propagate progress into the legacy callback chain so existing
|
||||
# UI surfaces continue to receive updates without rewriting the
|
||||
# old interfaces.
|
||||
# Call legacy progress_callback if provided
|
||||
if self.progress_callback:
|
||||
self.progress_callback(ProgressInfo(
|
||||
@ -403,7 +408,9 @@ class SeriesApp:
|
||||
# Reinitialize scanner
|
||||
self.SerieScanner.reinit()
|
||||
|
||||
# Wrap callback for progress reporting and cancellation
|
||||
# Wrap the scanner callback so we can surface progress through the
|
||||
# new ProgressInfo pipeline while maintaining backwards
|
||||
# compatibility with the legacy tuple-based callback signature.
|
||||
def wrapped_callback(folder: str, current: int):
|
||||
if self._is_cancelled():
|
||||
raise InterruptedError("Scan cancelled by user")
|
||||
|
||||
@ -1,56 +1,99 @@
|
||||
import os
|
||||
import json
|
||||
"""Utilities for loading and managing stored anime series metadata."""
|
||||
|
||||
import logging
|
||||
from .series import Serie
|
||||
import os
|
||||
from json import JSONDecodeError
|
||||
from typing import Dict, Iterable, List
|
||||
|
||||
from src.core.entities.series import Serie
|
||||
|
||||
|
||||
class SerieList:
|
||||
def __init__(self, basePath: str):
|
||||
self.directory = basePath
|
||||
self.folderDict: dict[str, Serie] = {} # Proper initialization
|
||||
"""Represents the collection of cached series stored on disk."""
|
||||
|
||||
def __init__(self, base_path: str) -> None:
|
||||
self.directory: str = base_path
|
||||
self.folderDict: Dict[str, Serie] = {}
|
||||
self.load_series()
|
||||
|
||||
def add(self, serie: Serie):
|
||||
if (not self.contains(serie.key)):
|
||||
dataPath = os.path.join(self.directory, serie.folder, "data")
|
||||
animePath = os.path.join(self.directory, serie.folder)
|
||||
os.makedirs(animePath, exist_ok=True)
|
||||
if not os.path.isfile(dataPath):
|
||||
serie.save_to_file(dataPath)
|
||||
self.folderDict[serie.folder] = serie;
|
||||
def add(self, serie: Serie) -> None:
|
||||
"""Persist a new series if it is not already present."""
|
||||
|
||||
if self.contains(serie.key):
|
||||
return
|
||||
|
||||
data_path = os.path.join(self.directory, serie.folder, "data")
|
||||
anime_path = os.path.join(self.directory, serie.folder)
|
||||
os.makedirs(anime_path, exist_ok=True)
|
||||
if not os.path.isfile(data_path):
|
||||
serie.save_to_file(data_path)
|
||||
self.folderDict[serie.folder] = serie
|
||||
|
||||
def contains(self, key: str) -> bool:
|
||||
for k, value in self.folderDict.items():
|
||||
if value.key == key:
|
||||
return True
|
||||
return False
|
||||
"""Return True when a series identified by ``key`` already exists."""
|
||||
|
||||
def load_series(self):
|
||||
""" Scan folders and load data files """
|
||||
logging.info(f"Scanning anime folders in: {self.directory}")
|
||||
for anime_folder in os.listdir(self.directory):
|
||||
return any(value.key == key for value in self.folderDict.values())
|
||||
|
||||
def load_series(self) -> None:
|
||||
"""Populate the in-memory map with metadata discovered on disk."""
|
||||
|
||||
logging.info("Scanning anime folders in %s", self.directory)
|
||||
try:
|
||||
entries: Iterable[str] = os.listdir(self.directory)
|
||||
except OSError as error:
|
||||
logging.error(
|
||||
"Unable to scan directory %s: %s",
|
||||
self.directory,
|
||||
error,
|
||||
)
|
||||
return
|
||||
|
||||
for anime_folder in entries:
|
||||
anime_path = os.path.join(self.directory, anime_folder, "data")
|
||||
if os.path.isfile(anime_path):
|
||||
logging.debug(f"Found data folder: {anime_path}")
|
||||
self.load_data(anime_folder, anime_path)
|
||||
else:
|
||||
logging.warning(f"Skipping {anime_folder} - No data folder found")
|
||||
logging.debug("Found data file for folder %s", anime_folder)
|
||||
self._load_data(anime_folder, anime_path)
|
||||
continue
|
||||
|
||||
logging.warning(
|
||||
"Skipping folder %s because no metadata file was found",
|
||||
anime_folder,
|
||||
)
|
||||
|
||||
def _load_data(self, anime_folder: str, data_path: str) -> None:
|
||||
"""Load a single series metadata file into the in-memory collection."""
|
||||
|
||||
def load_data(self, anime_folder, data_path):
|
||||
""" Load pickle files from the data folder """
|
||||
try:
|
||||
self.folderDict[anime_folder] = Serie.load_from_file(data_path)
|
||||
logging.debug(f"Successfully loaded {data_path} for {anime_folder}")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to load {data_path} in {anime_folder}: {e}")
|
||||
logging.debug("Successfully loaded metadata for %s", anime_folder)
|
||||
except (OSError, JSONDecodeError, KeyError, ValueError) as error:
|
||||
logging.error(
|
||||
"Failed to load metadata for folder %s from %s: %s",
|
||||
anime_folder,
|
||||
data_path,
|
||||
error,
|
||||
)
|
||||
|
||||
def GetMissingEpisode(self) -> List[Serie]:
|
||||
"""Return all series that still contain missing episodes."""
|
||||
|
||||
return [
|
||||
serie
|
||||
for serie in self.folderDict.values()
|
||||
if serie.episodeDict
|
||||
]
|
||||
|
||||
def get_missing_episodes(self) -> List[Serie]:
|
||||
"""PEP8-friendly alias for :meth:`GetMissingEpisode`."""
|
||||
|
||||
return self.GetMissingEpisode()
|
||||
|
||||
def GetList(self) -> List[Serie]:
|
||||
"""Return all series instances stored in the list."""
|
||||
|
||||
def GetMissingEpisode(self):
|
||||
"""Find all series with a non-empty episodeDict"""
|
||||
return [serie for serie in self.folderDict.values() if len(serie.episodeDict) > 0]
|
||||
|
||||
def GetList(self):
|
||||
"""Get all series in the list"""
|
||||
return list(self.folderDict.values())
|
||||
|
||||
def get_all(self) -> List[Serie]:
|
||||
"""PEP8-friendly alias for :meth:`GetList`."""
|
||||
|
||||
#k = AnimeList("\\\\sshfs.r\\ubuntu@192.168.178.43\\media\\serien\\Serien")
|
||||
#bbabab = k.GetMissingEpisode()
|
||||
#print(bbabab)
|
||||
return self.GetList()
|
||||
|
||||
@ -37,7 +37,10 @@ from .base_provider import Loader
|
||||
|
||||
|
||||
class EnhancedAniWorldLoader(Loader):
|
||||
"""Enhanced AniWorld loader with comprehensive error handling."""
|
||||
"""Aniworld provider with retry and recovery strategies.
|
||||
|
||||
Also exposes metrics hooks for download statistics.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -211,7 +214,9 @@ class EnhancedAniWorldLoader(Loader):
|
||||
if not word or not word.strip():
|
||||
raise ValueError("Search term cannot be empty")
|
||||
|
||||
search_url = f"{self.ANIWORLD_TO}/ajax/seriesSearch?keyword={quote(word)}"
|
||||
search_url = (
|
||||
f"{self.ANIWORLD_TO}/ajax/seriesSearch?keyword={quote(word)}"
|
||||
)
|
||||
|
||||
try:
|
||||
return self._fetch_anime_list_with_recovery(search_url)
|
||||
@ -250,7 +255,9 @@ class EnhancedAniWorldLoader(Loader):
|
||||
|
||||
clean_text = response_text.strip()
|
||||
|
||||
# Try multiple parsing strategies
|
||||
# Try multiple parsing strategies. We progressively relax the parsing
|
||||
# requirements to handle HTML-escaped payloads, stray BOM markers, and
|
||||
# control characters injected by the upstream service.
|
||||
parsing_strategies = [
|
||||
lambda text: json.loads(html.unescape(text)),
|
||||
lambda text: json.loads(text.encode('utf-8').decode('utf-8-sig')),
|
||||
|
||||
@ -1,5 +1,8 @@
|
||||
"""Resolve Doodstream embed players into direct download URLs."""
|
||||
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
@ -8,6 +11,12 @@ from fake_useragent import UserAgent
|
||||
|
||||
from .Provider import Provider
|
||||
|
||||
# Precompiled regex patterns to extract the ``pass_md5`` endpoint and the
|
||||
# session token embedded in the obfuscated player script. Compiling once keeps
|
||||
# repeated invocations fast and documents the parsing intent.
|
||||
PASS_MD5_PATTERN = re.compile(r"\$\.get\('([^']*/pass_md5/[^']*)'")
|
||||
TOKEN_PATTERN = re.compile(r"token=([a-zA-Z0-9]+)")
|
||||
|
||||
|
||||
class Doodstream(Provider):
|
||||
"""Doodstream video provider implementation."""
|
||||
@ -33,17 +42,15 @@ class Doodstream(Provider):
|
||||
"Referer": "https://dood.li/",
|
||||
}
|
||||
|
||||
def extract_data(pattern: str, content: str) -> str | None:
|
||||
"""Extract data using regex pattern."""
|
||||
match = re.search(pattern, content)
|
||||
def extract_data(pattern: re.Pattern[str], content: str) -> str | None:
|
||||
"""Extract data using a compiled regex pattern."""
|
||||
match = pattern.search(content)
|
||||
return match.group(1) if match else None
|
||||
|
||||
def generate_random_string(length: int = 10) -> str:
|
||||
"""Generate random alphanumeric string."""
|
||||
characters = (
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
|
||||
)
|
||||
return "".join(random.choice(characters) for _ in range(length))
|
||||
charset = string.ascii_letters + string.digits
|
||||
return "".join(random.choices(charset, k=length))
|
||||
|
||||
response = requests.get(
|
||||
embedded_link,
|
||||
@ -53,15 +60,13 @@ class Doodstream(Provider):
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
pass_md5_pattern = r"\$\.get\('([^']*\/pass_md5\/[^']*)'"
|
||||
pass_md5_url = extract_data(pass_md5_pattern, response.text)
|
||||
pass_md5_url = extract_data(PASS_MD5_PATTERN, response.text)
|
||||
if not pass_md5_url:
|
||||
raise ValueError(f"pass_md5 URL not found using {embedded_link}.")
|
||||
|
||||
full_md5_url = f"https://dood.li{pass_md5_url}"
|
||||
|
||||
token_pattern = r"token=([a-zA-Z0-9]+)"
|
||||
token = extract_data(token_pattern, response.text)
|
||||
token = extract_data(TOKEN_PATTERN, response.text)
|
||||
if not token:
|
||||
raise ValueError(f"Token not found using {embedded_link}.")
|
||||
|
||||
|
||||
@ -1,8 +1,12 @@
|
||||
"""Resolve Filemoon embed pages into direct streaming asset URLs."""
|
||||
|
||||
import re
|
||||
|
||||
import requests
|
||||
from aniworld import config
|
||||
|
||||
# import jsbeautifier.unpackers.packer as packer
|
||||
|
||||
from aniworld import config
|
||||
|
||||
REDIRECT_REGEX = re.compile(
|
||||
r'<iframe *(?:[^>]+ )?src=(?:\'([^\']+)\'|"([^"]+)")[^>]*>')
|
||||
|
||||
@ -1,6 +1,9 @@
|
||||
import re
|
||||
"""Helpers for extracting direct stream URLs from hanime.tv pages."""
|
||||
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
|
||||
import requests
|
||||
from aniworld.config import DEFAULT_REQUEST_TIMEOUT
|
||||
|
||||
@ -83,7 +86,7 @@ def get_direct_link_from_hanime(url=None):
|
||||
except ValueError as e:
|
||||
print(f"Error: {e}")
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
print("\nOperation cancelled by user.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -1,4 +1,10 @@
|
||||
"""Maintenance and system management API endpoints."""
|
||||
"""Maintenance API endpoints for system housekeeping and diagnostics.
|
||||
|
||||
This module exposes cleanup routines, system statistics, maintenance
|
||||
operations, and health reporting endpoints that rely on the shared system
|
||||
utilities and monitoring services. The routes allow administrators to
|
||||
prune logs, inspect disk usage, vacuum or analyze the database, and gather
|
||||
holistic health metrics for AniWorld deployments."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
@ -44,22 +44,15 @@ app = FastAPI(
|
||||
redoc_url="/api/redoc"
|
||||
)
|
||||
|
||||
# Configure CORS
|
||||
# WARNING: In production, ensure CORS_ORIGINS is properly configured
|
||||
# Default to localhost for development, configure via environment variable
|
||||
cors_origins = (
|
||||
settings.cors_origins.split(",")
|
||||
if settings.cors_origins and settings.cors_origins != "*"
|
||||
else (
|
||||
["http://localhost:3000", "http://localhost:8000"]
|
||||
if settings.cors_origins == "*"
|
||||
else []
|
||||
)
|
||||
)
|
||||
# Configure CORS using environment-driven configuration.
|
||||
allowed_origins = settings.allowed_origins or [
|
||||
"http://localhost:3000",
|
||||
"http://localhost:8000",
|
||||
]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=cors_origins if cors_origins else ["*"],
|
||||
allow_origins=allowed_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
|
||||
@ -46,21 +46,33 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
path = request.url.path or ""
|
||||
|
||||
# Apply rate limiting to auth endpoints that accept credentials
|
||||
if path in ("/api/auth/login", "/api/auth/setup") and request.method.upper() == "POST":
|
||||
if (
|
||||
path in ("/api/auth/login", "/api/auth/setup")
|
||||
and request.method.upper() == "POST"
|
||||
):
|
||||
client_host = self._get_client_ip(request)
|
||||
rec = self._rate.setdefault(client_host, {"count": 0, "window_start": time.time()})
|
||||
rate_limit_record = self._rate.setdefault(
|
||||
client_host,
|
||||
{"count": 0, "window_start": time.time()},
|
||||
)
|
||||
now = time.time()
|
||||
if now - rec["window_start"] > self.window_seconds:
|
||||
# reset window
|
||||
rec["window_start"] = now
|
||||
rec["count"] = 0
|
||||
# The limiter uses a fixed window; once the window expires, we
|
||||
# reset the counter for that client and start measuring again.
|
||||
if now - rate_limit_record["window_start"] > self.window_seconds:
|
||||
rate_limit_record["window_start"] = now
|
||||
rate_limit_record["count"] = 0
|
||||
|
||||
rec["count"] += 1
|
||||
if rec["count"] > self.rate_limit_per_minute:
|
||||
rate_limit_record["count"] += 1
|
||||
if rate_limit_record["count"] > self.rate_limit_per_minute:
|
||||
# Too many requests in window — return a JSON 429 response
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
content={"detail": "Too many authentication attempts, try again later"},
|
||||
content={
|
||||
"detail": (
|
||||
"Too many authentication attempts, "
|
||||
"try again later"
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# If Authorization header present try to decode token and attach session
|
||||
|
||||
@ -228,7 +228,9 @@ class DownloadService:
|
||||
added_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Insert based on priority
|
||||
# Insert based on priority. High-priority downloads jump the
|
||||
# line via appendleft so they execute before existing work;
|
||||
# everything else is appended to preserve FIFO order.
|
||||
if priority == DownloadPriority.HIGH:
|
||||
self._pending_queue.appendleft(item)
|
||||
else:
|
||||
|
||||
@ -5,9 +5,13 @@ This module provides dependency injection functions for the FastAPI
|
||||
application, including SeriesApp instances, AnimeService, DownloadService,
|
||||
database sessions, and authentication dependencies.
|
||||
"""
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, Optional
|
||||
import logging
|
||||
import time
|
||||
from asyncio import Lock
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Optional
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
try:
|
||||
@ -19,13 +23,15 @@ from src.config.settings import settings
|
||||
from src.core.SeriesApp import SeriesApp
|
||||
from src.server.services.auth_service import AuthError, auth_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.server.services.anime_service import AnimeService
|
||||
from src.server.services.download_service import DownloadService
|
||||
|
||||
# Security scheme for JWT authentication
|
||||
# Use auto_error=False to handle errors manually and return 401 instead of 403
|
||||
security = HTTPBearer(auto_error=False)
|
||||
http_bearer_security = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
# Global SeriesApp instance
|
||||
@ -36,6 +42,19 @@ _anime_service: Optional["AnimeService"] = None
|
||||
_download_service: Optional["DownloadService"] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitRecord:
|
||||
"""Track request counts within a fixed time window."""
|
||||
|
||||
count: int
|
||||
window_start: float
|
||||
|
||||
|
||||
_RATE_LIMIT_BUCKETS: Dict[str, RateLimitRecord] = {}
|
||||
_rate_limit_lock = Lock()
|
||||
_RATE_LIMIT_WINDOW_SECONDS = 60.0
|
||||
|
||||
|
||||
def get_series_app() -> SeriesApp:
|
||||
"""
|
||||
Dependency to get SeriesApp instance.
|
||||
@ -104,7 +123,9 @@ async def get_database_session() -> AsyncGenerator:
|
||||
|
||||
|
||||
def get_current_user(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(
|
||||
http_bearer_security
|
||||
),
|
||||
) -> dict:
|
||||
"""
|
||||
Dependency to get current authenticated user.
|
||||
@ -195,7 +216,7 @@ def get_current_user_optional(
|
||||
|
||||
|
||||
class CommonQueryParams:
|
||||
"""Common query parameters for API endpoints."""
|
||||
"""Reusable pagination parameters shared across API endpoints."""
|
||||
|
||||
def __init__(self, skip: int = 0, limit: int = 100) -> None:
|
||||
"""Create a reusable pagination parameter container.
|
||||
@ -226,23 +247,47 @@ def common_parameters(
|
||||
|
||||
|
||||
# Dependency for rate limiting (placeholder)
|
||||
async def rate_limit_dependency():
|
||||
"""
|
||||
Dependency for rate limiting API requests.
|
||||
|
||||
TODO: Implement rate limiting logic
|
||||
"""
|
||||
pass
|
||||
async def rate_limit_dependency(request: Request) -> None:
|
||||
"""Apply a simple fixed-window rate limit to incoming requests."""
|
||||
|
||||
client_id = "unknown"
|
||||
if request.client and request.client.host:
|
||||
client_id = request.client.host
|
||||
|
||||
max_requests = max(1, settings.api_rate_limit)
|
||||
now = time.time()
|
||||
|
||||
async with _rate_limit_lock:
|
||||
record = _RATE_LIMIT_BUCKETS.get(client_id)
|
||||
if not record or now - record.window_start >= _RATE_LIMIT_WINDOW_SECONDS:
|
||||
_RATE_LIMIT_BUCKETS[client_id] = RateLimitRecord(
|
||||
count=1,
|
||||
window_start=now,
|
||||
)
|
||||
return
|
||||
|
||||
record.count += 1
|
||||
if record.count > max_requests:
|
||||
logger.warning("Rate limit exceeded", extra={"client": client_id})
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail="Too many requests. Please slow down.",
|
||||
)
|
||||
|
||||
|
||||
# Dependency for request logging (placeholder)
|
||||
async def log_request_dependency():
|
||||
"""
|
||||
Dependency for logging API requests.
|
||||
|
||||
TODO: Implement request logging logic
|
||||
"""
|
||||
pass
|
||||
async def log_request_dependency(request: Request) -> None:
|
||||
"""Log request metadata for auditing and debugging purposes."""
|
||||
|
||||
logger.info(
|
||||
"API request",
|
||||
extra={
|
||||
"method": request.method,
|
||||
"path": request.url.path,
|
||||
"client": request.client.host if request.client else "unknown",
|
||||
"query": dict(request.query_params),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def get_anime_service() -> "AnimeService":
|
||||
|
||||
@ -251,8 +251,12 @@ class SystemUtilities:
|
||||
info = SystemUtilities.get_process_info(proc.pid)
|
||||
if info:
|
||||
processes.append(info)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as process_error:
|
||||
logger.debug(
|
||||
"Skipping process %s: %s",
|
||||
proc.pid,
|
||||
process_error,
|
||||
)
|
||||
|
||||
return processes
|
||||
except Exception as e:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user