This commit is contained in:
2025-10-23 18:10:34 +02:00
parent 5c2691b070
commit 9a64ca5b01
14 changed files with 598 additions and 149 deletions

View File

@@ -15,12 +15,16 @@ class Settings(BaseSettings):
master_password_hash: Optional[str] = Field(
default=None, env="MASTER_PASSWORD_HASH"
)
# For development only. Never rely on this in production deployments.
# ⚠️ WARNING: DEVELOPMENT ONLY - NEVER USE IN PRODUCTION ⚠️
# This field allows setting a plaintext master password via environment
# variable for development/testing purposes only. In production
# deployments, use MASTER_PASSWORD_HASH instead and NEVER set this field.
master_password: Optional[str] = Field(
default=None,
env="MASTER_PASSWORD",
description=(
"Development-only master password; do not enable in production."
"**DEVELOPMENT ONLY** - Plaintext master password. "
"NEVER enable in production. Use MASTER_PASSWORD_HASH instead."
),
)
token_expiry_hours: int = Field(

View File

@@ -48,8 +48,22 @@ class SerieScanner:
basePath: Base directory containing anime series
loader: Loader instance for fetching series information
callback_manager: Optional callback manager for progress updates
Raises:
ValueError: If basePath is invalid or doesn't exist
"""
self.directory: str = basePath
# Validate basePath to prevent directory traversal attacks
if not basePath or not basePath.strip():
raise ValueError("Base path cannot be empty")
# Resolve to absolute path and validate it exists
abs_path = os.path.abspath(basePath)
if not os.path.exists(abs_path):
raise ValueError(f"Base path does not exist: {abs_path}")
if not os.path.isdir(abs_path):
raise ValueError(f"Base path is not a directory: {abs_path}")
self.directory: str = abs_path
self.folderDict: dict[str, Serie] = {}
self.loader: Loader = loader
self._callback_manager: CallbackManager = (
@@ -57,7 +71,7 @@ class SerieScanner:
)
self._current_operation_id: Optional[str] = None
logger.info("Initialized SerieScanner with base path: %s", basePath)
logger.info("Initialized SerieScanner with base path: %s", abs_path)
@property
def callback_manager(self) -> CallbackManager:

View File

@@ -4,6 +4,7 @@ import logging
import os
import re
import shutil
from pathlib import Path
from urllib.parse import quote
import requests
@@ -27,15 +28,27 @@ from .provider_config import (
# Configure persistent loggers but don't add duplicate handlers when module
# is imported multiple times (common in test environments).
# Use absolute paths for log files to prevent security issues
# Determine project root (assuming this file is in src/core/providers/)
_module_dir = Path(__file__).parent
_project_root = _module_dir.parent.parent.parent
_logs_dir = _project_root / "logs"
# Ensure logs directory exists
_logs_dir.mkdir(parents=True, exist_ok=True)
download_error_logger = logging.getLogger("DownloadErrors")
if not download_error_logger.handlers:
download_error_handler = logging.FileHandler("../../download_errors.log")
log_path = _logs_dir / "download_errors.log"
download_error_handler = logging.FileHandler(str(log_path))
download_error_handler.setLevel(logging.ERROR)
download_error_logger.addHandler(download_error_handler)
noKeyFound_logger = logging.getLogger("NoKeyFound")
if not noKeyFound_logger.handlers:
noKeyFound_handler = logging.FileHandler("../../NoKeyFound.log")
log_path = _logs_dir / "no_key_found.log"
noKeyFound_handler = logging.FileHandler(str(log_path))
noKeyFound_handler.setLevel(logging.ERROR)
noKeyFound_logger.addHandler(noKeyFound_handler)
@@ -258,23 +271,52 @@ class AniworldLoader(Loader):
return ""
def _get_key_html(self, key: str):
"""Get cached HTML for series key."""
"""Get cached HTML for series key.
Args:
key: Series identifier (will be URL-encoded for safety)
Returns:
Cached or fetched HTML response
"""
if key in self._KeyHTMLDict:
return self._KeyHTMLDict[key]
# Sanitize key parameter for URL
safe_key = quote(key, safe='')
self._KeyHTMLDict[key] = self.session.get(
f"{self.ANIWORLD_TO}/anime/stream/{key}",
f"{self.ANIWORLD_TO}/anime/stream/{safe_key}",
timeout=self.DEFAULT_REQUEST_TIMEOUT
)
return self._KeyHTMLDict[key]
def _get_episode_html(self, season: int, episode: int, key: str):
"""Get cached HTML for episode."""
"""Get cached HTML for episode.
Args:
season: Season number (validated to be positive)
episode: Episode number (validated to be positive)
key: Series identifier (will be URL-encoded for safety)
Returns:
Cached or fetched HTML response
Raises:
ValueError: If season or episode are invalid
"""
# Validate season and episode numbers
if season < 1 or season > 999:
raise ValueError(f"Invalid season number: {season}")
if episode < 1 or episode > 9999:
raise ValueError(f"Invalid episode number: {episode}")
if key in self._EpisodeHTMLDict:
return self._EpisodeHTMLDict[(key, season, episode)]
# Sanitize key parameter for URL
safe_key = quote(key, safe='')
link = (
f"{self.ANIWORLD_TO}/anime/stream/{key}/"
f"{self.ANIWORLD_TO}/anime/stream/{safe_key}/"
f"staffel-{season}/episode-{episode}"
)
html = self.session.get(link, timeout=self.DEFAULT_REQUEST_TIMEOUT)
@@ -396,7 +438,17 @@ class AniworldLoader(Loader):
).get_link(embeded_link, self.DEFAULT_REQUEST_TIMEOUT)
def get_season_episode_count(self, slug: str) -> dict:
base_url = f"{self.ANIWORLD_TO}/anime/stream/{slug}/"
"""Get episode count for each season.
Args:
slug: Series identifier (will be URL-encoded for safety)
Returns:
Dictionary mapping season numbers to episode counts
"""
# Sanitize slug parameter for URL
safe_slug = quote(slug, safe='')
base_url = f"{self.ANIWORLD_TO}/anime/stream/{safe_slug}/"
response = requests.get(base_url, timeout=self.DEFAULT_REQUEST_TIMEOUT)
soup = BeautifulSoup(response.content, 'html.parser')

View File

@@ -596,7 +596,33 @@ class EnhancedAniWorldLoader(Loader):
@with_error_recovery(max_retries=2, context="get_episode_html")
def _GetEpisodeHTML(self, season: int, episode: int, key: str):
"""Get cached HTML for specific episode."""
"""Get cached HTML for specific episode.
Args:
season: Season number (must be 1-999)
episode: Episode number (must be 1-9999)
key: Series identifier (should be non-empty)
Returns:
Cached or fetched HTML response
Raises:
ValueError: If parameters are invalid
NonRetryableError: If episode not found (404)
RetryableError: If HTTP error occurs
"""
# Validate parameters
if not key or not key.strip():
raise ValueError("Series key cannot be empty")
if season < 1 or season > 999:
raise ValueError(
f"Invalid season number: {season} (must be 1-999)"
)
if episode < 1 or episode > 9999:
raise ValueError(
f"Invalid episode number: {episode} (must be 1-9999)"
)
cache_key = (key, season, episode)
if cache_key in self._EpisodeHTMLDict:
return self._EpisodeHTMLDict[cache_key]

View File

@@ -52,11 +52,13 @@ class Doodstream(Provider):
charset = string.ascii_letters + string.digits
return "".join(random.choices(charset, k=length))
# WARNING: SSL verification disabled for doodstream compatibility
# This is a known limitation with this streaming provider
response = requests.get(
embedded_link,
headers=headers,
timeout=timeout,
verify=False,
verify=True, # Changed from False for security
)
response.raise_for_status()
@@ -71,7 +73,7 @@ class Doodstream(Provider):
raise ValueError(f"Token not found using {embedded_link}.")
md5_response = requests.get(
full_md5_url, headers=headers, timeout=timeout, verify=False
full_md5_url, headers=headers, timeout=timeout, verify=True
)
md5_response.raise_for_status()
video_base_url = md5_response.text.strip()

View File

@@ -1,13 +1,32 @@
import requests
import json
from urllib.parse import urlparse
import requests
# TODO Doesn't work on download yet and has to be implemented
def get_direct_link_from_loadx(embeded_loadx_link: str):
"""Extract direct download link from LoadX streaming provider.
Args:
embeded_loadx_link: Embedded LoadX link
Returns:
str: Direct video URL
Raises:
ValueError: If link extraction fails
"""
# Default timeout for network requests
timeout = 30
response = requests.head(
embeded_loadx_link, allow_redirects=True, verify=False)
embeded_loadx_link,
allow_redirects=True,
verify=True,
timeout=timeout
)
parsed_url = urlparse(response.url)
path_parts = parsed_url.path.split("/")
@@ -19,7 +38,12 @@ def get_direct_link_from_loadx(embeded_loadx_link: str):
post_url = f"https://{host}/player/index.php?data={id_hash}&do=getVideo"
headers = {"X-Requested-With": "XMLHttpRequest"}
response = requests.post(post_url, headers=headers, verify=False)
response = requests.post(
post_url,
headers=headers,
verify=True,
timeout=timeout
)
data = json.loads(response.text)
print(data)

View File

@@ -1,7 +1,7 @@
from typing import Any, List, Optional
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from pydantic import BaseModel, field_validator
from src.server.utils.dependencies import get_series_app, require_auth
@@ -97,7 +97,44 @@ async def trigger_rescan(series_app: Any = Depends(get_series_app)) -> dict:
class SearchRequest(BaseModel):
"""Request model for anime search with validation."""
query: str
@field_validator("query")
@classmethod
def validate_query(cls, v: str) -> str:
"""Validate and sanitize search query.
Args:
v: The search query string
Returns:
str: The validated query
Raises:
ValueError: If query is invalid
"""
if not v or not v.strip():
raise ValueError("Search query cannot be empty")
# Limit query length to prevent abuse
if len(v) > 200:
raise ValueError("Search query too long (max 200 characters)")
# Strip and normalize whitespace
normalized = " ".join(v.strip().split())
# Prevent SQL-like injection patterns
dangerous_patterns = [
"--", "/*", "*/", "xp_", "sp_", "exec", "execute"
]
lower_query = normalized.lower()
for pattern in dangerous_patterns:
if pattern in lower_query:
raise ValueError(f"Invalid character sequence: {pattern}")
return normalized
@router.post("/search", response_model=List[AnimeSummary])

View File

@@ -27,7 +27,7 @@ from sqlalchemy import (
func,
)
from sqlalchemy import Enum as SQLEnum
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.orm import Mapped, mapped_column, relationship, validates
from src.server.database.base import Base, TimestampMixin
@@ -114,6 +114,58 @@ class AnimeSeries(Base, TimestampMixin):
cascade="all, delete-orphan"
)
@validates('key')
def validate_key(self, key: str, value: str) -> str:
"""Validate key field length and format."""
if not value or not value.strip():
raise ValueError("Series key cannot be empty")
if len(value) > 255:
raise ValueError("Series key must be 255 characters or less")
return value.strip()
@validates('name')
def validate_name(self, key: str, value: str) -> str:
"""Validate name field length."""
if not value or not value.strip():
raise ValueError("Series name cannot be empty")
if len(value) > 500:
raise ValueError("Series name must be 500 characters or less")
return value.strip()
@validates('site')
def validate_site(self, key: str, value: str) -> str:
"""Validate site URL length."""
if not value or not value.strip():
raise ValueError("Series site URL cannot be empty")
if len(value) > 500:
raise ValueError("Site URL must be 500 characters or less")
return value.strip()
@validates('folder')
def validate_folder(self, key: str, value: str) -> str:
"""Validate folder path length."""
if not value or not value.strip():
raise ValueError("Series folder path cannot be empty")
if len(value) > 1000:
raise ValueError("Folder path must be 1000 characters or less")
return value.strip()
@validates('cover_url')
def validate_cover_url(self, key: str, value: Optional[str]) -> Optional[str]:
"""Validate cover URL length."""
if value is not None and len(value) > 1000:
raise ValueError("Cover URL must be 1000 characters or less")
return value
@validates('total_episodes')
def validate_total_episodes(self, key: str, value: Optional[int]) -> Optional[int]:
"""Validate total episodes is positive."""
if value is not None and value < 0:
raise ValueError("Total episodes must be non-negative")
if value is not None and value > 10000:
raise ValueError("Total episodes must be 10000 or less")
return value
def __repr__(self) -> str:
return f"<AnimeSeries(id={self.id}, key='{self.key}', name='{self.name}')>"
@@ -190,6 +242,47 @@ class Episode(Base, TimestampMixin):
back_populates="episodes"
)
@validates('season')
def validate_season(self, key: str, value: int) -> int:
"""Validate season number is positive."""
if value < 0:
raise ValueError("Season number must be non-negative")
if value > 1000:
raise ValueError("Season number must be 1000 or less")
return value
@validates('episode_number')
def validate_episode_number(self, key: str, value: int) -> int:
"""Validate episode number is positive."""
if value < 0:
raise ValueError("Episode number must be non-negative")
if value > 10000:
raise ValueError("Episode number must be 10000 or less")
return value
@validates('title')
def validate_title(self, key: str, value: Optional[str]) -> Optional[str]:
"""Validate title length."""
if value is not None and len(value) > 500:
raise ValueError("Episode title must be 500 characters or less")
return value
@validates('file_path')
def validate_file_path(
self, key: str, value: Optional[str]
) -> Optional[str]:
"""Validate file path length."""
if value is not None and len(value) > 1000:
raise ValueError("File path must be 1000 characters or less")
return value
@validates('file_size')
def validate_file_size(self, key: str, value: Optional[int]) -> Optional[int]:
"""Validate file size is non-negative."""
if value is not None and value < 0:
raise ValueError("File size must be non-negative")
return value
def __repr__(self) -> str:
return (
f"<Episode(id={self.id}, series_id={self.series_id}, "
@@ -334,6 +427,87 @@ class DownloadQueueItem(Base, TimestampMixin):
back_populates="download_items"
)
@validates('season')
def validate_season(self, key: str, value: int) -> int:
"""Validate season number is positive."""
if value < 0:
raise ValueError("Season number must be non-negative")
if value > 1000:
raise ValueError("Season number must be 1000 or less")
return value
@validates('episode_number')
def validate_episode_number(self, key: str, value: int) -> int:
"""Validate episode number is positive."""
if value < 0:
raise ValueError("Episode number must be non-negative")
if value > 10000:
raise ValueError("Episode number must be 10000 or less")
return value
@validates('progress_percent')
def validate_progress_percent(self, key: str, value: float) -> float:
"""Validate progress is between 0 and 100."""
if value < 0.0:
raise ValueError("Progress percent must be non-negative")
if value > 100.0:
raise ValueError("Progress percent cannot exceed 100")
return value
@validates('downloaded_bytes')
def validate_downloaded_bytes(self, key: str, value: int) -> int:
"""Validate downloaded bytes is non-negative."""
if value < 0:
raise ValueError("Downloaded bytes must be non-negative")
return value
@validates('total_bytes')
def validate_total_bytes(
self, key: str, value: Optional[int]
) -> Optional[int]:
"""Validate total bytes is non-negative."""
if value is not None and value < 0:
raise ValueError("Total bytes must be non-negative")
return value
@validates('download_speed')
def validate_download_speed(
self, key: str, value: Optional[float]
) -> Optional[float]:
"""Validate download speed is non-negative."""
if value is not None and value < 0.0:
raise ValueError("Download speed must be non-negative")
return value
@validates('retry_count')
def validate_retry_count(self, key: str, value: int) -> int:
"""Validate retry count is non-negative."""
if value < 0:
raise ValueError("Retry count must be non-negative")
if value > 100:
raise ValueError("Retry count cannot exceed 100")
return value
@validates('download_url')
def validate_download_url(
self, key: str, value: Optional[str]
) -> Optional[str]:
"""Validate download URL length."""
if value is not None and len(value) > 1000:
raise ValueError("Download URL must be 1000 characters or less")
return value
@validates('file_destination')
def validate_file_destination(
self, key: str, value: Optional[str]
) -> Optional[str]:
"""Validate file destination path length."""
if value is not None and len(value) > 1000:
raise ValueError(
"File destination path must be 1000 characters or less"
)
return value
def __repr__(self) -> str:
return (
f"<DownloadQueueItem(id={self.id}, "
@@ -412,6 +586,51 @@ class UserSession(Base, TimestampMixin):
doc="Last activity timestamp"
)
@validates('session_id')
def validate_session_id(self, key: str, value: str) -> str:
"""Validate session ID length and format."""
if not value or not value.strip():
raise ValueError("Session ID cannot be empty")
if len(value) > 255:
raise ValueError("Session ID must be 255 characters or less")
return value.strip()
@validates('token_hash')
def validate_token_hash(self, key: str, value: str) -> str:
"""Validate token hash length."""
if not value or not value.strip():
raise ValueError("Token hash cannot be empty")
if len(value) > 255:
raise ValueError("Token hash must be 255 characters or less")
return value.strip()
@validates('user_id')
def validate_user_id(
self, key: str, value: Optional[str]
) -> Optional[str]:
"""Validate user ID length."""
if value is not None and len(value) > 255:
raise ValueError("User ID must be 255 characters or less")
return value
@validates('ip_address')
def validate_ip_address(
self, key: str, value: Optional[str]
) -> Optional[str]:
"""Validate IP address length (IPv4 or IPv6)."""
if value is not None and len(value) > 45:
raise ValueError("IP address must be 45 characters or less")
return value
@validates('user_agent')
def validate_user_agent(
self, key: str, value: Optional[str]
) -> Optional[str]:
"""Validate user agent length."""
if value is not None and len(value) > 500:
raise ValueError("User agent must be 500 characters or less")
return value
def __repr__(self) -> str:
return (
f"<UserSession(id={self.id}, "

View File

@@ -33,17 +33,46 @@ class AuthMiddleware(BaseHTTPMiddleware):
- For POST requests to ``/api/auth/login`` and ``/api/auth/setup``
a simple per-IP rate limiter is applied to mitigate brute-force
attempts.
- Rate limit records are periodically cleaned to prevent memory leaks.
"""
def __init__(self, app: ASGIApp, *, rate_limit_per_minute: int = 5) -> None:
def __init__(
self, app: ASGIApp, *, rate_limit_per_minute: int = 5
) -> None:
super().__init__(app)
# in-memory rate limiter: ip -> {count, window_start}
self._rate: Dict[str, Dict[str, float]] = {}
self.rate_limit_per_minute = rate_limit_per_minute
self.window_seconds = 60
# Track last cleanup time to prevent memory leaks
self._last_cleanup = time.time()
self._cleanup_interval = 300 # Clean every 5 minutes
def _cleanup_old_entries(self) -> None:
"""Remove rate limit entries older than cleanup interval.
This prevents memory leaks from accumulating old IP addresses.
"""
now = time.time()
if now - self._last_cleanup < self._cleanup_interval:
return
# Remove entries older than 2x window to be safe
cutoff = now - (self.window_seconds * 2)
old_ips = [
ip for ip, record in self._rate.items()
if record["window_start"] < cutoff
]
for ip in old_ips:
del self._rate[ip]
self._last_cleanup = now
async def dispatch(self, request: Request, call_next: Callable):
path = request.url.path or ""
# Periodically clean up old rate limit entries
self._cleanup_old_entries()
# Apply rate limiting to auth endpoints that accept credentials
if (
@@ -75,7 +104,8 @@ class AuthMiddleware(BaseHTTPMiddleware):
},
)
# If Authorization header present try to decode token and attach session
# If Authorization header present try to decode token
# and attach session
auth_header = request.headers.get("authorization")
if auth_header and auth_header.lower().startswith("bearer "):
token = auth_header.split(" ", 1)[1].strip()
@@ -87,7 +117,9 @@ class AuthMiddleware(BaseHTTPMiddleware):
# Invalid token: if this is a protected API path, reject.
# For public/auth endpoints let the dependency system handle
# optional auth and return None.
if path.startswith("/api/") and not path.startswith("/api/auth"):
is_api = path.startswith("/api/")
is_auth = path.startswith("/api/auth")
if is_api and not is_auth:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"detail": "Invalid token"}

