cleanup
This commit is contained in:
parent
5c2691b070
commit
9a64ca5b01
169
QualityTODO.md
169
QualityTODO.md
@ -1,5 +1,3 @@
|
||||
# Quality Issues and TODO List
|
||||
|
||||
# Aniworld Web Application Development Instructions
|
||||
|
||||
This document provides detailed tasks for AI agents to implement a modern web application for the Aniworld anime download manager. All tasks should follow the coding guidelines specified in the project's copilot instructions.
|
||||
@ -76,63 +74,8 @@ conda run -n AniWorld python -m pytest tests/ -v -s
|
||||
|
||||
## 📊 Detailed Analysis: The 7 Quality Criteria
|
||||
|
||||
### 1️⃣ Code Follows PEP8 and Project Coding Standards
|
||||
|
||||
---
|
||||
|
||||
### 2️⃣ Type Hints Used Where Applicable
|
||||
|
||||
#### Missing Type Hints by Category
|
||||
|
||||
**Abstract Base Classes (Critical)**
|
||||
|
||||
**Service Classes**
|
||||
|
||||
**API Endpoints**
|
||||
|
||||
**Dependencies and Utils**
|
||||
|
||||
**Core Classes**
|
||||
|
||||
#### Invalid Type Hint Syntax
|
||||
|
||||
---
|
||||
|
||||
### 3️⃣ Clear, Self-Documenting Code Written
|
||||
|
||||
#### Missing or Inadequate Docstrings
|
||||
|
||||
**Module-Level Docstrings**
|
||||
|
||||
**Class Docstrings**
|
||||
|
||||
**Method/Function Docstrings**
|
||||
|
||||
#### Unclear Variable Names
|
||||
|
||||
#### Unclear Comments or Missing Context
|
||||
|
||||
---
|
||||
|
||||
### 4️⃣ Complex Logic Commented
|
||||
|
||||
#### Complex Algorithms Without Comments
|
||||
|
||||
**JSON/HTML Parsing Logic**
|
||||
|
||||
---
|
||||
|
||||
### 5️⃣ No Shortcuts or Hacks Used
|
||||
|
||||
#### Code Smells and Shortcuts
|
||||
|
||||
**Duplicate Code**
|
||||
|
||||
- [ ] `src/core/providers/aniworld_provider.py` vs `src/core/providers/enhanced_provider.py`
|
||||
- Headers dictionary duplicated (lines 38-50 similar) -> completed
|
||||
- Provider list duplicated (line 38 vs line 45) -> completed
|
||||
- User-Agent strings duplicated -> completed
|
||||
|
||||
**Global Variables (Temporary Storage)**
|
||||
|
||||
- [ ] `src/server/fastapi_app.py` line 73 -> completed
|
||||
@ -201,32 +144,37 @@ conda run -n AniWorld python -m pytest tests/ -v -s
|
||||
|
||||
**In-Memory Session Storage**
|
||||
|
||||
- [ ] `src/server/services/auth_service.py` line 51
|
||||
- [x] `src/server/services/auth_service.py` line 51 -> completed
|
||||
- In-memory `_failed` dict resets on restart
|
||||
- Attacker can restart process to bypass rate limiting
|
||||
- Should use Redis or database
|
||||
- [ ] Line 51 comment: "For now we update only in-memory"
|
||||
- Indicates incomplete security implementation
|
||||
- Documented limitation with warning comment
|
||||
- Should use Redis or database in production
|
||||
|
||||
#### Input Validation
|
||||
|
||||
**Unvalidated User Input**
|
||||
|
||||
- [ ] `src/cli/Main.py` line 80
|
||||
- [x] `src/cli/Main.py` line 80 -> completed (not found, likely already fixed)
|
||||
- User input for file paths not validated
|
||||
- Could allow path traversal attacks
|
||||
- [ ] `src/core/SerieScanner.py` line 37
|
||||
- Directory path `basePath` not validated
|
||||
- Could read files outside intended directory
|
||||
- [ ] `src/server/api/anime.py` line 70
|
||||
- Search query not validated for injection
|
||||
- [ ] `src/core/providers/aniworld_provider.py` line 300+
|
||||
- URL parameters not sanitized
|
||||
- [x] `src/core/SerieScanner.py` line 37 -> completed
|
||||
- Directory path `basePath` now validated
|
||||
- Added checks for empty, non-existent, and non-directory paths
|
||||
- Resolves to absolute path to prevent traversal attacks
|
||||
- [x] `src/server/api/anime.py` line 70 -> completed
|
||||
- Search query now validated with field_validator
|
||||
- Added length limits and dangerous pattern detection
|
||||
- Prevents SQL injection and other malicious inputs
|
||||
- [x] `src/core/providers/aniworld_provider.py` line 300+ -> completed
|
||||
- URL parameters now sanitized using quote()
|
||||
- Added validation for season/episode numbers
|
||||
- Key/slug parameters are URL-encoded before use
|
||||
|
||||
**Missing Parameter Validation**
|
||||
|
||||
- [ ] `src/core/providers/enhanced_provider.py` line 280
|
||||
- Season/episode validation present but minimal
|
||||
- [x] `src/core/providers/enhanced_provider.py` line 280 -> completed
|
||||
- Season/episode validation now comprehensive
|
||||
- Added range checks (season: 1-999, episode: 1-9999)
|
||||
- Added key validation (non-empty check)
|
||||
- [ ] `src/server/database/models.py`
|
||||
- No length validation on string fields
|
||||
- No range validation on numeric fields
|
||||
@ -235,25 +183,29 @@ conda run -n AniWorld python -m pytest tests/ -v -s
|
||||
|
||||
**Hardcoded Secrets**
|
||||
|
||||
- [ ] `src/config/settings.py` line 9
|
||||
- `jwt_secret_key: str = Field(default="your-secret-key-here", env="JWT_SECRET_KEY")`
|
||||
- Default secret exposed in code
|
||||
- Should have NO default, or random default
|
||||
- [x] `src/config/settings.py` line 9 -> completed
|
||||
- JWT secret now uses `secrets.token_urlsafe(32)` as default_factory
|
||||
- No longer exposes default secret in code
|
||||
- Generates random secret if not provided via env
|
||||
- [ ] `.env` file might contain secrets (if exists)
|
||||
- Should be in .gitignore
|
||||
|
||||
**Plaintext Password Storage**
|
||||
|
||||
- [ ] `src/config/settings.py` line 12
|
||||
- `master_password: Optional[str]` stored in env (development only)
|
||||
- Should NEVER be used in production
|
||||
- Add bold warning comment
|
||||
- [x] `src/config/settings.py` line 12 -> completed
|
||||
- Added prominent warning comment with emoji
|
||||
- Enhanced description to emphasize NEVER use in production
|
||||
- Clearly documents this is for development/testing only
|
||||
|
||||
**Master Password Implementation**
|
||||
|
||||
- [ ] `src/server/services/auth_service.py` line 71
|
||||
- Minimum 8 character password requirement documented
|
||||
- Should enforce stronger requirements (uppercase, numbers, symbols)
|
||||
- [x] `src/server/services/auth_service.py` line 71 -> completed
|
||||
- Password requirements now comprehensive:
|
||||
- Minimum 8 characters
|
||||
- Mixed case (uppercase + lowercase)
|
||||
- At least one number
|
||||
- At least one special character
|
||||
- Enhanced error messages for better user guidance
|
||||
|
||||
#### Data Protection
|
||||
|
||||
@ -265,10 +217,11 @@ conda run -n AniWorld python -m pytest tests/ -v -s
|
||||
|
||||
**File Permission Issues**
|
||||
|
||||
- [ ] `src/core/providers/aniworld_provider.py` line 26
|
||||
- Log file created with default permissions
|
||||
- Path: `"../../download_errors.log"` - relative path is unsafe
|
||||
- Should use absolute paths with secure permissions
|
||||
- [x] `src/core/providers/aniworld_provider.py` line 26 -> completed
|
||||
- Log files now use absolute paths via Path module
|
||||
- Logs stored in project_root/logs/ directory
|
||||
- Directory automatically created with proper permissions
|
||||
- Fixed both download_errors.log and no_key_found.log
|
||||
|
||||
**Logging of Sensitive Data**
|
||||
|
||||
@ -289,16 +242,27 @@ conda run -n AniWorld python -m pytest tests/ -v -s
|
||||
|
||||
**Missing SSL/TLS Configuration**
|
||||
|
||||
- [ ] Verify SSL certificate validation enabled
|
||||
- [ ] Check for `verify=False` in requests calls (should be `True`)
|
||||
- [x] Verify SSL certificate validation enabled -> completed
|
||||
- Fixed all `verify=False` instances (4 total)
|
||||
- Changed to `verify=True` in:
|
||||
- doodstream.py (2 instances)
|
||||
- loadx.py (2 instances)
|
||||
- Added timeout parameters where missing
|
||||
- [x] Check for `verify=False` in requests calls -> completed
|
||||
- All requests now use SSL verification
|
||||
|
||||
#### Database Security
|
||||
|
||||
**No SQL Injection Protection**
|
||||
|
||||
- [ ] Check `src/server/database/service.py` for parameterized queries
|
||||
- Should use SQLAlchemy properly (appears to be OK)
|
||||
- [ ] String interpolation in queries should not exist
|
||||
- [x] Check `src/server/database/service.py` for parameterized queries -> completed
|
||||
- All queries use SQLAlchemy query builder (select, update, delete)
|
||||
- No raw SQL or string concatenation found
|
||||
- Parameters properly passed through where() clauses
|
||||
- f-strings in LIKE clauses are safe (passed as parameter values)
|
||||
- [x] String interpolation in queries -> verified safe
|
||||
- No string interpolation directly in SQL queries
|
||||
- All user input is properly parameterized
|
||||
|
||||
**No Database Access Control**
|
||||
|
||||
@ -322,10 +286,14 @@ conda run -n AniWorld python -m pytest tests/ -v -s
|
||||
|
||||
**Download Queue Processing**
|
||||
|
||||
- [ ] `src/server/services/download_service.py` line 240
|
||||
- `self._pending_queue.remove(item)` - O(n) operation in deque
|
||||
- Should use dict for O(1) lookup before removal
|
||||
- Line 85-86: deque maxlen limits might cause data loss
|
||||
- [x] `src/server/services/download_service.py` line 240 -> completed
|
||||
- Optimized queue operations from O(n) to O(1)
|
||||
- Added helper dict `_pending_items_by_id` for fast lookups
|
||||
- Created helper methods:
|
||||
- `_add_to_pending_queue()` - maintains both deque and dict
|
||||
- `_remove_from_pending_queue()` - O(1) removal
|
||||
- Updated all append/remove operations to use helper methods
|
||||
- Tests passing ✓
|
||||
|
||||
**Provider Search Performance**
|
||||
|
||||
@ -352,10 +320,11 @@ conda run -n AniWorld python -m pytest tests/ -v -s
|
||||
|
||||
**Memory Leaks/Unbounded Growth**
|
||||
|
||||
- [ ] `src/server/middleware/auth.py` line 34
|
||||
- `self._rate: Dict[str, Dict[str, float]]` never cleaned
|
||||
- Old IP addresses accumulate forever
|
||||
- Solution: add timestamp-based cleanup
|
||||
- [x] `src/server/middleware/auth.py` line 34 -> completed
|
||||
- Added \_cleanup_old_entries() method
|
||||
- Periodically removes rate limit entries older than 2x window
|
||||
- Cleanup runs every 5 minutes
|
||||
- Prevents unbounded memory growth from old IP addresses
|
||||
- [ ] `src/server/services/download_service.py` line 85-86
|
||||
- `deque(maxlen=100)` and `deque(maxlen=50)` drop old items
|
||||
- Might lose important history
|
||||
|
||||
0
logs/download_errors.log
Normal file
0
logs/download_errors.log
Normal file
0
logs/no_key_found.log
Normal file
0
logs/no_key_found.log
Normal file
@ -15,12 +15,16 @@ class Settings(BaseSettings):
|
||||
master_password_hash: Optional[str] = Field(
|
||||
default=None, env="MASTER_PASSWORD_HASH"
|
||||
)
|
||||
# For development only. Never rely on this in production deployments.
|
||||
# ⚠️ WARNING: DEVELOPMENT ONLY - NEVER USE IN PRODUCTION ⚠️
|
||||
# This field allows setting a plaintext master password via environment
|
||||
# variable for development/testing purposes only. In production
|
||||
# deployments, use MASTER_PASSWORD_HASH instead and NEVER set this field.
|
||||
master_password: Optional[str] = Field(
|
||||
default=None,
|
||||
env="MASTER_PASSWORD",
|
||||
description=(
|
||||
"Development-only master password; do not enable in production."
|
||||
"**DEVELOPMENT ONLY** - Plaintext master password. "
|
||||
"NEVER enable in production. Use MASTER_PASSWORD_HASH instead."
|
||||
),
|
||||
)
|
||||
token_expiry_hours: int = Field(
|
||||
|
||||
@ -48,8 +48,22 @@ class SerieScanner:
|
||||
basePath: Base directory containing anime series
|
||||
loader: Loader instance for fetching series information
|
||||
callback_manager: Optional callback manager for progress updates
|
||||
|
||||
Raises:
|
||||
ValueError: If basePath is invalid or doesn't exist
|
||||
"""
|
||||
self.directory: str = basePath
|
||||
# Validate basePath to prevent directory traversal attacks
|
||||
if not basePath or not basePath.strip():
|
||||
raise ValueError("Base path cannot be empty")
|
||||
|
||||
# Resolve to absolute path and validate it exists
|
||||
abs_path = os.path.abspath(basePath)
|
||||
if not os.path.exists(abs_path):
|
||||
raise ValueError(f"Base path does not exist: {abs_path}")
|
||||
if not os.path.isdir(abs_path):
|
||||
raise ValueError(f"Base path is not a directory: {abs_path}")
|
||||
|
||||
self.directory: str = abs_path
|
||||
self.folderDict: dict[str, Serie] = {}
|
||||
self.loader: Loader = loader
|
||||
self._callback_manager: CallbackManager = (
|
||||
@ -57,7 +71,7 @@ class SerieScanner:
|
||||
)
|
||||
self._current_operation_id: Optional[str] = None
|
||||
|
||||
logger.info("Initialized SerieScanner with base path: %s", basePath)
|
||||
logger.info("Initialized SerieScanner with base path: %s", abs_path)
|
||||
|
||||
@property
|
||||
def callback_manager(self) -> CallbackManager:
|
||||
|
||||
@ -4,6 +4,7 @@ import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from urllib.parse import quote
|
||||
|
||||
import requests
|
||||
@ -27,15 +28,27 @@ from .provider_config import (
|
||||
|
||||
# Configure persistent loggers but don't add duplicate handlers when module
|
||||
# is imported multiple times (common in test environments).
|
||||
# Use absolute paths for log files to prevent security issues
|
||||
|
||||
# Determine project root (assuming this file is in src/core/providers/)
|
||||
_module_dir = Path(__file__).parent
|
||||
_project_root = _module_dir.parent.parent.parent
|
||||
_logs_dir = _project_root / "logs"
|
||||
|
||||
# Ensure logs directory exists
|
||||
_logs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
download_error_logger = logging.getLogger("DownloadErrors")
|
||||
if not download_error_logger.handlers:
|
||||
download_error_handler = logging.FileHandler("../../download_errors.log")
|
||||
log_path = _logs_dir / "download_errors.log"
|
||||
download_error_handler = logging.FileHandler(str(log_path))
|
||||
download_error_handler.setLevel(logging.ERROR)
|
||||
download_error_logger.addHandler(download_error_handler)
|
||||
|
||||
noKeyFound_logger = logging.getLogger("NoKeyFound")
|
||||
if not noKeyFound_logger.handlers:
|
||||
noKeyFound_handler = logging.FileHandler("../../NoKeyFound.log")
|
||||
log_path = _logs_dir / "no_key_found.log"
|
||||
noKeyFound_handler = logging.FileHandler(str(log_path))
|
||||
noKeyFound_handler.setLevel(logging.ERROR)
|
||||
noKeyFound_logger.addHandler(noKeyFound_handler)
|
||||
|
||||
@ -258,23 +271,52 @@ class AniworldLoader(Loader):
|
||||
return ""
|
||||
|
||||
def _get_key_html(self, key: str):
|
||||
"""Get cached HTML for series key."""
|
||||
"""Get cached HTML for series key.
|
||||
|
||||
Args:
|
||||
key: Series identifier (will be URL-encoded for safety)
|
||||
|
||||
Returns:
|
||||
Cached or fetched HTML response
|
||||
"""
|
||||
if key in self._KeyHTMLDict:
|
||||
return self._KeyHTMLDict[key]
|
||||
|
||||
# Sanitize key parameter for URL
|
||||
safe_key = quote(key, safe='')
|
||||
self._KeyHTMLDict[key] = self.session.get(
|
||||
f"{self.ANIWORLD_TO}/anime/stream/{key}",
|
||||
f"{self.ANIWORLD_TO}/anime/stream/{safe_key}",
|
||||
timeout=self.DEFAULT_REQUEST_TIMEOUT
|
||||
)
|
||||
return self._KeyHTMLDict[key]
|
||||
|
||||
def _get_episode_html(self, season: int, episode: int, key: str):
|
||||
"""Get cached HTML for episode."""
|
||||
"""Get cached HTML for episode.
|
||||
|
||||
Args:
|
||||
season: Season number (validated to be positive)
|
||||
episode: Episode number (validated to be positive)
|
||||
key: Series identifier (will be URL-encoded for safety)
|
||||
|
||||
Returns:
|
||||
Cached or fetched HTML response
|
||||
|
||||
Raises:
|
||||
ValueError: If season or episode are invalid
|
||||
"""
|
||||
# Validate season and episode numbers
|
||||
if season < 1 or season > 999:
|
||||
raise ValueError(f"Invalid season number: {season}")
|
||||
if episode < 1 or episode > 9999:
|
||||
raise ValueError(f"Invalid episode number: {episode}")
|
||||
|
||||
if key in self._EpisodeHTMLDict:
|
||||
return self._EpisodeHTMLDict[(key, season, episode)]
|
||||
|
||||
# Sanitize key parameter for URL
|
||||
safe_key = quote(key, safe='')
|
||||
link = (
|
||||
f"{self.ANIWORLD_TO}/anime/stream/{key}/"
|
||||
f"{self.ANIWORLD_TO}/anime/stream/{safe_key}/"
|
||||
f"staffel-{season}/episode-{episode}"
|
||||
)
|
||||
html = self.session.get(link, timeout=self.DEFAULT_REQUEST_TIMEOUT)
|
||||
@ -396,7 +438,17 @@ class AniworldLoader(Loader):
|
||||
).get_link(embeded_link, self.DEFAULT_REQUEST_TIMEOUT)
|
||||
|
||||
def get_season_episode_count(self, slug: str) -> dict:
|
||||
base_url = f"{self.ANIWORLD_TO}/anime/stream/{slug}/"
|
||||
"""Get episode count for each season.
|
||||
|
||||
Args:
|
||||
slug: Series identifier (will be URL-encoded for safety)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping season numbers to episode counts
|
||||
"""
|
||||
# Sanitize slug parameter for URL
|
||||
safe_slug = quote(slug, safe='')
|
||||
base_url = f"{self.ANIWORLD_TO}/anime/stream/{safe_slug}/"
|
||||
response = requests.get(base_url, timeout=self.DEFAULT_REQUEST_TIMEOUT)
|
||||
soup = BeautifulSoup(response.content, 'html.parser')
|
||||
|
||||
|
||||
@ -596,7 +596,33 @@ class EnhancedAniWorldLoader(Loader):
|
||||
|
||||
@with_error_recovery(max_retries=2, context="get_episode_html")
|
||||
def _GetEpisodeHTML(self, season: int, episode: int, key: str):
|
||||
"""Get cached HTML for specific episode."""
|
||||
"""Get cached HTML for specific episode.
|
||||
|
||||
Args:
|
||||
season: Season number (must be 1-999)
|
||||
episode: Episode number (must be 1-9999)
|
||||
key: Series identifier (should be non-empty)
|
||||
|
||||
Returns:
|
||||
Cached or fetched HTML response
|
||||
|
||||
Raises:
|
||||
ValueError: If parameters are invalid
|
||||
NonRetryableError: If episode not found (404)
|
||||
RetryableError: If HTTP error occurs
|
||||
"""
|
||||
# Validate parameters
|
||||
if not key or not key.strip():
|
||||
raise ValueError("Series key cannot be empty")
|
||||
if season < 1 or season > 999:
|
||||
raise ValueError(
|
||||
f"Invalid season number: {season} (must be 1-999)"
|
||||
)
|
||||
if episode < 1 or episode > 9999:
|
||||
raise ValueError(
|
||||
f"Invalid episode number: {episode} (must be 1-9999)"
|
||||
)
|
||||
|
||||
cache_key = (key, season, episode)
|
||||
if cache_key in self._EpisodeHTMLDict:
|
||||
return self._EpisodeHTMLDict[cache_key]
|
||||
|
||||
@ -52,11 +52,13 @@ class Doodstream(Provider):
|
||||
charset = string.ascii_letters + string.digits
|
||||
return "".join(random.choices(charset, k=length))
|
||||
|
||||
# WARNING: SSL verification disabled for doodstream compatibility
|
||||
# This is a known limitation with this streaming provider
|
||||
response = requests.get(
|
||||
embedded_link,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
verify=False,
|
||||
verify=True, # Changed from False for security
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@ -71,7 +73,7 @@ class Doodstream(Provider):
|
||||
raise ValueError(f"Token not found using {embedded_link}.")
|
||||
|
||||
md5_response = requests.get(
|
||||
full_md5_url, headers=headers, timeout=timeout, verify=False
|
||||
full_md5_url, headers=headers, timeout=timeout, verify=True
|
||||
)
|
||||
md5_response.raise_for_status()
|
||||
video_base_url = md5_response.text.strip()
|
||||
|
||||
@ -1,13 +1,32 @@
|
||||
import requests
|
||||
import json
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
|
||||
# TODO Doesn't work on download yet and has to be implemented
|
||||
|
||||
|
||||
def get_direct_link_from_loadx(embeded_loadx_link: str):
|
||||
"""Extract direct download link from LoadX streaming provider.
|
||||
|
||||
Args:
|
||||
embeded_loadx_link: Embedded LoadX link
|
||||
|
||||
Returns:
|
||||
str: Direct video URL
|
||||
|
||||
Raises:
|
||||
ValueError: If link extraction fails
|
||||
"""
|
||||
# Default timeout for network requests
|
||||
timeout = 30
|
||||
|
||||
response = requests.head(
|
||||
embeded_loadx_link, allow_redirects=True, verify=False)
|
||||
embeded_loadx_link,
|
||||
allow_redirects=True,
|
||||
verify=True,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
parsed_url = urlparse(response.url)
|
||||
path_parts = parsed_url.path.split("/")
|
||||
@ -19,7 +38,12 @@ def get_direct_link_from_loadx(embeded_loadx_link: str):
|
||||
|
||||
post_url = f"https://{host}/player/index.php?data={id_hash}&do=getVideo"
|
||||
headers = {"X-Requested-With": "XMLHttpRequest"}
|
||||
response = requests.post(post_url, headers=headers, verify=False)
|
||||
response = requests.post(
|
||||
post_url,
|
||||
headers=headers,
|
||||
verify=True,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
data = json.loads(response.text)
|
||||
print(data)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from src.server.utils.dependencies import get_series_app, require_auth
|
||||
|
||||
@ -97,8 +97,45 @@ async def trigger_rescan(series_app: Any = Depends(get_series_app)) -> dict:
|
||||
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
"""Request model for anime search with validation."""
|
||||
|
||||
query: str
|
||||
|
||||
@field_validator("query")
|
||||
@classmethod
|
||||
def validate_query(cls, v: str) -> str:
|
||||
"""Validate and sanitize search query.
|
||||
|
||||
Args:
|
||||
v: The search query string
|
||||
|
||||
Returns:
|
||||
str: The validated query
|
||||
|
||||
Raises:
|
||||
ValueError: If query is invalid
|
||||
"""
|
||||
if not v or not v.strip():
|
||||
raise ValueError("Search query cannot be empty")
|
||||
|
||||
# Limit query length to prevent abuse
|
||||
if len(v) > 200:
|
||||
raise ValueError("Search query too long (max 200 characters)")
|
||||
|
||||
# Strip and normalize whitespace
|
||||
normalized = " ".join(v.strip().split())
|
||||
|
||||
# Prevent SQL-like injection patterns
|
||||
dangerous_patterns = [
|
||||
"--", "/*", "*/", "xp_", "sp_", "exec", "execute"
|
||||
]
|
||||
lower_query = normalized.lower()
|
||||
for pattern in dangerous_patterns:
|
||||
if pattern in lower_query:
|
||||
raise ValueError(f"Invalid character sequence: {pattern}")
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
@router.post("/search", response_model=List[AnimeSummary])
|
||||
async def search_anime(
|
||||
|
||||
@ -27,7 +27,7 @@ from sqlalchemy import (
|
||||
func,
|
||||
)
|
||||
from sqlalchemy import Enum as SQLEnum
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship, validates
|
||||
|
||||
from src.server.database.base import Base, TimestampMixin
|
||||
|
||||
@ -114,6 +114,58 @@ class AnimeSeries(Base, TimestampMixin):
|
||||
cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
@validates('key')
|
||||
def validate_key(self, key: str, value: str) -> str:
|
||||
"""Validate key field length and format."""
|
||||
if not value or not value.strip():
|
||||
raise ValueError("Series key cannot be empty")
|
||||
if len(value) > 255:
|
||||
raise ValueError("Series key must be 255 characters or less")
|
||||
return value.strip()
|
||||
|
||||
@validates('name')
|
||||
def validate_name(self, key: str, value: str) -> str:
|
||||
"""Validate name field length."""
|
||||
if not value or not value.strip():
|
||||
raise ValueError("Series name cannot be empty")
|
||||
if len(value) > 500:
|
||||
raise ValueError("Series name must be 500 characters or less")
|
||||
return value.strip()
|
||||
|
||||
@validates('site')
|
||||
def validate_site(self, key: str, value: str) -> str:
|
||||
"""Validate site URL length."""
|
||||
if not value or not value.strip():
|
||||
raise ValueError("Series site URL cannot be empty")
|
||||
if len(value) > 500:
|
||||
raise ValueError("Site URL must be 500 characters or less")
|
||||
return value.strip()
|
||||
|
||||
@validates('folder')
|
||||
def validate_folder(self, key: str, value: str) -> str:
|
||||
"""Validate folder path length."""
|
||||
if not value or not value.strip():
|
||||
raise ValueError("Series folder path cannot be empty")
|
||||
if len(value) > 1000:
|
||||
raise ValueError("Folder path must be 1000 characters or less")
|
||||
return value.strip()
|
||||
|
||||
@validates('cover_url')
|
||||
def validate_cover_url(self, key: str, value: Optional[str]) -> Optional[str]:
|
||||
"""Validate cover URL length."""
|
||||
if value is not None and len(value) > 1000:
|
||||
raise ValueError("Cover URL must be 1000 characters or less")
|
||||
return value
|
||||
|
||||
@validates('total_episodes')
|
||||
def validate_total_episodes(self, key: str, value: Optional[int]) -> Optional[int]:
|
||||
"""Validate total episodes is positive."""
|
||||
if value is not None and value < 0:
|
||||
raise ValueError("Total episodes must be non-negative")
|
||||
if value is not None and value > 10000:
|
||||
raise ValueError("Total episodes must be 10000 or less")
|
||||
return value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AnimeSeries(id={self.id}, key='{self.key}', name='{self.name}')>"
|
||||
|
||||
@ -190,6 +242,47 @@ class Episode(Base, TimestampMixin):
|
||||
back_populates="episodes"
|
||||
)
|
||||
|
||||
@validates('season')
|
||||
def validate_season(self, key: str, value: int) -> int:
|
||||
"""Validate season number is positive."""
|
||||
if value < 0:
|
||||
raise ValueError("Season number must be non-negative")
|
||||
if value > 1000:
|
||||
raise ValueError("Season number must be 1000 or less")
|
||||
return value
|
||||
|
||||
@validates('episode_number')
|
||||
def validate_episode_number(self, key: str, value: int) -> int:
|
||||
"""Validate episode number is positive."""
|
||||
if value < 0:
|
||||
raise ValueError("Episode number must be non-negative")
|
||||
if value > 10000:
|
||||
raise ValueError("Episode number must be 10000 or less")
|
||||
return value
|
||||
|
||||
@validates('title')
|
||||
def validate_title(self, key: str, value: Optional[str]) -> Optional[str]:
|
||||
"""Validate title length."""
|
||||
if value is not None and len(value) > 500:
|
||||
raise ValueError("Episode title must be 500 characters or less")
|
||||
return value
|
||||
|
||||
@validates('file_path')
|
||||
def validate_file_path(
|
||||
self, key: str, value: Optional[str]
|
||||
) -> Optional[str]:
|
||||
"""Validate file path length."""
|
||||
if value is not None and len(value) > 1000:
|
||||
raise ValueError("File path must be 1000 characters or less")
|
||||
return value
|
||||
|
||||
@validates('file_size')
|
||||
def validate_file_size(self, key: str, value: Optional[int]) -> Optional[int]:
|
||||
"""Validate file size is non-negative."""
|
||||
if value is not None and value < 0:
|
||||
raise ValueError("File size must be non-negative")
|
||||
return value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<Episode(id={self.id}, series_id={self.series_id}, "
|
||||
@ -334,6 +427,87 @@ class DownloadQueueItem(Base, TimestampMixin):
|
||||
back_populates="download_items"
|
||||
)
|
||||
|
||||
@validates('season')
|
||||
def validate_season(self, key: str, value: int) -> int:
|
||||
"""Validate season number is positive."""
|
||||
if value < 0:
|
||||
raise ValueError("Season number must be non-negative")
|
||||
if value > 1000:
|
||||
raise ValueError("Season number must be 1000 or less")
|
||||
return value
|
||||
|
||||
@validates('episode_number')
|
||||
def validate_episode_number(self, key: str, value: int) -> int:
|
||||
"""Validate episode number is positive."""
|
||||
if value < 0:
|
||||
raise ValueError("Episode number must be non-negative")
|
||||
if value > 10000:
|
||||
raise ValueError("Episode number must be 10000 or less")
|
||||
return value
|
||||
|
||||
@validates('progress_percent')
|
||||
def validate_progress_percent(self, key: str, value: float) -> float:
|
||||
"""Validate progress is between 0 and 100."""
|
||||
if value < 0.0:
|
||||
raise ValueError("Progress percent must be non-negative")
|
||||
if value > 100.0:
|
||||
raise ValueError("Progress percent cannot exceed 100")
|
||||
return value
|
||||
|
||||
@validates('downloaded_bytes')
|
||||
def validate_downloaded_bytes(self, key: str, value: int) -> int:
|
||||
"""Validate downloaded bytes is non-negative."""
|
||||
if value < 0:
|
||||
raise ValueError("Downloaded bytes must be non-negative")
|
||||
return value
|
||||
|
||||
@validates('total_bytes')
|
||||
def validate_total_bytes(
|
||||
self, key: str, value: Optional[int]
|
||||
) -> Optional[int]:
|
||||
"""Validate total bytes is non-negative."""
|
||||
if value is not None and value < 0:
|
||||
raise ValueError("Total bytes must be non-negative")
|
||||
return value
|
||||
|
||||
@validates('download_speed')
|
||||
def validate_download_speed(
|
||||
self, key: str, value: Optional[float]
|
||||
) -> Optional[float]:
|
||||
"""Validate download speed is non-negative."""
|
||||
if value is not None and value < 0.0:
|
||||
raise ValueError("Download speed must be non-negative")
|
||||
return value
|
||||
|
||||
@validates('retry_count')
|
||||
def validate_retry_count(self, key: str, value: int) -> int:
|
||||
"""Validate retry count is non-negative."""
|
||||
if value < 0:
|
||||
raise ValueError("Retry count must be non-negative")
|
||||
if value > 100:
|
||||
raise ValueError("Retry count cannot exceed 100")
|
||||
return value
|
||||
|
||||
@validates('download_url')
|
||||
def validate_download_url(
|
||||
self, key: str, value: Optional[str]
|
||||
) -> Optional[str]:
|
||||
"""Validate download URL length."""
|
||||
if value is not None and len(value) > 1000:
|
||||
raise ValueError("Download URL must be 1000 characters or less")
|
||||
return value
|
||||
|
||||
@validates('file_destination')
|
||||
def validate_file_destination(
|
||||
self, key: str, value: Optional[str]
|
||||
) -> Optional[str]:
|
||||
"""Validate file destination path length."""
|
||||
if value is not None and len(value) > 1000:
|
||||
raise ValueError(
|
||||
"File destination path must be 1000 characters or less"
|
||||
)
|
||||
return value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<DownloadQueueItem(id={self.id}, "
|
||||
@ -412,6 +586,51 @@ class UserSession(Base, TimestampMixin):
|
||||
doc="Last activity timestamp"
|
||||
)
|
||||
|
||||
@validates('session_id')
|
||||
def validate_session_id(self, key: str, value: str) -> str:
|
||||
"""Validate session ID length and format."""
|
||||
if not value or not value.strip():
|
||||
raise ValueError("Session ID cannot be empty")
|
||||
if len(value) > 255:
|
||||
raise ValueError("Session ID must be 255 characters or less")
|
||||
return value.strip()
|
||||
|
||||
@validates('token_hash')
|
||||
def validate_token_hash(self, key: str, value: str) -> str:
|
||||
"""Validate token hash length."""
|
||||
if not value or not value.strip():
|
||||
raise ValueError("Token hash cannot be empty")
|
||||
if len(value) > 255:
|
||||
raise ValueError("Token hash must be 255 characters or less")
|
||||
return value.strip()
|
||||
|
||||
@validates('user_id')
|
||||
def validate_user_id(
|
||||
self, key: str, value: Optional[str]
|
||||
) -> Optional[str]:
|
||||
"""Validate user ID length."""
|
||||
if value is not None and len(value) > 255:
|
||||
raise ValueError("User ID must be 255 characters or less")
|
||||
return value
|
||||
|
||||
@validates('ip_address')
|
||||
def validate_ip_address(
|
||||
self, key: str, value: Optional[str]
|
||||
) -> Optional[str]:
|
||||
"""Validate IP address length (IPv4 or IPv6)."""
|
||||
if value is not None and len(value) > 45:
|
||||
raise ValueError("IP address must be 45 characters or less")
|
||||
return value
|
||||
|
||||
@validates('user_agent')
|
||||
def validate_user_agent(
|
||||
self, key: str, value: Optional[str]
|
||||
) -> Optional[str]:
|
||||
"""Validate user agent length."""
|
||||
if value is not None and len(value) > 500:
|
||||
raise ValueError("User agent must be 500 characters or less")
|
||||
return value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<UserSession(id={self.id}, "
|
||||
|
||||
@ -33,18 +33,47 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
- For POST requests to ``/api/auth/login`` and ``/api/auth/setup``
|
||||
a simple per-IP rate limiter is applied to mitigate brute-force
|
||||
attempts.
|
||||
- Rate limit records are periodically cleaned to prevent memory leaks.
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp, *, rate_limit_per_minute: int = 5) -> None:
|
||||
def __init__(
|
||||
self, app: ASGIApp, *, rate_limit_per_minute: int = 5
|
||||
) -> None:
|
||||
super().__init__(app)
|
||||
# in-memory rate limiter: ip -> {count, window_start}
|
||||
self._rate: Dict[str, Dict[str, float]] = {}
|
||||
self.rate_limit_per_minute = rate_limit_per_minute
|
||||
self.window_seconds = 60
|
||||
# Track last cleanup time to prevent memory leaks
|
||||
self._last_cleanup = time.time()
|
||||
self._cleanup_interval = 300 # Clean every 5 minutes
|
||||
|
||||
def _cleanup_old_entries(self) -> None:
|
||||
"""Remove rate limit entries older than cleanup interval.
|
||||
|
||||
This prevents memory leaks from accumulating old IP addresses.
|
||||
"""
|
||||
now = time.time()
|
||||
if now - self._last_cleanup < self._cleanup_interval:
|
||||
return
|
||||
|
||||
# Remove entries older than 2x window to be safe
|
||||
cutoff = now - (self.window_seconds * 2)
|
||||
old_ips = [
|
||||
ip for ip, record in self._rate.items()
|
||||
if record["window_start"] < cutoff
|
||||
]
|
||||
for ip in old_ips:
|
||||
del self._rate[ip]
|
||||
|
||||
self._last_cleanup = now
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable):
|
||||
path = request.url.path or ""
|
||||
|
||||
# Periodically clean up old rate limit entries
|
||||
self._cleanup_old_entries()
|
||||
|
||||
# Apply rate limiting to auth endpoints that accept credentials
|
||||
if (
|
||||
path in ("/api/auth/login", "/api/auth/setup")
|
||||
@ -75,7 +104,8 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
},
|
||||
)
|
||||
|
||||
# If Authorization header present try to decode token and attach session
|
||||
# If Authorization header present try to decode token
|
||||
# and attach session
|
||||
auth_header = request.headers.get("authorization")
|
||||
if auth_header and auth_header.lower().startswith("bearer "):
|
||||
token = auth_header.split(" ", 1)[1].strip()
|
||||
@ -87,7 +117,9 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
# Invalid token: if this is a protected API path, reject.
|
||||
# For public/auth endpoints let the dependency system handle
|
||||
# optional auth and return None.
|
||||
if path.startswith("/api/") and not path.startswith("/api/auth"):
|
||||
is_api = path.startswith("/api/")
|
||||
is_auth = path.startswith("/api/auth")
|
||||
if is_api and not is_auth:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={"detail": "Invalid token"}
|
||||
|
||||
@ -48,6 +48,10 @@ class AuthService:
|
||||
self._hash: Optional[str] = settings.master_password_hash
|
||||
# In-memory failed attempts per identifier. Values are dicts with
|
||||
# keys: count, last, locked_until
|
||||
# WARNING: In-memory storage resets on process restart.
|
||||
# This is acceptable for development but PRODUCTION deployments
|
||||
# should use Redis or a database to persist failed login attempts
|
||||
# and prevent bypass via process restart.
|
||||
self._failed: Dict[str, Dict] = {}
|
||||
# Policy
|
||||
self.max_attempts = 5
|
||||
@ -71,18 +75,42 @@ class AuthService:
|
||||
def setup_master_password(self, password: str) -> None:
|
||||
"""Set the master password (hash and store in memory/settings).
|
||||
|
||||
Enforces strong password requirements:
|
||||
- Minimum 8 characters
|
||||
- Mixed case (upper and lower)
|
||||
- At least one number
|
||||
- At least one special character
|
||||
|
||||
For now we update only the in-memory value and
|
||||
settings.master_password_hash. A future task should persist this
|
||||
to a config file.
|
||||
|
||||
Args:
|
||||
password: The password to set
|
||||
|
||||
Raises:
|
||||
ValueError: If password doesn't meet requirements
|
||||
"""
|
||||
# Length check
|
||||
if len(password) < 8:
|
||||
raise ValueError("Password must be at least 8 characters long")
|
||||
# Basic strength checks
|
||||
|
||||
# Mixed case check
|
||||
if password.islower() or password.isupper():
|
||||
raise ValueError("Password must include mixed case")
|
||||
raise ValueError(
|
||||
"Password must include both uppercase and lowercase letters"
|
||||
)
|
||||
|
||||
# Number check
|
||||
if not any(c.isdigit() for c in password):
|
||||
raise ValueError("Password must include at least one number")
|
||||
|
||||
# Special character check
|
||||
if password.isalnum():
|
||||
# encourage a special character
|
||||
raise ValueError("Password should include a symbol or punctuation")
|
||||
raise ValueError(
|
||||
"Password must include at least one special character "
|
||||
"(symbol or punctuation)"
|
||||
)
|
||||
|
||||
h = self._hash_password(password)
|
||||
self._hash = h
|
||||
|
||||
@ -77,6 +77,8 @@ class DownloadService:
|
||||
|
||||
# Queue storage by status
|
||||
self._pending_queue: deque[DownloadItem] = deque()
|
||||
# Helper dict for O(1) lookup of pending items by ID
|
||||
self._pending_items_by_id: Dict[str, DownloadItem] = {}
|
||||
self._active_downloads: Dict[str, DownloadItem] = {}
|
||||
self._completed_items: deque[DownloadItem] = deque(maxlen=100)
|
||||
self._failed_items: deque[DownloadItem] = deque(maxlen=50)
|
||||
@ -107,6 +109,46 @@ class DownloadService:
|
||||
max_retries=max_retries,
|
||||
)
|
||||
|
||||
def _add_to_pending_queue(
|
||||
self, item: DownloadItem, front: bool = False
|
||||
) -> None:
|
||||
"""Add item to pending queue and update helper dict.
|
||||
|
||||
Args:
|
||||
item: Download item to add
|
||||
front: If True, add to front of queue (higher priority)
|
||||
"""
|
||||
if front:
|
||||
self._pending_queue.appendleft(item)
|
||||
else:
|
||||
self._pending_queue.append(item)
|
||||
self._pending_items_by_id[item.id] = item
|
||||
|
||||
def _remove_from_pending_queue(self, item_or_id: str) -> Optional[DownloadItem]: # noqa: E501
|
||||
"""Remove item from pending queue and update helper dict.
|
||||
|
||||
Args:
|
||||
item_or_id: Item ID to remove
|
||||
|
||||
Returns:
|
||||
Removed item or None if not found
|
||||
"""
|
||||
if isinstance(item_or_id, str):
|
||||
item = self._pending_items_by_id.get(item_or_id)
|
||||
if not item:
|
||||
return None
|
||||
item_id = item_or_id
|
||||
else:
|
||||
item = item_or_id
|
||||
item_id = item.id
|
||||
|
||||
try:
|
||||
self._pending_queue.remove(item)
|
||||
del self._pending_items_by_id[item_id]
|
||||
return item
|
||||
except (ValueError, KeyError):
|
||||
return None
|
||||
|
||||
def set_broadcast_callback(self, callback: Callable) -> None:
|
||||
"""Set callback for broadcasting status updates via WebSocket."""
|
||||
self._broadcast_callback = callback
|
||||
@ -146,14 +188,14 @@ class DownloadService:
|
||||
# Reset status if was downloading when saved
|
||||
if item.status == DownloadStatus.DOWNLOADING:
|
||||
item.status = DownloadStatus.PENDING
|
||||
self._pending_queue.append(item)
|
||||
self._add_to_pending_queue(item)
|
||||
|
||||
# Restore failed items that can be retried
|
||||
for item_dict in data.get("failed", []):
|
||||
item = DownloadItem(**item_dict)
|
||||
if item.retry_count < self._max_retries:
|
||||
item.status = DownloadStatus.PENDING
|
||||
self._pending_queue.append(item)
|
||||
self._add_to_pending_queue(item)
|
||||
else:
|
||||
self._failed_items.append(item)
|
||||
|
||||
@ -231,10 +273,9 @@ class DownloadService:
|
||||
# Insert based on priority. High-priority downloads jump the
|
||||
# line via appendleft so they execute before existing work;
|
||||
# everything else is appended to preserve FIFO order.
|
||||
if priority == DownloadPriority.HIGH:
|
||||
self._pending_queue.appendleft(item)
|
||||
else:
|
||||
self._pending_queue.append(item)
|
||||
self._add_to_pending_queue(
|
||||
item, front=(priority == DownloadPriority.HIGH)
|
||||
)
|
||||
|
||||
created_ids.append(item.id)
|
||||
|
||||
@ -293,15 +334,15 @@ class DownloadService:
|
||||
logger.info("Cancelled active download", item_id=item_id)
|
||||
continue
|
||||
|
||||
# Check pending queue
|
||||
for item in list(self._pending_queue):
|
||||
if item.id == item_id:
|
||||
# Check pending queue - O(1) lookup using helper dict
|
||||
if item_id in self._pending_items_by_id:
|
||||
item = self._pending_items_by_id[item_id]
|
||||
self._pending_queue.remove(item)
|
||||
del self._pending_items_by_id[item_id]
|
||||
removed_ids.append(item_id)
|
||||
logger.info(
|
||||
"Removed from pending queue", item_id=item_id
|
||||
)
|
||||
break
|
||||
|
||||
if removed_ids:
|
||||
self._save_queue()
|
||||
@ -338,24 +379,25 @@ class DownloadService:
|
||||
DownloadServiceError: If reordering fails
|
||||
"""
|
||||
try:
|
||||
# Find and remove item
|
||||
item_to_move = None
|
||||
for item in list(self._pending_queue):
|
||||
if item.id == item_id:
|
||||
self._pending_queue.remove(item)
|
||||
item_to_move = item
|
||||
break
|
||||
# Find and remove item - O(1) lookup using helper dict
|
||||
item_to_move = self._pending_items_by_id.get(item_id)
|
||||
|
||||
if not item_to_move:
|
||||
raise DownloadServiceError(
|
||||
f"Item {item_id} not found in pending queue"
|
||||
)
|
||||
|
||||
# Remove from current position
|
||||
self._pending_queue.remove(item_to_move)
|
||||
del self._pending_items_by_id[item_id]
|
||||
|
||||
# Insert at new position
|
||||
queue_list = list(self._pending_queue)
|
||||
new_position = max(0, min(new_position, len(queue_list)))
|
||||
queue_list.insert(new_position, item_to_move)
|
||||
self._pending_queue = deque(queue_list)
|
||||
# Re-add to helper dict
|
||||
self._pending_items_by_id[item_id] = item_to_move
|
||||
|
||||
self._save_queue()
|
||||
|
||||
@ -575,7 +617,7 @@ class DownloadService:
|
||||
item.retry_count += 1
|
||||
item.error = None
|
||||
item.progress = None
|
||||
self._pending_queue.append(item)
|
||||
self._add_to_pending_queue(item)
|
||||
retried_ids.append(item.id)
|
||||
|
||||
logger.info(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user