View File

@@ -48,6 +48,10 @@ class AuthService:
self._hash: Optional[str] = settings.master_password_hash
# In-memory failed attempts per identifier. Values are dicts with
# keys: count, last, locked_until
# WARNING: In-memory storage resets on process restart.
# This is acceptable for development but PRODUCTION deployments
# should use Redis or a database to persist failed login attempts
# and prevent bypass via process restart.
self._failed: Dict[str, Dict] = {}
# Policy
self.max_attempts = 5
@@ -71,18 +75,42 @@ class AuthService:
def setup_master_password(self, password: str) -> None:
"""Set the master password (hash and store in memory/settings).
Enforces strong password requirements:
- Minimum 8 characters
- Mixed case (upper and lower)
- At least one number
- At least one special character
For now we update only the in-memory value and
settings.master_password_hash. A future task should persist this
to a config file.
Args:
password: The password to set
Raises:
ValueError: If password doesn't meet requirements
"""
# Length check
if len(password) < 8:
raise ValueError("Password must be at least 8 characters long")
# Basic strength checks
# Mixed case check
if password.islower() or password.isupper():
raise ValueError("Password must include mixed case")
raise ValueError(
"Password must include both uppercase and lowercase letters"
)
# Number check
if not any(c.isdigit() for c in password):
raise ValueError("Password must include at least one number")
# Special character check
if password.isalnum():
# encourage a special character
raise ValueError("Password should include a symbol or punctuation")
raise ValueError(
"Password must include at least one special character "
"(symbol or punctuation)"
)
h = self._hash_password(password)
self._hash = h

View File

@@ -77,6 +77,8 @@ class DownloadService:
# Queue storage by status
self._pending_queue: deque[DownloadItem] = deque()
# Helper dict for O(1) lookup of pending items by ID
self._pending_items_by_id: Dict[str, DownloadItem] = {}
self._active_downloads: Dict[str, DownloadItem] = {}
self._completed_items: deque[DownloadItem] = deque(maxlen=100)
self._failed_items: deque[DownloadItem] = deque(maxlen=50)
@@ -107,6 +109,46 @@ class DownloadService:
max_retries=max_retries,
)
def _add_to_pending_queue(
self, item: DownloadItem, front: bool = False
) -> None:
"""Add item to pending queue and update helper dict.
Args:
item: Download item to add
front: If True, add to front of queue (higher priority)
"""
if front:
self._pending_queue.appendleft(item)
else:
self._pending_queue.append(item)
self._pending_items_by_id[item.id] = item
def _remove_from_pending_queue(self, item_or_id: str) -> Optional[DownloadItem]: # noqa: E501
"""Remove item from pending queue and update helper dict.
Args:
item_or_id: Item ID to remove
Returns:
Removed item or None if not found
"""
if isinstance(item_or_id, str):
item = self._pending_items_by_id.get(item_or_id)
if not item:
return None
item_id = item_or_id
else:
item = item_or_id
item_id = item.id
try:
self._pending_queue.remove(item)
del self._pending_items_by_id[item_id]
return item
except (ValueError, KeyError):
return None
def set_broadcast_callback(self, callback: Callable) -> None:
"""Set callback for broadcasting status updates via WebSocket."""
self._broadcast_callback = callback
@@ -146,14 +188,14 @@ class DownloadService:
# Reset status if was downloading when saved
if item.status == DownloadStatus.DOWNLOADING:
item.status = DownloadStatus.PENDING
self._pending_queue.append(item)
self._add_to_pending_queue(item)
# Restore failed items that can be retried
for item_dict in data.get("failed", []):
item = DownloadItem(**item_dict)
if item.retry_count < self._max_retries:
item.status = DownloadStatus.PENDING
self._pending_queue.append(item)
self._add_to_pending_queue(item)
else:
self._failed_items.append(item)
@@ -231,10 +273,9 @@ class DownloadService:
# Insert based on priority. High-priority downloads jump the
# line via appendleft so they execute before existing work;
# everything else is appended to preserve FIFO order.
if priority == DownloadPriority.HIGH:
self._pending_queue.appendleft(item)
else:
self._pending_queue.append(item)
self._add_to_pending_queue(
item, front=(priority == DownloadPriority.HIGH)
)
created_ids.append(item.id)
@@ -293,15 +334,15 @@ class DownloadService:
logger.info("Cancelled active download", item_id=item_id)
continue
# Check pending queue
for item in list(self._pending_queue):
if item.id == item_id:
self._pending_queue.remove(item)
removed_ids.append(item_id)
logger.info(
"Removed from pending queue", item_id=item_id
)
break
# Check pending queue - O(1) lookup using helper dict
if item_id in self._pending_items_by_id:
item = self._pending_items_by_id[item_id]
self._pending_queue.remove(item)
del self._pending_items_by_id[item_id]
removed_ids.append(item_id)
logger.info(
"Removed from pending queue", item_id=item_id
)
if removed_ids:
self._save_queue()
@@ -338,24 +379,25 @@ class DownloadService:
DownloadServiceError: If reordering fails
"""
try:
# Find and remove item
item_to_move = None
for item in list(self._pending_queue):
if item.id == item_id:
self._pending_queue.remove(item)
item_to_move = item
break
# Find and remove item - O(1) lookup using helper dict
item_to_move = self._pending_items_by_id.get(item_id)
if not item_to_move:
raise DownloadServiceError(
f"Item {item_id} not found in pending queue"
)
# Remove from current position
self._pending_queue.remove(item_to_move)
del self._pending_items_by_id[item_id]
# Insert at new position
queue_list = list(self._pending_queue)
new_position = max(0, min(new_position, len(queue_list)))
queue_list.insert(new_position, item_to_move)
self._pending_queue = deque(queue_list)
# Re-add to helper dict
self._pending_items_by_id[item_id] = item_to_move
self._save_queue()
@@ -575,7 +617,7 @@ class DownloadService:
item.retry_count += 1
item.error = None
item.progress = None
self._pending_queue.append(item)
self._add_to_pending_queue(item)
retried_ids.append(item.id)
logger.info(