Compare commits

..

No commits in common. "86eaa8a680cdc0233dfe281409cb33ffe00142f8" and "338e3feb4af2501eaae70897c817f0cae2c786f9" have entirely different histories.

60 changed files with 3425 additions and 3175 deletions

Binary file not shown.

View File

@ -17,8 +17,8 @@
"keep_days": 30
},
"other": {
"master_password_hash": "$pbkdf2-sha256$29000$tRZCyFnr/d87x/i/19p7Lw$BoD8EF67N97SRs7kIX8SREbotRwvFntS.WCH9ZwTxHY",
"anime_directory": "/home/lukas/Volume/serien/"
"master_password_hash": "$pbkdf2-sha256$29000$854zxnhvzXmPsVbqvXduTQ$G0HVRAt3kyO5eFwvo.ILkpX9JdmyXYJ9MNPTS/UxAGk",
"anime_directory": "/mnt/server/serien/Serien/"
},
"version": "1.0.0"
}

View File

@ -17,7 +17,8 @@
"keep_days": 30
},
"other": {
"master_password_hash": "$pbkdf2-sha256$29000$JWTsXWstZYyxNiYEQAihFA$K9QPNr2J9biZEX/7SFKU94dnynvyCICrGjKtZcEu6t8"
"master_password_hash": "$pbkdf2-sha256$29000$VCqllLL2vldKyTmHkJIyZg$jNllpzlpENdgCslmS.tG.PGxRZ9pUnrqFEQFveDEcYk",
"anime_directory": "/mnt/server/serien/Serien/"
},
"version": "1.0.0"
}

View File

@ -17,7 +17,8 @@
"keep_days": 30
},
"other": {
"master_password_hash": "$pbkdf2-sha256$29000$1fo/x1gLYax1bs15L.X8/w$T2GKqjDG7LT9tTZIwX/P2T/uKKuM9IhOD9jmhFUw4A0"
"master_password_hash": "$pbkdf2-sha256$29000$3/t/7733PkdoTckZQyildA$Nz9SdX2ZgqBwyzhQ9FGNcnzG1X.TW9oce3sDxJbVSdY",
"anime_directory": "/mnt/server/serien/Serien/"
},
"version": "1.0.0"
}

View File

@ -1,24 +0,0 @@
{
"name": "Aniworld",
"data_dir": "data",
"scheduler": {
"enabled": true,
"interval_minutes": 60
},
"logging": {
"level": "INFO",
"file": null,
"max_bytes": null,
"backup_count": 3
},
"backup": {
"enabled": false,
"path": "data/backups",
"keep_days": 30
},
"other": {
"master_password_hash": "$pbkdf2-sha256$29000$nbNWSkkJIeTce48xxrh3bg$QXT6A63JqmSLimtTeI04HzC4eKfQS26xFW7UL9Ry5co",
"anime_directory": "/home/lukas/Volume/serien/"
},
"version": "1.0.0"
}

View File

@ -1,24 +0,0 @@
{
"name": "Aniworld",
"data_dir": "data",
"scheduler": {
"enabled": true,
"interval_minutes": 60
},
"logging": {
"level": "INFO",
"file": null,
"max_bytes": null,
"backup_count": 3
},
"backup": {
"enabled": false,
"path": "data/backups",
"keep_days": 30
},
"other": {
"master_password_hash": "$pbkdf2-sha256$29000$j5HSWuu9V.rdm9Pa2zunNA$gjQqL753WLBMZtHVOhziVn.vW3Bkq8mGtCzSkbBjSHo",
"anime_directory": "/home/lukas/Volume/serien/"
},
"version": "1.0.0"
}

327
data/download_queue.json Normal file
View File

@ -0,0 +1,327 @@
{
"pending": [
{
"id": "ae6424dc-558b-4946-9f07-20db1a09bf33",
"serie_id": "test-series-2",
"serie_folder": "Another Series (2024)",
"serie_name": "Another Series",
"episode": {
"season": 1,
"episode": 1,
"title": null
},
"status": "pending",
"priority": "HIGH",
"added_at": "2025-11-28T17:54:38.593236Z",
"started_at": null,
"completed_at": null,
"progress": null,
"error": null,
"retry_count": 0,
"source_url": null
},
{
"id": "011c2038-9fe3-41cb-844f-ce50c40e415f",
"serie_id": "series-high",
"serie_folder": "Series High (2024)",
"serie_name": "Series High",
"episode": {
"season": 1,
"episode": 1,
"title": null
},
"status": "pending",
"priority": "HIGH",
"added_at": "2025-11-28T17:54:38.632289Z",
"started_at": null,
"completed_at": null,
"progress": null,
"error": null,
"retry_count": 0,
"source_url": null
},
{
"id": "0eee56e0-414d-4cd7-8da7-b5a139abd8b5",
"serie_id": "series-normal",
"serie_folder": "Series Normal (2024)",
"serie_name": "Series Normal",
"episode": {
"season": 1,
"episode": 1,
"title": null
},
"status": "pending",
"priority": "NORMAL",
"added_at": "2025-11-28T17:54:38.635082Z",
"started_at": null,
"completed_at": null,
"progress": null,
"error": null,
"retry_count": 0,
"source_url": null
},
{
"id": "eea9f4f3-98e5-4041-9fc6-92e3d4c6fee6",
"serie_id": "series-low",
"serie_folder": "Series Low (2024)",
"serie_name": "Series Low",
"episode": {
"season": 1,
"episode": 1,
"title": null
},
"status": "pending",
"priority": "LOW",
"added_at": "2025-11-28T17:54:38.637038Z",
"started_at": null,
"completed_at": null,
"progress": null,
"error": null,
"retry_count": 0,
"source_url": null
},
{
"id": "b6f84ea9-86c8-4cc9-90e5-c7c6ce10c593",
"serie_id": "test-series",
"serie_folder": "Test Series (2024)",
"serie_name": "Test Series",
"episode": {
"season": 1,
"episode": 1,
"title": null
},
"status": "pending",
"priority": "NORMAL",
"added_at": "2025-11-28T17:54:38.801266Z",
"started_at": null,
"completed_at": null,
"progress": null,
"error": null,
"retry_count": 0,
"source_url": null
},
{
"id": "412aa28d-9763-41ef-913d-3d63919f9346",
"serie_id": "test-series",
"serie_folder": "Test Series (2024)",
"serie_name": "Test Series",
"episode": {
"season": 1,
"episode": 1,
"title": null
},
"status": "pending",
"priority": "NORMAL",
"added_at": "2025-11-28T17:54:38.867939Z",
"started_at": null,
"completed_at": null,
"progress": null,
"error": null,
"retry_count": 0,
"source_url": null
},
{
"id": "3a036824-2d14-41dd-81b8-094dd322a137",
"serie_id": "invalid-series",
"serie_folder": "Invalid Series (2024)",
"serie_name": "Invalid Series",
"episode": {
"season": 99,
"episode": 99,
"title": null
},
"status": "pending",
"priority": "NORMAL",
"added_at": "2025-11-28T17:54:38.935125Z",
"started_at": null,
"completed_at": null,
"progress": null,
"error": null,
"retry_count": 0,
"source_url": null
},
{
"id": "1f4108ed-5488-4f46-ad5b-fe27e3b04790",
"serie_id": "test-series",
"serie_folder": "Test Series (2024)",
"serie_name": "Test Series",
"episode": {
"season": 1,
"episode": 1,
"title": null
},
"status": "pending",
"priority": "NORMAL",
"added_at": "2025-11-28T17:54:38.968296Z",
"started_at": null,
"completed_at": null,
"progress": null,
"error": null,
"retry_count": 0,
"source_url": null
},
{
"id": "5e880954-1a9f-450a-8008-5b9d6ac07d66",
"serie_id": "series-2",
"serie_folder": "Series 2 (2024)",
"serie_name": "Series 2",
"episode": {
"season": 1,
"episode": 1,
"title": null
},
"status": "pending",
"priority": "NORMAL",
"added_at": "2025-11-28T17:54:39.055885Z",
"started_at": null,
"completed_at": null,
"progress": null,
"error": null,
"retry_count": 0,
"source_url": null
},
{
"id": "2415ac21-509b-4d71-b5b9-b824116d6785",
"serie_id": "series-0",
"serie_folder": "Series 0 (2024)",
"serie_name": "Series 0",
"episode": {
"season": 1,
"episode": 1,
"title": null
},
"status": "pending",
"priority": "NORMAL",
"added_at": "2025-11-28T17:54:39.056795Z",
"started_at": null,
"completed_at": null,
"progress": null,
"error": null,
"retry_count": 0,
"source_url": null
},
{
"id": "716f9823-d59a-4b04-863b-c75fd54bc464",
"serie_id": "series-1",
"serie_folder": "Series 1 (2024)",
"serie_name": "Series 1",
"episode": {
"season": 1,
"episode": 1,
"title": null
},
"status": "pending",
"priority": "NORMAL",
"added_at": "2025-11-28T17:54:39.057486Z",
"started_at": null,
"completed_at": null,
"progress": null,
"error": null,
"retry_count": 0,
"source_url": null
},
{
"id": "36ad4323-daa9-49c4-97e8-a0aec0cca7a1",
"serie_id": "series-4",
"serie_folder": "Series 4 (2024)",
"serie_name": "Series 4",
"episode": {
"season": 1,
"episode": 1,
"title": null
},
"status": "pending",
"priority": "NORMAL",
"added_at": "2025-11-28T17:54:39.058179Z",
"started_at": null,
"completed_at": null,
"progress": null,
"error": null,
"retry_count": 0,
"source_url": null
},
{
"id": "695ee7a9-42bb-4953-9a8a-10bd7f533369",
"serie_id": "series-3",
"serie_folder": "Series 3 (2024)",
"serie_name": "Series 3",
"episode": {
"season": 1,
"episode": 1,
"title": null
},
"status": "pending",
"priority": "NORMAL",
"added_at": "2025-11-28T17:54:39.058816Z",
"started_at": null,
"completed_at": null,
"progress": null,
"error": null,
"retry_count": 0,
"source_url": null
},
{
"id": "aa948908-c410-42ec-85d6-a0298d7d95a5",
"serie_id": "persistent-series",
"serie_folder": "Persistent Series (2024)",
"serie_name": "Persistent Series",
"episode": {
"season": 1,
"episode": 1,
"title": null
},
"status": "pending",
"priority": "NORMAL",
"added_at": "2025-11-28T17:54:39.152427Z",
"started_at": null,
"completed_at": null,
"progress": null,
"error": null,
"retry_count": 0,
"source_url": null
},
{
"id": "2537f20e-f394-4c68-81d5-48be3c0c402a",
"serie_id": "ws-series",
"serie_folder": "WebSocket Series (2024)",
"serie_name": "WebSocket Series",
"episode": {
"season": 1,
"episode": 1,
"title": null
},
"status": "pending",
"priority": "NORMAL",
"added_at": "2025-11-28T17:54:39.219061Z",
"started_at": null,
"completed_at": null,
"progress": null,
"error": null,
"retry_count": 0,
"source_url": null
},
{
"id": "aaaf3b05-cce8-47d5-b350-59c5d72533ad",
"serie_id": "workflow-series",
"serie_folder": "Workflow Test Series (2024)",
"serie_name": "Workflow Test Series",
"episode": {
"season": 1,
"episode": 1,
"title": null
},
"status": "pending",
"priority": "HIGH",
"added_at": "2025-11-28T17:54:39.254462Z",
"started_at": null,
"completed_at": null,
"progress": null,
"error": null,
"retry_count": 0,
"source_url": null
}
],
"active": [],
"failed": [],
"timestamp": "2025-11-28T17:54:39.259761+00:00"
}

View File

@ -178,6 +178,10 @@ grep -rn "data-key\|data-folder\|data-series" src/server/web/templates/ --includ
- [ ] All CRUD operations use `key` for identification
- [ ] Logging uses `key` in messages
3. **`src/server/database/migrations/`**
- [ ] Migration files maintain `key` as unique, indexed column
- [ ] No migrations that use `folder` as identifier
**Validation Commands:**
```bash

View File

@ -60,9 +60,10 @@ Throughout the codebase, three identifiers are used for anime series:
**Valid examples**: `"attack-on-titan"`, `"one-piece"`, `"86-eighty-six"`, `"re-zero"`
**Invalid examples**: `"Attack On Titan"`, `"attack_on_titan"`, `"attack on titan"`
### Notes
### Migration Notes
- **Backward Compatibility**: API endpoints accepting `anime_id` will check `key` first, then fall back to `folder` lookup
- **Deprecation**: Folder-based lookups are deprecated and will be removed in a future version
- **New Code**: Always use `key` for identification; `folder` is metadata only
## API Endpoints
@ -163,91 +164,6 @@ All series-related WebSocket events include `key` as the primary identifier in t
- `AnimeSeriesService.get_by_id(id)` - Internal lookup by database ID
- No `get_by_folder()` method exists - folder is never used for lookups
### DownloadQueueItem Fields
| Field | Type | Purpose |
| -------------- | ----------- | ----------------------------------------- |
| `id` | String (PK) | UUID for the queue item |
| `serie_id` | String | Series key for identification |
| `serie_folder` | String | Filesystem folder path |
| `serie_name` | String | Display name for the series |
| `season` | Integer | Season number |
| `episode` | Integer | Episode number |
| `status` | Enum | pending, downloading, completed, failed |
| `priority` | Enum | low, normal, high |
| `progress` | Float | Download progress percentage (0.0-100.0) |
| `error` | String | Error message if failed |
| `retry_count` | Integer | Number of retry attempts |
| `added_at` | DateTime | When item was added to queue |
| `started_at` | DateTime | When download started (nullable) |
| `completed_at` | DateTime | When download completed/failed (nullable) |
## Data Storage
### Storage Architecture
The application uses **SQLite database** as the primary storage for all application data.
| Data Type | Storage Location | Service |
| -------------- | ------------------ | --------------------------------------- |
| Anime Series | `data/aniworld.db` | `AnimeSeriesService` |
| Episodes | `data/aniworld.db` | `AnimeSeriesService` |
| Download Queue | `data/aniworld.db` | `DownloadService` via `QueueRepository` |
| User Sessions | `data/aniworld.db` | `AuthService` |
| Configuration | `data/config.json` | `ConfigService` |
### Download Queue Storage
The download queue is stored in SQLite via `QueueRepository`, which wraps `DownloadQueueService`:
```python
# QueueRepository provides async operations for queue items
repository = QueueRepository(session_factory)
# Save item to database
saved_item = await repository.save_item(download_item)
# Get pending items (ordered by priority and add time)
pending = await repository.get_pending_items()
# Update item status
await repository.update_status(item_id, DownloadStatus.COMPLETED)
# Update download progress
await repository.update_progress(item_id, progress=45.5, downloaded=450, total=1000, speed=2.5)
```
**Queue Persistence Features:**
- Queue state survives server restarts
- Items in `downloading` status are reset to `pending` on startup
- Failed items within retry limit are automatically re-queued
- Completed and failed history is preserved (with limits)
- Real-time progress updates are persisted to database
### Anime Series Database Storage
```python
# Add series to database
await AnimeSeriesService.create(db_session, series_data)
# Query series by key
series = await AnimeSeriesService.get_by_key(db_session, "attack-on-titan")
# Update series
await AnimeSeriesService.update(db_session, series_id, update_data)
```
### Legacy File Storage (Deprecated)
The legacy file-based storage is **deprecated** and will be removed in v3.0.0:
- `Serie.save_to_file()` - Deprecated, use `AnimeSeriesService.create()`
- `Serie.load_from_file()` - Deprecated, use `AnimeSeriesService.get_by_key()`
- `SerieList.add()` - Deprecated, use `SerieList.add_to_db()`
Deprecation warnings are raised when using these methods.
## Core Services
### SeriesApp (`src/core/SeriesApp.py`)

View File

@ -75,7 +75,7 @@ conda run -n AniWorld python -m uvicorn src.server.fastapi_app:app --host 127.0.
---
## Implementation Notes
## Final Implementation Notes
1. **Incremental Development**: Implement features incrementally, testing each component thoroughly before moving to the next
2. **Code Review**: Review all generated code for adherence to project standards

View File

@ -14,4 +14,5 @@ pytest==7.4.3
pytest-asyncio==0.21.1
httpx==0.25.2
sqlalchemy>=2.0.35
alembic==1.13.0
aiosqlite>=0.19.0

View File

@ -7,7 +7,7 @@
# installs dependencies, sets up the database, and starts the application.
#
# Usage:
# ./start.sh [development|production] [--no-install]
# ./start.sh [development|production] [--no-install] [--no-migrate]
#
# Environment Variables:
# ENVIRONMENT: 'development' or 'production' (default: development)
@ -28,6 +28,7 @@ PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
CONDA_ENV="${CONDA_ENV:-AniWorld}"
ENVIRONMENT="${1:-development}"
INSTALL_DEPS="${INSTALL_DEPS:-true}"
RUN_MIGRATIONS="${RUN_MIGRATIONS:-true}"
PORT="${PORT:-8000}"
HOST="${HOST:-127.0.0.1}"
@ -103,6 +104,20 @@ install_dependencies() {
log_success "Dependencies installed."
}
# Run database migrations
run_migrations() {
if [[ "$RUN_MIGRATIONS" != "true" ]]; then
log_warning "Skipping database migrations."
return
fi
log_info "Running database migrations..."
cd "$PROJECT_ROOT"
conda run -n "$CONDA_ENV" \
python -m alembic upgrade head 2>/dev/null || log_warning "No migrations to run."
log_success "Database migrations completed."
}
# Initialize database
init_database() {
log_info "Initializing database..."
@ -205,6 +220,10 @@ main() {
INSTALL_DEPS="false"
shift
;;
--no-migrate)
RUN_MIGRATIONS="false"
shift
;;
*)
ENVIRONMENT="$1"
shift
@ -218,6 +237,7 @@ main() {
create_env_file
install_dependencies
init_database
run_migrations
start_application
}

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -3,23 +3,14 @@ SerieScanner - Scans directories for anime series and missing episodes.
This module provides functionality to scan anime directories, identify
missing episodes, and report progress through callback interfaces.
The scanner supports two modes of operation:
1. File-based mode (legacy): Saves scan results to data files
2. Database mode (preferred): Saves scan results to SQLite database
Database mode is preferred for new code. File-based mode is kept for
backward compatibility with CLI usage.
"""
from __future__ import annotations
import logging
import os
import re
import traceback
import uuid
import warnings
from typing import TYPE_CHECKING, Callable, Iterable, Iterator, Optional
from typing import Callable, Iterable, Iterator, Optional
from src.core.entities.series import Serie
from src.core.exceptions.Exceptions import MatchNotFoundError, NoKeyFoundException
@ -33,11 +24,6 @@ from src.core.interfaces.callbacks import (
)
from src.core.providers.base_provider import Loader
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
from src.server.database.models import AnimeSeries
logger = logging.getLogger(__name__)
error_logger = logging.getLogger("error")
no_key_found_logger = logging.getLogger("series.nokey")
@ -48,28 +34,13 @@ class SerieScanner:
Scans directories for anime series and identifies missing episodes.
Supports progress callbacks for real-time scanning updates.
The scanner supports two modes:
1. File-based (legacy): Set db_session=None, saves to data files
2. Database mode: Provide db_session, saves to SQLite database
Example:
# File-based mode (legacy)
scanner = SerieScanner("/path/to/anime", loader)
scanner.scan()
# Database mode (preferred)
async with get_db_session() as db:
scanner = SerieScanner("/path/to/anime", loader, db_session=db)
await scanner.scan_async()
"""
def __init__(
self,
basePath: str,
loader: Loader,
callback_manager: Optional[CallbackManager] = None,
db_session: Optional["AsyncSession"] = None
callback_manager: Optional[CallbackManager] = None
) -> None:
"""
Initialize the SerieScanner.
@ -78,8 +49,6 @@ class SerieScanner:
basePath: Base directory containing anime series
loader: Loader instance for fetching series information
callback_manager: Optional callback manager for progress updates
db_session: Optional database session for database mode.
If provided, scan_async() should be used instead of scan().
Raises:
ValueError: If basePath is invalid or doesn't exist
@ -102,7 +71,6 @@ class SerieScanner:
callback_manager or CallbackManager()
)
self._current_operation_id: Optional[str] = None
self._db_session: Optional["AsyncSession"] = db_session
logger.info("Initialized SerieScanner with base path: %s", abs_path)
@ -129,14 +97,7 @@ class SerieScanner:
callback: Optional[Callable[[str, int], None]] = None
) -> None:
"""
Scan directories for anime series and missing episodes (file-based).
This method saves results to data files. For database storage,
use scan_async() instead.
.. deprecated:: 2.0.0
Use :meth:`scan_async` for database-backed storage.
File-based storage will be removed in a future version.
Scan directories for anime series and missing episodes.
Args:
callback: Optional legacy callback function (folder, count)
@ -144,12 +105,6 @@ class SerieScanner:
Raises:
Exception: If scan fails critically
"""
warnings.warn(
"File-based scan() is deprecated. Use scan_async() for "
"database storage.",
DeprecationWarning,
stacklevel=2
)
# Generate unique operation ID
self._current_operation_id = str(uuid.uuid4())
@ -336,365 +291,6 @@ class SerieScanner:
raise
async def scan_async(
self,
db: "AsyncSession",
callback: Optional[Callable[[str, int], None]] = None
) -> None:
"""
Scan directories for anime series and save to database.
This is the preferred method for scanning when using database
storage. Results are saved to the database instead of files.
Args:
db: Database session for async operations
callback: Optional legacy callback function (folder, count)
Raises:
Exception: If scan fails critically
Example:
async with get_db_session() as db:
scanner = SerieScanner("/path/to/anime", loader)
await scanner.scan_async(db)
"""
# Generate unique operation ID
self._current_operation_id = str(uuid.uuid4())
logger.info("Starting async scan for missing episodes (database mode)")
# Notify scan starting
self._callback_manager.notify_progress(
ProgressContext(
operation_type=OperationType.SCAN,
operation_id=self._current_operation_id,
phase=ProgressPhase.STARTING,
current=0,
total=0,
percentage=0.0,
message="Initializing scan (database mode)"
)
)
try:
# Get total items to process
total_to_scan = self.get_total_to_scan()
logger.info("Total folders to scan: %d", total_to_scan)
result = self.__find_mp4_files()
counter = 0
saved_to_db = 0
for folder, mp4_files in result:
try:
counter += 1
# Calculate progress
if total_to_scan > 0:
percentage = (counter / total_to_scan) * 100
else:
percentage = 0.0
# Notify progress
self._callback_manager.notify_progress(
ProgressContext(
operation_type=OperationType.SCAN,
operation_id=self._current_operation_id,
phase=ProgressPhase.IN_PROGRESS,
current=counter,
total=total_to_scan,
percentage=percentage,
message=f"Scanning: {folder}",
details=f"Found {len(mp4_files)} episodes"
)
)
# Call legacy callback if provided
if callback:
callback(folder, counter)
serie = self.__read_data_from_file(folder)
if (
serie is not None
and serie.key
and serie.key.strip()
):
# Get missing episodes from provider
missing_episodes, _site = (
self.__get_missing_episodes_and_season(
serie.key, mp4_files
)
)
serie.episodeDict = missing_episodes
serie.folder = folder
# Save to database instead of file
await self._save_serie_to_db(serie, db)
saved_to_db += 1
# Store by key in memory cache
if serie.key in self.keyDict:
logger.error(
"Duplicate series found with key '%s' "
"(folder: '%s')",
serie.key,
folder
)
else:
self.keyDict[serie.key] = serie
logger.debug(
"Stored series with key '%s' (folder: '%s')",
serie.key,
folder
)
except NoKeyFoundException as nkfe:
error_msg = f"Error processing folder '{folder}': {nkfe}"
logger.error(error_msg)
self._callback_manager.notify_error(
ErrorContext(
operation_type=OperationType.SCAN,
operation_id=self._current_operation_id,
error=nkfe,
message=error_msg,
recoverable=True,
metadata={"folder": folder, "key": None}
)
)
except Exception as e:
error_msg = (
f"Folder: '{folder}' - Unexpected error: {e}"
)
error_logger.error(
"%s\n%s",
error_msg,
traceback.format_exc()
)
self._callback_manager.notify_error(
ErrorContext(
operation_type=OperationType.SCAN,
operation_id=self._current_operation_id,
error=e,
message=error_msg,
recoverable=True,
metadata={"folder": folder, "key": None}
)
)
continue
# Notify scan completion
self._callback_manager.notify_completion(
CompletionContext(
operation_type=OperationType.SCAN,
operation_id=self._current_operation_id,
success=True,
message=f"Scan completed. Processed {counter} folders.",
statistics={
"total_folders": counter,
"series_found": len(self.keyDict),
"saved_to_db": saved_to_db
}
)
)
logger.info(
"Async scan completed. Processed %d folders, "
"found %d series, saved %d to database",
counter,
len(self.keyDict),
saved_to_db
)
except Exception as e:
error_msg = f"Critical async scan error: {e}"
logger.error("%s\n%s", error_msg, traceback.format_exc())
self._callback_manager.notify_error(
ErrorContext(
operation_type=OperationType.SCAN,
operation_id=self._current_operation_id,
error=e,
message=error_msg,
recoverable=False
)
)
self._callback_manager.notify_completion(
CompletionContext(
operation_type=OperationType.SCAN,
operation_id=self._current_operation_id,
success=False,
message=error_msg
)
)
raise
async def _save_serie_to_db(
self,
serie: Serie,
db: "AsyncSession"
) -> Optional["AnimeSeries"]:
"""
Save or update a series in the database.
Creates a new record if the series doesn't exist, or updates
the episodes if they have changed.
Args:
serie: Serie instance to save
db: Database session for async operations
Returns:
Created or updated AnimeSeries instance, or None if unchanged
"""
from src.server.database.service import AnimeSeriesService, EpisodeService
# Check if series already exists
existing = await AnimeSeriesService.get_by_key(db, serie.key)
if existing:
# Build existing episode dict from episodes for comparison
existing_episodes = await EpisodeService.get_by_series(
db, existing.id
)
existing_dict: dict[int, list[int]] = {}
for ep in existing_episodes:
if ep.season not in existing_dict:
existing_dict[ep.season] = []
existing_dict[ep.season].append(ep.episode_number)
for season in existing_dict:
existing_dict[season].sort()
# Update episodes if changed
if existing_dict != serie.episodeDict:
# Add new episodes
new_dict = serie.episodeDict or {}
for season, episode_numbers in new_dict.items():
existing_eps = set(existing_dict.get(season, []))
for ep_num in episode_numbers:
if ep_num not in existing_eps:
await EpisodeService.create(
db=db,
series_id=existing.id,
season=season,
episode_number=ep_num,
)
# Update folder if changed
if existing.folder != serie.folder:
await AnimeSeriesService.update(
db,
existing.id,
folder=serie.folder
)
logger.info(
"Updated series in database: %s (key=%s)",
serie.name,
serie.key
)
return existing
else:
logger.debug(
"Series unchanged in database: %s (key=%s)",
serie.name,
serie.key
)
return None
else:
# Create new series
anime_series = await AnimeSeriesService.create(
db=db,
key=serie.key,
name=serie.name,
site=serie.site,
folder=serie.folder,
)
# Create Episode records
if serie.episodeDict:
for season, episode_numbers in serie.episodeDict.items():
for ep_num in episode_numbers:
await EpisodeService.create(
db=db,
series_id=anime_series.id,
season=season,
episode_number=ep_num,
)
logger.info(
"Created series in database: %s (key=%s)",
serie.name,
serie.key
)
return anime_series
async def _update_serie_in_db(
self,
serie: Serie,
db: "AsyncSession"
) -> Optional["AnimeSeries"]:
"""
Update an existing series in the database.
Args:
serie: Serie instance to update
db: Database session for async operations
Returns:
Updated AnimeSeries instance, or None if not found
"""
from src.server.database.service import AnimeSeriesService, EpisodeService
existing = await AnimeSeriesService.get_by_key(db, serie.key)
if not existing:
logger.warning(
"Cannot update non-existent series: %s (key=%s)",
serie.name,
serie.key
)
return None
# Update basic fields
await AnimeSeriesService.update(
db,
existing.id,
name=serie.name,
site=serie.site,
folder=serie.folder,
)
# Update episodes - add any new ones
if serie.episodeDict:
existing_episodes = await EpisodeService.get_by_series(
db, existing.id
)
existing_dict: dict[int, set[int]] = {}
for ep in existing_episodes:
if ep.season not in existing_dict:
existing_dict[ep.season] = set()
existing_dict[ep.season].add(ep.episode_number)
for season, episode_numbers in serie.episodeDict.items():
existing_eps = existing_dict.get(season, set())
for ep_num in episode_numbers:
if ep_num not in existing_eps:
await EpisodeService.create(
db=db,
series_id=existing.id,
season=season,
episode_number=ep_num,
)
logger.info(
"Updated series in database: %s (key=%s)",
serie.name,
serie.key
)
return existing
def __find_mp4_files(self) -> Iterator[tuple[str, list[str]]]:
"""Find all .mp4 files in the directory structure."""
logger.info("Scanning for .mp4 files")

View File

@ -8,16 +8,10 @@ progress reporting, and error handling.
import asyncio
import logging
import warnings
from typing import Any, Dict, List, Optional
from events import Events
try:
from sqlalchemy.ext.asyncio import AsyncSession
except ImportError: # pragma: no cover - optional dependency
AsyncSession = object # type: ignore
from src.core.entities.SerieList import SerieList
from src.core.entities.series import Serie
from src.core.providers.provider_factory import Loaders
@ -136,20 +130,15 @@ class SeriesApp:
def __init__(
self,
directory_to_search: str,
db_session: Optional[AsyncSession] = None,
):
"""
Initialize SeriesApp.
Args:
directory_to_search: Base directory for anime series
db_session: Optional database session for database-backed
storage. When provided, SerieList and SerieScanner will
use the database instead of file-based storage.
"""
self.directory_to_search = directory_to_search
self._db_session = db_session
# Initialize events
self._events = Events()
@ -158,20 +147,15 @@ class SeriesApp:
self.loaders = Loaders()
self.loader = self.loaders.GetLoader(key="aniworld.to")
self.serie_scanner = SerieScanner(
directory_to_search, self.loader, db_session=db_session
)
self.list = SerieList(
self.directory_to_search, db_session=db_session
)
self.serie_scanner = SerieScanner(directory_to_search, self.loader)
self.list = SerieList(self.directory_to_search)
# Synchronous init used during constructor to avoid awaiting
# in __init__
self._init_list_sync()
logger.info(
"SeriesApp initialized for directory: %s (db_session: %s)",
directory_to_search,
"provided" if db_session else "none"
"SeriesApp initialized for directory: %s",
directory_to_search
)
@property
@ -204,53 +188,6 @@ class SeriesApp:
"""Set scan_status event handler."""
self._events.scan_status = value
@property
def db_session(self) -> Optional[AsyncSession]:
"""
Get the database session.
Returns:
AsyncSession or None: The database session if configured
"""
return self._db_session
def set_db_session(self, session: Optional[AsyncSession]) -> None:
"""
Update the database session.
Also updates the db_session on SerieList and SerieScanner.
Args:
session: The new database session or None
"""
self._db_session = session
self.list._db_session = session
self.serie_scanner._db_session = session
logger.debug(
"Database session updated: %s",
"provided" if session else "none"
)
async def init_from_db_async(self) -> None:
"""
Initialize series list from database (async).
This should be called when using database storage instead of
the synchronous file-based initialization.
"""
if self._db_session:
await self.list.load_series_from_db(self._db_session)
self.series_list = self.list.GetMissingEpisode()
logger.debug(
"Loaded %d series with missing episodes from database",
len(self.series_list)
)
else:
warnings.warn(
"init_from_db_async called without db_session configured",
UserWarning
)
def _init_list_sync(self) -> None:
"""Synchronous initialization helper for constructor."""
self.series_list = self.list.GetMissingEpisode()

View File

@ -1,120 +1,41 @@
"""Utilities for loading and managing stored anime series metadata.
This module provides the SerieList class for managing collections of anime
series metadata. It supports both file-based and database-backed storage.
The class can operate in two modes:
1. File-based mode (legacy): Reads/writes data files from disk
2. Database mode: Reads/writes to SQLite database via AnimeSeriesService
Database mode is preferred for new code. File-based mode is kept for
backward compatibility with CLI usage.
"""
from __future__ import annotations
"""Utilities for loading and managing stored anime series metadata."""
import logging
import os
import warnings
from json import JSONDecodeError
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional
from typing import Dict, Iterable, List, Optional
from src.core.entities.series import Serie
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
from src.server.database.models import AnimeSeries
logger = logging.getLogger(__name__)
class SerieList:
"""
Represents the collection of cached series stored on disk or database.
Represents the collection of cached series stored on disk.
Series are identified by their unique 'key' (provider identifier).
The 'folder' is metadata only and not used for lookups.
The class supports two modes of operation:
1. File-based mode (legacy):
Initialize without db_session to use file-based storage.
Series are loaded from 'data' files in the anime directory.
2. Database mode (preferred):
Pass db_session to use database-backed storage via AnimeSeriesService.
Series are loaded from the AnimeSeries table.
Example:
# File-based mode (legacy)
serie_list = SerieList("/path/to/anime")
# Database mode (preferred)
async with get_db_session() as db:
serie_list = SerieList("/path/to/anime", db_session=db)
await serie_list.load_series_from_db()
Attributes:
directory: Path to the anime directory
keyDict: Internal dictionary mapping serie.key to Serie objects
_db_session: Optional database session for database mode
"""
def __init__(
self,
base_path: str,
db_session: Optional["AsyncSession"] = None,
skip_load: bool = False
) -> None:
"""Initialize the SerieList.
Args:
base_path: Path to the anime directory
db_session: Optional database session for database mode.
If provided, use load_series_from_db() instead of
the automatic file-based loading.
skip_load: If True, skip automatic loading of series.
Useful when using database mode to allow async loading.
"""
def __init__(self, base_path: str) -> None:
self.directory: str = base_path
# Internal storage using serie.key as the dictionary key
self.keyDict: Dict[str, Serie] = {}
self._db_session: Optional["AsyncSession"] = db_session
# Only auto-load from files if no db_session and not skipping
if not skip_load and db_session is None:
self.load_series()
self.load_series()
def add(self, serie: Serie) -> None:
"""
Persist a new series if it is not already present (file-based mode).
Persist a new series if it is not already present.
Uses serie.key for identification. The serie.folder is used for
filesystem operations only.
.. deprecated:: 2.0.0
Use :meth:`add_to_db` for database-backed storage.
File-based storage will be removed in a future version.
Args:
serie: The Serie instance to add
Note:
This method creates data files on disk. For database storage,
use add_to_db() instead.
"""
if self.contains(serie.key):
return
warnings.warn(
"File-based storage via add() is deprecated. "
"Use add_to_db() for database storage.",
DeprecationWarning,
stacklevel=2
)
data_path = os.path.join(self.directory, serie.folder, "data")
anime_path = os.path.join(self.directory, serie.folder)
os.makedirs(anime_path, exist_ok=True)
@ -123,73 +44,6 @@ class SerieList:
# Store by key, not folder
self.keyDict[serie.key] = serie
async def add_to_db(
self,
serie: Serie,
db: "AsyncSession"
) -> Optional["AnimeSeries"]:
"""
Add a series to the database.
Uses serie.key for identification. Creates a new AnimeSeries
record in the database if it doesn't already exist.
Args:
serie: The Serie instance to add
db: Database session for async operations
Returns:
Created AnimeSeries instance, or None if already exists
Example:
async with get_db_session() as db:
result = await serie_list.add_to_db(serie, db)
if result:
print(f"Added series: {result.name}")
"""
from src.server.database.service import AnimeSeriesService, EpisodeService
# Check if series already exists in DB
existing = await AnimeSeriesService.get_by_key(db, serie.key)
if existing:
logger.debug(
"Series already exists in database: %s (key=%s)",
serie.name,
serie.key
)
return None
# Create new series in database
anime_series = await AnimeSeriesService.create(
db=db,
key=serie.key,
name=serie.name,
site=serie.site,
folder=serie.folder,
)
# Create Episode records for each episode in episodeDict
if serie.episodeDict:
for season, episode_numbers in serie.episodeDict.items():
for episode_number in episode_numbers:
await EpisodeService.create(
db=db,
series_id=anime_series.id,
season=season,
episode_number=episode_number,
)
# Also add to in-memory collection
self.keyDict[serie.key] = serie
logger.info(
"Added series to database: %s (key=%s)",
serie.name,
serie.key
)
return anime_series
def contains(self, key: str) -> bool:
"""
Return True when a series identified by ``key`` already exists.
@ -253,112 +107,6 @@ class SerieList:
error,
)
async def load_series_from_db(self, db: "AsyncSession") -> int:
"""
Load all series from the database into the in-memory collection.
This is the preferred method for populating the series list
when using database-backed storage.
Args:
db: Database session for async operations
Returns:
Number of series loaded from the database
Example:
async with get_db_session() as db:
serie_list = SerieList("/path/to/anime", skip_load=True)
count = await serie_list.load_series_from_db(db)
print(f"Loaded {count} series from database")
"""
from src.server.database.service import AnimeSeriesService
# Clear existing in-memory data
self.keyDict.clear()
# Load all series from database (with episodes for episodeDict)
anime_series_list = await AnimeSeriesService.get_all(
db, with_episodes=True
)
for anime_series in anime_series_list:
serie = self._convert_from_db(anime_series)
self.keyDict[serie.key] = serie
logger.info(
"Loaded %d series from database",
len(self.keyDict)
)
return len(self.keyDict)
@staticmethod
def _convert_from_db(anime_series: "AnimeSeries") -> Serie:
"""
Convert an AnimeSeries database model to a Serie entity.
Args:
anime_series: AnimeSeries model from database
(must have episodes relationship loaded)
Returns:
Serie entity instance
"""
# Build episode_dict from episodes relationship
episode_dict: dict[int, list[int]] = {}
if anime_series.episodes:
for episode in anime_series.episodes:
season = episode.season
if season not in episode_dict:
episode_dict[season] = []
episode_dict[season].append(episode.episode_number)
# Sort episode numbers within each season
for season in episode_dict:
episode_dict[season].sort()
return Serie(
key=anime_series.key,
name=anime_series.name,
site=anime_series.site,
folder=anime_series.folder,
episodeDict=episode_dict
)
@staticmethod
def _convert_to_db_dict(serie: Serie) -> dict:
"""
Convert a Serie entity to a dictionary for database creation.
Args:
serie: Serie entity instance
Returns:
Dictionary suitable for AnimeSeriesService.create()
"""
return {
"key": serie.key,
"name": serie.name,
"site": serie.site,
"folder": serie.folder,
}
async def contains_in_db(self, key: str, db: "AsyncSession") -> bool:
"""
Check if a series with the given key exists in the database.
Args:
key: The unique provider identifier for the series
db: Database session for async operations
Returns:
True if the series exists in the database
"""
from src.server.database.service import AnimeSeriesService
existing = await AnimeSeriesService.get_by_key(db, key)
return existing is not None
def GetMissingEpisode(self) -> List[Serie]:
"""Return all series that still contain missing episodes."""
return [

View File

@ -1,5 +1,4 @@
import json
import warnings
class Serie:
@ -155,46 +154,13 @@ class Serie:
)
def save_to_file(self, filename: str):
"""Save Serie object to JSON file.
.. deprecated::
File-based storage is deprecated. Use database storage via
`AnimeSeriesService.create()` instead. This method will be
removed in v3.0.0.
Args:
filename: Path to save the JSON file
"""
warnings.warn(
"save_to_file() is deprecated and will be removed in v3.0.0. "
"Use database storage via AnimeSeriesService.create() instead.",
DeprecationWarning,
stacklevel=2
)
"""Save Serie object to JSON file."""
with open(filename, "w", encoding="utf-8") as file:
json.dump(self.to_dict(), file, indent=4)
@classmethod
def load_from_file(cls, filename: str) -> "Serie":
"""Load Serie object from JSON file.
.. deprecated::
File-based storage is deprecated. Use database storage via
`AnimeSeriesService.get_by_key()` instead. This method will be
removed in v3.0.0.
Args:
filename: Path to load the JSON file from
Returns:
Serie: The loaded Serie object
"""
warnings.warn(
"load_from_file() is deprecated and will be removed in v3.0.0. "
"Use database storage via AnimeSeriesService instead.",
DeprecationWarning,
stacklevel=2
)
"""Load Serie object from JSON file."""
with open(filename, "r", encoding="utf-8") as file:
data = json.load(file)
return cls.from_dict(data)

View File

@ -229,6 +229,37 @@ class DatabaseIntegrityChecker:
logger.warning(msg)
issues_found += count
# Check for invalid progress percentages
stmt = select(DownloadQueueItem).where(
(DownloadQueueItem.progress < 0) |
(DownloadQueueItem.progress > 100)
)
invalid_progress = self.session.execute(stmt).scalars().all()
if invalid_progress:
count = len(invalid_progress)
msg = (
f"Found {count} queue items with invalid progress "
f"percentages"
)
self.issues.append(msg)
logger.warning(msg)
issues_found += count
# Check for queue items with invalid status
valid_statuses = {'pending', 'downloading', 'completed', 'failed'}
stmt = select(DownloadQueueItem).where(
~DownloadQueueItem.status.in_(valid_statuses)
)
invalid_status = self.session.execute(stmt).scalars().all()
if invalid_status:
count = len(invalid_status)
msg = f"Found {count} queue items with invalid status"
self.issues.append(msg)
logger.warning(msg)
issues_found += count
if issues_found == 0:
logger.info("No data consistency issues found")

View File

@ -4,14 +4,11 @@ from typing import Any, List, Optional
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from src.core.entities.series import Serie
from src.server.database.service import AnimeSeriesService
from src.server.services.anime_service import AnimeService, AnimeServiceError
from src.server.utils.dependencies import (
get_anime_service,
get_optional_database_session,
get_series_app,
require_auth,
)
@ -585,7 +582,6 @@ async def add_series(
request: AddSeriesRequest,
_auth: dict = Depends(require_auth),
series_app: Any = Depends(get_series_app),
db: Optional[AsyncSession] = Depends(get_optional_database_session),
) -> dict:
"""Add a new series to the library.
@ -593,9 +589,6 @@ async def add_series(
The `key` is the URL-safe identifier used for all lookups.
The `name` is stored as display metadata along with a
filesystem-friendly `folder` name derived from the name.
Series are saved to the database using AnimeSeriesService when
database is available, falling back to in-memory storage otherwise.
Args:
request: Request containing the series link and name.
@ -603,10 +596,9 @@ async def add_series(
- name: Display name for the series
_auth: Ensures the caller is authenticated (value unused)
series_app: Core `SeriesApp` instance provided via dependency
db: Optional database session for async operations
Returns:
Dict[str, Any]: Status payload with success message, key, and db_id
Dict[str, Any]: Status payload with success message and key
Raises:
HTTPException: If adding the series fails or link is invalid
@ -625,6 +617,13 @@ async def add_series(
detail="Series name cannot be empty",
)
# Check if series_app has the list attribute
if not hasattr(series_app, "list"):
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Series list functionality not available",
)
# Extract key from link URL
# Expected format: https://aniworld.to/anime/stream/{key}
link = request.link.strip()
@ -647,68 +646,36 @@ async def add_series(
# Create folder from name (filesystem-friendly)
folder = request.name.strip()
db_id = None
# Try to save to database if available
if db is not None:
# Check if series already exists in database
existing = await AnimeSeriesService.get_by_key(db, key)
if existing:
return {
"status": "exists",
"message": f"Series already exists: {request.name}",
"key": key,
"folder": existing.folder,
"db_id": existing.id
}
# Save to database using AnimeSeriesService
anime_series = await AnimeSeriesService.create(
db=db,
key=key,
name=request.name.strip(),
site="aniworld.to",
folder=folder,
)
db_id = anime_series.id
logger.info(
"Added series to database: %s (key=%s, db_id=%d)",
request.name,
key,
db_id
)
# Create a new Serie object
# key: unique identifier extracted from link
# name: display name from request
# folder: filesystem folder name (derived from name)
# episodeDict: empty for new series
serie = Serie(
key=key,
name=request.name.strip(),
site="aniworld.to",
folder=folder,
episodeDict={}
)
# Also add to in-memory cache if series_app has the list attribute
if series_app and hasattr(series_app, "list"):
serie = Serie(
key=key,
name=request.name.strip(),
site="aniworld.to",
folder=folder,
episodeDict={}
)
# Add to in-memory cache
if hasattr(series_app.list, 'keyDict'):
# Direct update without file saving
series_app.list.keyDict[key] = serie
elif hasattr(series_app.list, 'add'):
# Legacy: use add method (may create file with deprecation warning)
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
series_app.list.add(serie)
# Add the series to the list
series_app.list.add(serie)
# Refresh the series list to update the cache
if hasattr(series_app, "refresh_series_list"):
series_app.refresh_series_list()
return {
"status": "success",
"message": f"Successfully added series: {request.name}",
"key": key,
"folder": folder,
"db_id": db_id
"folder": folder
}
except HTTPException:
raise
except Exception as exc:
logger.error("Failed to add series: %s", exc, exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to add series: {str(exc)}",

View File

@ -26,7 +26,7 @@ optional_bearer = HTTPBearer(auto_error=False)
@router.post("/setup", status_code=http_status.HTTP_201_CREATED)
async def setup_auth(req: SetupRequest):
def setup_auth(req: SetupRequest):
"""Initial setup endpoint to configure the master password.
This endpoint also initializes the configuration with default values
@ -57,20 +57,17 @@ async def setup_auth(req: SetupRequest):
config.other['master_password_hash'] = password_hash
# Store anime directory in config's other field if provided
anime_directory = None
if hasattr(req, 'anime_directory') and req.anime_directory:
anime_directory = req.anime_directory.strip()
if anime_directory:
config.other['anime_directory'] = anime_directory
config.other['anime_directory'] = req.anime_directory
# Save the config with the password hash and anime directory
config_service.save_config(config, create_backup=False)
return {"status": "ok"}
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
return {"status": "ok"}
@router.post("/login", response_model=LoginResponse)
def login(req: LoginRequest):

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, status
@ -210,10 +210,10 @@ def update_advanced_config(
) from e
@router.post("/directory", response_model=Dict[str, Any])
async def update_directory(
@router.post("/directory", response_model=Dict[str, str])
def update_directory(
directory_config: Dict[str, str], auth: dict = Depends(require_auth)
) -> Dict[str, Any]:
) -> Dict[str, str]:
"""Update anime directory configuration.
Args:
@ -235,15 +235,13 @@ async def update_directory(
app_config = config_service.load_config()
# Store directory in other section
app_config.other["anime_directory"] = directory
if "anime_directory" not in app_config.other:
app_config.other["anime_directory"] = directory
else:
app_config.other["anime_directory"] = directory
config_service.save_config(app_config)
response: Dict[str, Any] = {
"message": "Anime directory updated successfully"
}
return response
return {"message": "Anime directory updated successfully"}
except ConfigServiceError as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,

View File

@ -13,7 +13,7 @@ This package provides persistent storage for anime series, episodes, download qu
Install required dependencies:
```bash
pip install sqlalchemy aiosqlite
pip install sqlalchemy alembic aiosqlite
```
Or use the project requirements:
@ -163,6 +163,24 @@ from src.config.settings import settings
settings.database_url = "sqlite:///./data/aniworld.db"
```
## Migrations (Future)
Alembic is installed for database migrations:
```bash
# Initialize Alembic
alembic init alembic
# Generate migration
alembic revision --autogenerate -m "Description"
# Apply migrations
alembic upgrade head
# Rollback
alembic downgrade -1
```
## Testing
Run database tests:
@ -178,7 +196,8 @@ The test suite uses an in-memory SQLite database for isolation and speed.
- **base.py**: Base declarative class and mixins
- **models.py**: SQLAlchemy ORM models (4 models)
- **connection.py**: Engine, session factory, dependency injection
- \***\*init**.py\*\*: Package exports
- **migrations.py**: Alembic migration placeholder
- ****init**.py**: Package exports
- **service.py**: Service layer with CRUD operations
## Service Layer
@ -413,4 +432,5 @@ Solution: Ensure referenced records exist before creating relationships.
## Further Reading
- [SQLAlchemy 2.0 Documentation](https://docs.sqlalchemy.org/en/20/)
- [Alembic Tutorial](https://alembic.sqlalchemy.org/en/latest/tutorial.html)
- [FastAPI with Databases](https://fastapi.tiangolo.com/tutorial/sql-databases/)

View File

@ -30,6 +30,7 @@ from src.server.database.init import (
create_database_backup,
create_database_schema,
get_database_info,
get_migration_guide,
get_schema_version,
initialize_database,
seed_initial_data,
@ -63,6 +64,7 @@ __all__ = [
"check_database_health",
"create_database_backup",
"get_database_info",
"get_migration_guide",
"CURRENT_SCHEMA_VERSION",
"EXPECTED_TABLES",
# Models

View File

@ -86,24 +86,19 @@ async def init_db() -> None:
db_url = _get_database_url()
logger.info(f"Initializing database: {db_url}")
# Build engine kwargs based on database type
is_sqlite = "sqlite" in db_url
engine_kwargs = {
"echo": settings.log_level == "DEBUG",
"poolclass": pool.StaticPool if is_sqlite else pool.QueuePool,
"pool_pre_ping": True,
}
# Only add pool_size and max_overflow for non-SQLite databases
if not is_sqlite:
engine_kwargs["pool_size"] = 5
engine_kwargs["max_overflow"] = 10
# Create async engine
_engine = create_async_engine(db_url, **engine_kwargs)
_engine = create_async_engine(
db_url,
echo=settings.log_level == "DEBUG",
poolclass=pool.StaticPool if "sqlite" in db_url else pool.QueuePool,
pool_size=5 if "sqlite" not in db_url else None,
max_overflow=10 if "sqlite" not in db_url else None,
pool_pre_ping=True,
future=True,
)
# Configure SQLite if needed
if is_sqlite:
if "sqlite" in db_url:
_configure_sqlite_engine(_engine)
# Create async session factory
@ -117,13 +112,12 @@ async def init_db() -> None:
# Create sync engine for initial setup
sync_url = settings.database_url
is_sqlite_sync = "sqlite" in sync_url
sync_engine_kwargs = {
"echo": settings.log_level == "DEBUG",
"poolclass": pool.StaticPool if is_sqlite_sync else pool.QueuePool,
"pool_pre_ping": True,
}
_sync_engine = create_engine(sync_url, **sync_engine_kwargs)
_sync_engine = create_engine(
sync_url,
echo=settings.log_level == "DEBUG",
poolclass=pool.StaticPool if "sqlite" in sync_url else pool.QueuePool,
pool_pre_ping=True,
)
# Create sync session factory
_sync_session_factory = sessionmaker(
@ -264,35 +258,3 @@ def get_sync_session() -> Session:
)
return _sync_session_factory()
def get_async_session_factory() -> AsyncSession:
"""Get a new async database session (factory function).
Creates a new session instance for use in repository patterns.
The caller is responsible for committing/rolling back and closing.
Returns:
AsyncSession: New database session for async operations
Raises:
RuntimeError: If database is not initialized
Example:
session = get_async_session_factory()
try:
result = await session.execute(select(AnimeSeries))
await session.commit()
return result.scalars().all()
except Exception:
await session.rollback()
raise
finally:
await session.close()
"""
if _session_factory is None:
raise RuntimeError(
"Database not initialized. Call init_db() first."
)
return _session_factory()

View File

@ -0,0 +1,479 @@
"""Example integration of database service with existing services.
This file demonstrates how to integrate the database service layer with
existing application services like AnimeService and DownloadService.
These examples show patterns for:
- Persisting scan results to database
- Loading queue from database on startup
- Syncing download progress to database
- Maintaining consistency between in-memory state and database
"""
from __future__ import annotations
import logging
from typing import List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from src.core.entities.series import Serie
from src.server.database.models import DownloadPriority, DownloadStatus
from src.server.database.service import (
AnimeSeriesService,
DownloadQueueService,
EpisodeService,
)
logger = logging.getLogger(__name__)
# ============================================================================
# Example 1: Persist Scan Results
# ============================================================================
async def persist_scan_results(
db: AsyncSession,
series_list: List[Serie],
) -> None:
"""Persist scan results to database.
Updates or creates anime series and their episodes based on
scan results from SerieScanner.
Args:
db: Database session
series_list: List of Serie objects from scan
"""
logger.info(f"Persisting {len(series_list)} series to database")
for serie in series_list:
# Check if series exists
existing = await AnimeSeriesService.get_by_key(db, serie.key)
if existing:
# Update existing series
await AnimeSeriesService.update(
db,
existing.id,
name=serie.name,
site=serie.site,
folder=serie.folder,
episode_dict=serie.episode_dict,
)
series_id = existing.id
else:
# Create new series
new_series = await AnimeSeriesService.create(
db,
key=serie.key,
name=serie.name,
site=serie.site,
folder=serie.folder,
episode_dict=serie.episode_dict,
)
series_id = new_series.id
# Update episodes for this series
await _update_episodes(db, series_id, serie)
await db.commit()
logger.info("Scan results persisted successfully")
async def _update_episodes(
db: AsyncSession,
series_id: int,
serie: Serie,
) -> None:
"""Update episodes for a series.
Args:
db: Database session
series_id: Series ID in database
serie: Serie object with episode information
"""
# Get existing episodes
existing_episodes = await EpisodeService.get_by_series(db, series_id)
existing_map = {
(ep.season, ep.episode_number): ep
for ep in existing_episodes
}
# Iterate through episode_dict to create/update episodes
for season, episodes in serie.episode_dict.items():
for ep_num in episodes:
key = (int(season), int(ep_num))
if key in existing_map:
# Episode exists, check if downloaded
episode = existing_map[key]
# Update if needed (e.g., file path changed)
if not episode.is_downloaded:
# Check if file exists locally
# This would be done by checking serie.local_episodes
pass
else:
# Create new episode
await EpisodeService.create(
db,
series_id=series_id,
season=int(season),
episode_number=int(ep_num),
is_downloaded=False,
)
# ============================================================================
# Example 2: Load Queue from Database
# ============================================================================
async def load_queue_from_database(
db: AsyncSession,
) -> List[dict]:
"""Load download queue from database.
Retrieves pending and active download items from database and
converts them to format suitable for DownloadService.
Args:
db: Database session
Returns:
List of download items as dictionaries
"""
logger.info("Loading download queue from database")
# Get pending and active items
pending = await DownloadQueueService.get_pending(db)
active = await DownloadQueueService.get_active(db)
all_items = pending + active
# Convert to dictionary format for DownloadService
queue_items = []
for item in all_items:
queue_items.append({
"id": item.id,
"series_id": item.series_id,
"season": item.season,
"episode_number": item.episode_number,
"status": item.status.value,
"priority": item.priority.value,
"progress_percent": item.progress_percent,
"downloaded_bytes": item.downloaded_bytes,
"total_bytes": item.total_bytes,
"download_speed": item.download_speed,
"error_message": item.error_message,
"retry_count": item.retry_count,
})
logger.info(f"Loaded {len(queue_items)} items from database")
return queue_items
# ============================================================================
# Example 3: Sync Download Progress to Database
# ============================================================================
async def sync_download_progress(
db: AsyncSession,
item_id: int,
progress_percent: float,
downloaded_bytes: int,
total_bytes: Optional[int] = None,
download_speed: Optional[float] = None,
) -> None:
"""Sync download progress to database.
Updates download queue item progress in database. This would be called
from the download progress callback.
Args:
db: Database session
item_id: Download queue item ID
progress_percent: Progress percentage (0-100)
downloaded_bytes: Bytes downloaded
total_bytes: Optional total file size
download_speed: Optional current speed (bytes/sec)
"""
await DownloadQueueService.update_progress(
db,
item_id,
progress_percent,
downloaded_bytes,
total_bytes,
download_speed,
)
await db.commit()
async def mark_download_complete(
db: AsyncSession,
item_id: int,
file_path: str,
file_size: int,
) -> None:
"""Mark download as complete in database.
Updates download queue item status and marks episode as downloaded.
Args:
db: Database session
item_id: Download queue item ID
file_path: Path to downloaded file
file_size: File size in bytes
"""
# Get download item
item = await DownloadQueueService.get_by_id(db, item_id)
if not item:
logger.error(f"Download item {item_id} not found")
return
# Update download status
await DownloadQueueService.update_status(
db,
item_id,
DownloadStatus.COMPLETED,
)
# Find or create episode and mark as downloaded
episode = await EpisodeService.get_by_episode(
db,
item.series_id,
item.season,
item.episode_number,
)
if episode:
await EpisodeService.mark_downloaded(
db,
episode.id,
file_path,
file_size,
)
else:
# Create episode
episode = await EpisodeService.create(
db,
series_id=item.series_id,
season=item.season,
episode_number=item.episode_number,
file_path=file_path,
file_size=file_size,
is_downloaded=True,
)
await db.commit()
logger.info(
f"Marked download complete: S{item.season:02d}E{item.episode_number:02d}"
)
async def mark_download_failed(
db: AsyncSession,
item_id: int,
error_message: str,
) -> None:
"""Mark download as failed in database.
Args:
db: Database session
item_id: Download queue item ID
error_message: Error description
"""
await DownloadQueueService.update_status(
db,
item_id,
DownloadStatus.FAILED,
error_message=error_message,
)
await db.commit()
# ============================================================================
# Example 4: Add Episodes to Download Queue
# ============================================================================
async def add_episodes_to_queue(
db: AsyncSession,
series_key: str,
episodes: List[tuple[int, int]], # List of (season, episode) tuples
priority: DownloadPriority = DownloadPriority.NORMAL,
) -> int:
"""Add multiple episodes to download queue.
Args:
db: Database session
series_key: Series provider key
episodes: List of (season, episode_number) tuples
priority: Download priority
Returns:
Number of episodes added to queue
"""
# Get series
series = await AnimeSeriesService.get_by_key(db, series_key)
if not series:
logger.error(f"Series not found: {series_key}")
return 0
added_count = 0
for season, episode_number in episodes:
# Check if already in queue
existing_items = await DownloadQueueService.get_all(db)
already_queued = any(
item.series_id == series.id
and item.season == season
and item.episode_number == episode_number
and item.status in (DownloadStatus.PENDING, DownloadStatus.DOWNLOADING)
for item in existing_items
)
if not already_queued:
await DownloadQueueService.create(
db,
series_id=series.id,
season=season,
episode_number=episode_number,
priority=priority,
)
added_count += 1
await db.commit()
logger.info(f"Added {added_count} episodes to download queue")
return added_count
# ============================================================================
# Example 5: Integration with AnimeService
# ============================================================================
class EnhancedAnimeService:
"""Enhanced AnimeService with database persistence.
This is an example of how to wrap the existing AnimeService with
database persistence capabilities.
"""
def __init__(self, db_session_factory):
"""Initialize enhanced anime service.
Args:
db_session_factory: Async session factory for database access
"""
self.db_session_factory = db_session_factory
async def rescan_with_persistence(self, directory: str) -> dict:
"""Rescan directory and persist results.
Args:
directory: Directory to scan
Returns:
Scan results dictionary
"""
# Import here to avoid circular dependencies
from src.core.SeriesApp import SeriesApp
# Perform scan
app = SeriesApp(directory)
series_list = app.ReScan()
# Persist to database
async with self.db_session_factory() as db:
await persist_scan_results(db, series_list)
return {
"total_series": len(series_list),
"message": "Scan completed and persisted to database",
}
async def get_series_with_missing_episodes(self) -> List[dict]:
"""Get series with missing episodes from database.
Returns:
List of series with missing episodes
"""
async with self.db_session_factory() as db:
# Get all series
all_series = await AnimeSeriesService.get_all(
db,
with_episodes=True,
)
# Filter series with missing episodes
series_with_missing = []
for series in all_series:
if series.episode_dict:
total_episodes = sum(
len(eps) for eps in series.episode_dict.values()
)
downloaded_episodes = sum(
1 for ep in series.episodes if ep.is_downloaded
)
if downloaded_episodes < total_episodes:
series_with_missing.append({
"id": series.id,
"key": series.key,
"name": series.name,
"total_episodes": total_episodes,
"downloaded_episodes": downloaded_episodes,
"missing_episodes": total_episodes - downloaded_episodes,
})
return series_with_missing
# ============================================================================
# Usage Example
# ============================================================================
async def example_usage():
"""Example usage of database service integration."""
from src.server.database import get_db_session
# Get database session
async with get_db_session() as db:
# Example 1: Add episodes to queue
added = await add_episodes_to_queue(
db,
series_key="attack-on-titan",
episodes=[(1, 1), (1, 2), (1, 3)],
priority=DownloadPriority.HIGH,
)
print(f"Added {added} episodes to queue")
# Example 2: Load queue
queue_items = await load_queue_from_database(db)
print(f"Queue has {len(queue_items)} items")
# Example 3: Update progress
if queue_items:
await sync_download_progress(
db,
item_id=queue_items[0]["id"],
progress_percent=50.0,
downloaded_bytes=500000,
total_bytes=1000000,
)
# Example 4: Mark complete
if queue_items:
await mark_download_complete(
db,
item_id=queue_items[0]["id"],
file_path="/path/to/file.mp4",
file_size=1000000,
)
if __name__ == "__main__":
import asyncio
asyncio.run(example_usage())

View File

@ -2,9 +2,12 @@
This module provides comprehensive database initialization functionality:
- Schema creation and validation
- Initial data migration
- Database health checks
- Schema versioning support
- Migration utilities
For production deployments, consider using Alembic for managed migrations.
"""
from __future__ import annotations
@ -44,7 +47,7 @@ EXPECTED_INDEXES = {
"episodes": ["ix_episodes_series_id"],
"download_queue": [
"ix_download_queue_series_id",
"ix_download_queue_episode_id",
"ix_download_queue_status",
],
"user_sessions": [
"ix_user_sessions_session_id",
@ -313,6 +316,7 @@ async def get_schema_version(engine: Optional[AsyncEngine] = None) -> str:
"""Get current database schema version.
Returns version string based on existing tables and structure.
For production, consider using Alembic versioning.
Args:
engine: Optional database engine (uses default if not provided)
@ -350,6 +354,8 @@ async def create_schema_version_table(
) -> None:
"""Create schema version tracking table.
Future enhancement for tracking schema migrations with Alembic.
Args:
engine: Optional database engine (uses default if not provided)
"""
@ -581,6 +587,60 @@ def get_database_info() -> Dict[str, Any]:
}
def get_migration_guide() -> str:
"""Get migration guide for production deployments.
Returns:
Migration guide text
"""
return """
Database Migration Guide
========================
Current Setup: SQLAlchemy create_all()
- Automatically creates tables on startup
- Suitable for development and single-instance deployments
- Schema changes require manual handling
For Production with Alembic:
============================
1. Initialize Alembic (already installed):
alembic init alembic
2. Configure alembic/env.py:
from src.server.database.base import Base
target_metadata = Base.metadata
3. Configure alembic.ini:
sqlalchemy.url = <your-database-url>
4. Generate initial migration:
alembic revision --autogenerate -m "Initial schema v1.0.0"
5. Review migration in alembic/versions/
6. Apply migration:
alembic upgrade head
7. For future schema changes:
- Modify models in src/server/database/models.py
- Generate migration: alembic revision --autogenerate -m "Description"
- Review generated migration
- Test in staging environment
- Apply: alembic upgrade head
- For rollback: alembic downgrade -1
Best Practices:
==============
- Always backup database before migrations
- Test migrations in staging first
- Review auto-generated migrations carefully
- Keep migrations in version control
- Document breaking changes
"""
# =============================================================================
# Public API
# =============================================================================
@ -596,6 +656,7 @@ __all__ = [
"check_database_health",
"create_database_backup",
"get_database_info",
"get_migration_guide",
"CURRENT_SCHEMA_VERSION",
"EXPECTED_TABLES",
]

View File

@ -0,0 +1,167 @@
"""Database migration utilities.
This module provides utilities for database migrations and schema versioning.
Alembic integration can be added when needed for production environments.
For now, we use SQLAlchemy's create_all for automatic schema creation.
"""
from __future__ import annotations
import logging
from typing import Optional
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncEngine
from src.server.database.base import Base
from src.server.database.connection import get_engine, get_sync_engine
logger = logging.getLogger(__name__)
async def initialize_schema(engine: Optional[AsyncEngine] = None) -> None:
"""Initialize database schema.
Creates all tables defined in Base metadata if they don't exist.
This is a simple migration strategy suitable for single-instance deployments.
For production with multiple instances, consider using Alembic:
- alembic init alembic
- alembic revision --autogenerate -m "Initial schema"
- alembic upgrade head
Args:
engine: Optional database engine (uses default if not provided)
Raises:
RuntimeError: If database is not initialized
"""
if engine is None:
engine = get_engine()
logger.info("Initializing database schema...")
# Create all tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info("Database schema initialized successfully")
async def check_schema_version(engine: Optional[AsyncEngine] = None) -> str:
"""Check current database schema version.
Returns a simple version identifier based on existing tables.
For production, consider using Alembic for proper versioning.
Args:
engine: Optional database engine (uses default if not provided)
Returns:
Schema version string
Raises:
RuntimeError: If database is not initialized
"""
if engine is None:
engine = get_engine()
async with engine.connect() as conn:
# Check which tables exist
result = await conn.execute(
text(
"SELECT name FROM sqlite_master "
"WHERE type='table' AND name NOT LIKE 'sqlite_%'"
)
)
tables = [row[0] for row in result]
if not tables:
return "empty"
elif len(tables) == 4 and all(
t in tables for t in [
"anime_series",
"episodes",
"download_queue",
"user_sessions",
]
):
return "v1.0"
else:
return "custom"
def get_migration_info() -> str:
"""Get information about database migration setup.
Returns:
Migration setup information
"""
return """
Database Migration Information
==============================
Current Strategy: SQLAlchemy create_all()
- Automatically creates tables on startup
- Suitable for development and single-instance deployments
- Schema changes require manual handling
For Production Migrations (Alembic):
====================================
1. Initialize Alembic:
alembic init alembic
2. Configure alembic/env.py:
- Import Base from src.server.database.base
- Set target_metadata = Base.metadata
3. Configure alembic.ini:
- Set sqlalchemy.url to your database URL
4. Generate initial migration:
alembic revision --autogenerate -m "Initial schema"
5. Apply migrations:
alembic upgrade head
6. For future changes:
- Modify models in src/server/database/models.py
- Generate migration: alembic revision --autogenerate -m "Description"
- Review generated migration in alembic/versions/
- Apply: alembic upgrade head
Benefits of Alembic:
- Version control for database schema
- Automatic migration generation from model changes
- Rollback support with downgrade scripts
- Multi-instance deployment support
- Safe schema changes in production
"""
# =============================================================================
# Future Alembic Integration
# =============================================================================
#
# When ready to use Alembic, follow these steps:
#
# 1. Install Alembic (already in requirements.txt):
# pip install alembic
#
# 2. Initialize Alembic from project root:
# alembic init alembic
#
# 3. Update alembic/env.py to use our Base:
# from src.server.database.base import Base
# target_metadata = Base.metadata
#
# 4. Configure alembic.ini with DATABASE_URL from settings
#
# 5. Generate initial migration:
# alembic revision --autogenerate -m "Initial schema"
#
# 6. Review generated migration and apply:
# alembic upgrade head
#
# =============================================================================

View File

@ -0,0 +1,236 @@
"""
Initial database schema migration.
This migration creates the base tables for the Aniworld application,
including users, anime, downloads, and configuration tables.
Version: 20250124_001
Created: 2025-01-24
"""
import logging
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from ..migrations.base import Migration, MigrationError
logger = logging.getLogger(__name__)
class InitialSchemaMigration(Migration):
"""
Creates initial database schema.
This migration sets up all core tables needed for the application:
- users: User accounts and authentication
- anime: Anime series metadata
- episodes: Episode information
- downloads: Download queue and history
- config: Application configuration
"""
def __init__(self):
"""Initialize the initial schema migration."""
super().__init__(
version="20250124_001",
description="Create initial database schema",
)
async def upgrade(self, session: AsyncSession) -> None:
"""
Create all initial tables.
Args:
session: Database session
Raises:
MigrationError: If table creation fails
"""
try:
# Create users table
await session.execute(
text(
"""
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL UNIQUE,
email TEXT,
password_hash TEXT NOT NULL,
is_active BOOLEAN DEFAULT 1,
is_admin BOOLEAN DEFAULT 0,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
)
# Create anime table
await session.execute(
text(
"""
CREATE TABLE IF NOT EXISTS anime (
id INTEGER PRIMARY KEY AUTOINCREMENT,
title TEXT NOT NULL,
original_title TEXT,
description TEXT,
genres TEXT,
release_year INTEGER,
status TEXT,
total_episodes INTEGER,
cover_image_url TEXT,
aniworld_url TEXT,
mal_id INTEGER,
anilist_id INTEGER,
added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
)
# Create episodes table
await session.execute(
text(
"""
CREATE TABLE IF NOT EXISTS episodes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
anime_id INTEGER NOT NULL,
episode_number INTEGER NOT NULL,
season_number INTEGER DEFAULT 1,
title TEXT,
description TEXT,
duration_minutes INTEGER,
air_date DATE,
stream_url TEXT,
download_url TEXT,
file_path TEXT,
file_size_bytes INTEGER,
is_downloaded BOOLEAN DEFAULT 0,
download_progress REAL DEFAULT 0.0,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (anime_id) REFERENCES anime(id)
ON DELETE CASCADE,
UNIQUE (anime_id, season_number, episode_number)
)
"""
)
)
# Create downloads table
await session.execute(
text(
"""
CREATE TABLE IF NOT EXISTS downloads (
id INTEGER PRIMARY KEY AUTOINCREMENT,
episode_id INTEGER NOT NULL,
user_id INTEGER,
status TEXT NOT NULL DEFAULT 'pending',
priority INTEGER DEFAULT 5,
progress REAL DEFAULT 0.0,
download_speed_mbps REAL,
eta_seconds INTEGER,
started_at TIMESTAMP,
completed_at TIMESTAMP,
failed_at TIMESTAMP,
error_message TEXT,
retry_count INTEGER DEFAULT 0,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (episode_id) REFERENCES episodes(id)
ON DELETE CASCADE,
FOREIGN KEY (user_id) REFERENCES users(id)
ON DELETE SET NULL
)
"""
)
)
# Create config table
await session.execute(
text(
"""
CREATE TABLE IF NOT EXISTS config (
id INTEGER PRIMARY KEY AUTOINCREMENT,
key TEXT NOT NULL UNIQUE,
value TEXT NOT NULL,
category TEXT DEFAULT 'general',
description TEXT,
is_secret BOOLEAN DEFAULT 0,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
)
# Create indexes for better performance
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_anime_title "
"ON anime(title)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_episodes_anime_id "
"ON episodes(anime_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_downloads_status "
"ON downloads(status)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS "
"idx_downloads_episode_id ON downloads(episode_id)"
)
)
logger.info("Initial schema created successfully")
except Exception as e:
logger.error(f"Failed to create initial schema: {e}")
raise MigrationError(
f"Initial schema creation failed: {e}"
) from e
async def downgrade(self, session: AsyncSession) -> None:
"""
Drop all initial tables.
Args:
session: Database session
Raises:
MigrationError: If table dropping fails
"""
try:
# Drop tables in reverse order to respect foreign keys
tables = [
"downloads",
"episodes",
"anime",
"users",
"config",
]
for table in tables:
await session.execute(text(f"DROP TABLE IF EXISTS {table}"))
logger.debug(f"Dropped table: {table}")
logger.info("Initial schema rolled back successfully")
except Exception as e:
logger.error(f"Failed to rollback initial schema: {e}")
raise MigrationError(
f"Initial schema rollback failed: {e}"
) from e

View File

@ -0,0 +1,17 @@
"""
Database migration system for Aniworld application.
This package provides tools for managing database schema changes,
including migration creation, execution, and rollback capabilities.
"""
from .base import Migration, MigrationError
from .runner import MigrationRunner
from .validator import MigrationValidator
__all__ = [
"Migration",
"MigrationError",
"MigrationRunner",
"MigrationValidator",
]

View File

@ -0,0 +1,128 @@
"""
Base migration classes and utilities.
This module provides the foundation for database migrations,
including the abstract Migration class and error handling.
"""
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession
class MigrationError(Exception):
"""Base exception for migration-related errors."""
pass
class Migration(ABC):
"""
Abstract base class for database migrations.
Each migration should inherit from this class and implement
the upgrade and downgrade methods.
Attributes:
version: Unique version identifier (e.g., "20250124_001")
description: Human-readable description of the migration
created_at: Timestamp when migration was created
"""
def __init__(
self,
version: str,
description: str,
created_at: Optional[datetime] = None,
):
"""
Initialize migration.
Args:
version: Unique version identifier
description: Human-readable description
created_at: Creation timestamp (defaults to now)
"""
self.version = version
self.description = description
self.created_at = created_at or datetime.now()
@abstractmethod
async def upgrade(self, session: AsyncSession) -> None:
"""
Apply the migration.
Args:
session: Database session for executing changes
Raises:
MigrationError: If migration fails
"""
pass
@abstractmethod
async def downgrade(self, session: AsyncSession) -> None:
"""
Revert the migration.
Args:
session: Database session for reverting changes
Raises:
MigrationError: If rollback fails
"""
pass
def __repr__(self) -> str:
"""Return string representation of migration."""
return f"Migration({self.version}: {self.description})"
def __eq__(self, other: object) -> bool:
"""Check equality based on version."""
if not isinstance(other, Migration):
return False
return self.version == other.version
def __hash__(self) -> int:
"""Return hash based on version."""
return hash(self.version)
class MigrationHistory:
"""
Tracks applied migrations in the database.
This model stores information about which migrations have been
applied, when they were applied, and their execution status.
"""
__tablename__ = "migration_history"
def __init__(
self,
version: str,
description: str,
applied_at: datetime,
execution_time_ms: int,
success: bool = True,
error_message: Optional[str] = None,
):
"""
Initialize migration history record.
Args:
version: Migration version identifier
description: Migration description
applied_at: Timestamp when migration was applied
execution_time_ms: Time taken to execute in milliseconds
success: Whether migration succeeded
error_message: Error message if migration failed
"""
self.version = version
self.description = description
self.applied_at = applied_at
self.execution_time_ms = execution_time_ms
self.success = success
self.error_message = error_message

View File

@ -0,0 +1,323 @@
"""
Migration runner for executing database migrations.
This module handles the execution of migrations in the correct order,
tracks migration history, and provides rollback capabilities.
"""
import importlib.util
import logging
import time
from datetime import datetime
from pathlib import Path
from typing import List, Optional
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from .base import Migration, MigrationError, MigrationHistory
logger = logging.getLogger(__name__)
class MigrationRunner:
"""
Manages database migration execution and tracking.
This class handles loading migrations, executing them in order,
tracking their status, and rolling back when needed.
"""
def __init__(self, migrations_dir: Path, session: AsyncSession):
"""
Initialize migration runner.
Args:
migrations_dir: Directory containing migration files
session: Database session for executing migrations
"""
self.migrations_dir = migrations_dir
self.session = session
self._migrations: List[Migration] = []
async def initialize(self) -> None:
"""
Initialize migration system by creating tracking table if needed.
Raises:
MigrationError: If initialization fails
"""
try:
# Create migration_history table if it doesn't exist
create_table_sql = """
CREATE TABLE IF NOT EXISTS migration_history (
id INTEGER PRIMARY KEY AUTOINCREMENT,
version TEXT NOT NULL UNIQUE,
description TEXT NOT NULL,
applied_at TIMESTAMP NOT NULL,
execution_time_ms INTEGER NOT NULL,
success BOOLEAN NOT NULL DEFAULT 1,
error_message TEXT
)
"""
await self.session.execute(text(create_table_sql))
await self.session.commit()
logger.info("Migration system initialized")
except Exception as e:
logger.error(f"Failed to initialize migration system: {e}")
raise MigrationError(f"Initialization failed: {e}") from e
def load_migrations(self) -> None:
"""
Load all migration files from the migrations directory.
Migration files should be named in format: {version}_{description}.py
and contain a Migration class that inherits from base.Migration.
Raises:
MigrationError: If loading migrations fails
"""
try:
self._migrations.clear()
if not self.migrations_dir.exists():
logger.warning(f"Migrations directory does not exist: {self.migrations_dir}")
return
# Find all Python files in migrations directory
migration_files = sorted(self.migrations_dir.glob("*.py"))
migration_files = [f for f in migration_files if f.name != "__init__.py"]
for file_path in migration_files:
try:
# Import the migration module dynamically
spec = importlib.util.spec_from_file_location(
f"migration.{file_path.stem}", file_path
)
if spec and spec.loader:
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# Find Migration subclass in module
for attr_name in dir(module):
attr = getattr(module, attr_name)
if (
isinstance(attr, type)
and issubclass(attr, Migration)
and attr != Migration
):
migration_instance = attr()
self._migrations.append(migration_instance)
logger.debug(f"Loaded migration: {migration_instance.version}")
break
except Exception as e:
logger.error(f"Failed to load migration {file_path.name}: {e}")
raise MigrationError(f"Failed to load {file_path.name}: {e}") from e
# Sort migrations by version
self._migrations.sort(key=lambda m: m.version)
logger.info(f"Loaded {len(self._migrations)} migrations")
except Exception as e:
logger.error(f"Failed to load migrations: {e}")
raise MigrationError(f"Loading migrations failed: {e}") from e
async def get_applied_migrations(self) -> List[str]:
"""
Get list of already applied migration versions.
Returns:
List of migration versions that have been applied
Raises:
MigrationError: If query fails
"""
try:
result = await self.session.execute(
text("SELECT version FROM migration_history WHERE success = 1 ORDER BY version")
)
versions = [row[0] for row in result.fetchall()]
return versions
except Exception as e:
logger.error(f"Failed to get applied migrations: {e}")
raise MigrationError(f"Query failed: {e}") from e
async def get_pending_migrations(self) -> List[Migration]:
"""
Get list of migrations that haven't been applied yet.
Returns:
List of pending Migration objects
Raises:
MigrationError: If check fails
"""
applied = await self.get_applied_migrations()
pending = [m for m in self._migrations if m.version not in applied]
return pending
async def apply_migration(self, migration: Migration) -> None:
"""
Apply a single migration.
Args:
migration: Migration to apply
Raises:
MigrationError: If migration fails
"""
start_time = time.time()
success = False
error_message = None
try:
logger.info(f"Applying migration: {migration.version} - {migration.description}")
# Execute the migration
await migration.upgrade(self.session)
await self.session.commit()
success = True
execution_time_ms = int((time.time() - start_time) * 1000)
logger.info(
f"Migration {migration.version} applied successfully in {execution_time_ms}ms"
)
except Exception as e:
error_message = str(e)
execution_time_ms = int((time.time() - start_time) * 1000)
logger.error(f"Migration {migration.version} failed: {e}")
await self.session.rollback()
raise MigrationError(f"Migration {migration.version} failed: {e}") from e
finally:
# Record migration in history
try:
history_record = MigrationHistory(
version=migration.version,
description=migration.description,
applied_at=datetime.now(),
execution_time_ms=execution_time_ms,
success=success,
error_message=error_message,
)
insert_sql = """
INSERT INTO migration_history
(version, description, applied_at, execution_time_ms, success, error_message)
VALUES (:version, :description, :applied_at, :execution_time_ms, :success, :error_message)
"""
await self.session.execute(
text(insert_sql),
{
"version": history_record.version,
"description": history_record.description,
"applied_at": history_record.applied_at,
"execution_time_ms": history_record.execution_time_ms,
"success": history_record.success,
"error_message": history_record.error_message,
},
)
await self.session.commit()
except Exception as e:
logger.error(f"Failed to record migration history: {e}")
async def run_migrations(self, target_version: Optional[str] = None) -> int:
"""
Run all pending migrations up to target version.
Args:
target_version: Stop at this version (None = run all)
Returns:
Number of migrations applied
Raises:
MigrationError: If migrations fail
"""
pending = await self.get_pending_migrations()
if target_version:
pending = [m for m in pending if m.version <= target_version]
if not pending:
logger.info("No pending migrations to apply")
return 0
logger.info(f"Applying {len(pending)} pending migrations")
for migration in pending:
await self.apply_migration(migration)
return len(pending)
async def rollback_migration(self, migration: Migration) -> None:
"""
Rollback a single migration.
Args:
migration: Migration to rollback
Raises:
MigrationError: If rollback fails
"""
start_time = time.time()
try:
logger.info(f"Rolling back migration: {migration.version}")
# Execute the downgrade
await migration.downgrade(self.session)
await self.session.commit()
execution_time_ms = int((time.time() - start_time) * 1000)
# Remove from history
delete_sql = "DELETE FROM migration_history WHERE version = :version"
await self.session.execute(text(delete_sql), {"version": migration.version})
await self.session.commit()
logger.info(
f"Migration {migration.version} rolled back successfully in {execution_time_ms}ms"
)
except Exception as e:
logger.error(f"Rollback of {migration.version} failed: {e}")
await self.session.rollback()
raise MigrationError(f"Rollback of {migration.version} failed: {e}") from e
async def rollback(self, steps: int = 1) -> int:
"""
Rollback the last N migrations.
Args:
steps: Number of migrations to rollback
Returns:
Number of migrations rolled back
Raises:
MigrationError: If rollback fails
"""
applied = await self.get_applied_migrations()
if not applied:
logger.info("No migrations to rollback")
return 0
# Get migrations to rollback (in reverse order)
to_rollback = applied[-steps:]
to_rollback.reverse()
migrations_to_rollback = [m for m in self._migrations if m.version in to_rollback]
logger.info(f"Rolling back {len(migrations_to_rollback)} migrations")
for migration in migrations_to_rollback:
await self.rollback_migration(migration)
return len(migrations_to_rollback)

View File

@ -0,0 +1,222 @@
"""
Migration validator for ensuring migration safety and integrity.
This module provides validation utilities to check migrations
before they are executed, ensuring they meet quality standards.
"""
import logging
from typing import List, Optional, Set
from .base import Migration, MigrationError
logger = logging.getLogger(__name__)
class MigrationValidator:
"""
Validates migrations before execution.
Performs various checks to ensure migrations are safe to run,
including version uniqueness, naming conventions, and
dependency resolution.
"""
def __init__(self):
"""Initialize migration validator."""
self.errors: List[str] = []
self.warnings: List[str] = []
def reset(self) -> None:
"""Clear validation results."""
self.errors.clear()
self.warnings.clear()
def validate_migration(self, migration: Migration) -> bool:
"""
Validate a single migration.
Args:
migration: Migration to validate
Returns:
True if migration is valid, False otherwise
"""
self.reset()
# Check version format
if not self._validate_version_format(migration.version):
self.errors.append(
f"Invalid version format: {migration.version}. "
"Expected format: YYYYMMDD_NNN"
)
# Check description
if not migration.description or len(migration.description) < 5:
self.errors.append(
f"Migration {migration.version} has invalid "
f"description: '{migration.description}'"
)
# Check for implementation
if not hasattr(migration, "upgrade") or not callable(
getattr(migration, "upgrade")
):
self.errors.append(
f"Migration {migration.version} missing upgrade method"
)
if not hasattr(migration, "downgrade") or not callable(
getattr(migration, "downgrade")
):
self.errors.append(
f"Migration {migration.version} missing downgrade method"
)
return len(self.errors) == 0
def validate_migrations(self, migrations: List[Migration]) -> bool:
"""
Validate a list of migrations.
Args:
migrations: List of migrations to validate
Returns:
True if all migrations are valid, False otherwise
"""
self.reset()
if not migrations:
self.warnings.append("No migrations to validate")
return True
# Check for duplicate versions
versions: Set[str] = set()
for migration in migrations:
if migration.version in versions:
self.errors.append(
f"Duplicate migration version: {migration.version}"
)
versions.add(migration.version)
# Return early if duplicates found
if self.errors:
return False
# Validate each migration
for migration in migrations:
if not self.validate_migration(migration):
logger.error(
f"Migration {migration.version} "
f"validation failed: {self.errors}"
)
return False
# Check version ordering
sorted_versions = sorted([m.version for m in migrations])
actual_versions = [m.version for m in migrations]
if sorted_versions != actual_versions:
self.warnings.append(
"Migrations are not in chronological order"
)
return len(self.errors) == 0
def _validate_version_format(self, version: str) -> bool:
"""
Validate version string format.
Args:
version: Version string to validate
Returns:
True if format is valid
"""
# Expected format: YYYYMMDD_NNN or YYYYMMDD_NNN_description
if not version:
return False
parts = version.split("_")
if len(parts) < 2:
return False
# Check date part (YYYYMMDD)
date_part = parts[0]
if len(date_part) != 8 or not date_part.isdigit():
return False
# Check sequence part (NNN)
seq_part = parts[1]
if not seq_part.isdigit():
return False
return True
def check_migration_conflicts(
self,
pending: List[Migration],
applied: List[str],
) -> Optional[str]:
"""
Check for conflicts between pending and applied migrations.
Args:
pending: List of pending migrations
applied: List of applied migration versions
Returns:
Error message if conflicts found, None otherwise
"""
# Check if any pending migration has version lower than applied
if not applied:
return None
latest_applied = max(applied)
for migration in pending:
if migration.version < latest_applied:
return (
f"Migration {migration.version} is older than "
f"latest applied migration {latest_applied}. "
"This may indicate a merge conflict."
)
return None
def get_validation_report(self) -> str:
"""
Get formatted validation report.
Returns:
Formatted report string
"""
report = []
if self.errors:
report.append("Validation Errors:")
for error in self.errors:
report.append(f" - {error}")
if self.warnings:
report.append("Validation Warnings:")
for warning in self.warnings:
report.append(f" - {warning}")
if not self.errors and not self.warnings:
report.append("All validations passed")
return "\n".join(report)
def raise_if_invalid(self) -> None:
"""
Raise exception if validation failed.
Raises:
MigrationError: If validation errors exist
"""
if self.errors:
error_msg = "\n".join(self.errors)
raise MigrationError(
f"Migration validation failed:\n{error_msg}"
)

View File

@ -15,7 +15,18 @@ from datetime import datetime, timezone
from enum import Enum
from typing import List, Optional
from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, String, Text, func
from sqlalchemy import (
JSON,
Boolean,
DateTime,
Float,
ForeignKey,
Integer,
String,
Text,
func,
)
from sqlalchemy import Enum as SQLEnum
from sqlalchemy.orm import Mapped, mapped_column, relationship, validates
from src.server.database.base import Base, TimestampMixin
@ -40,6 +51,10 @@ class AnimeSeries(Base, TimestampMixin):
name: Display name of the series
site: Provider site URL
folder: Filesystem folder name (metadata only, not for lookups)
description: Optional series description
status: Current status (ongoing, completed, etc.)
total_episodes: Total number of episodes
cover_url: URL to series cover image
episodes: Relationship to Episode models (via id foreign key)
download_items: Relationship to DownloadQueueItem models (via id foreign key)
created_at: Creation timestamp (from TimestampMixin)
@ -74,6 +89,30 @@ class AnimeSeries(Base, TimestampMixin):
doc="Filesystem folder name - METADATA ONLY, not for lookups"
)
# Metadata
description: Mapped[Optional[str]] = mapped_column(
Text, nullable=True,
doc="Series description"
)
status: Mapped[Optional[str]] = mapped_column(
String(50), nullable=True,
doc="Series status (ongoing, completed, etc.)"
)
total_episodes: Mapped[Optional[int]] = mapped_column(
Integer, nullable=True,
doc="Total number of episodes"
)
cover_url: Mapped[Optional[str]] = mapped_column(
String(1000), nullable=True,
doc="URL to cover image"
)
# JSON field for episode dictionary (season -> [episodes])
episode_dict: Mapped[Optional[dict]] = mapped_column(
JSON, nullable=True,
doc="Episode dictionary {season: [episodes]}"
)
# Relationships
episodes: Mapped[List["Episode"]] = relationship(
"Episode",
@ -122,6 +161,22 @@ class AnimeSeries(Base, TimestampMixin):
raise ValueError("Folder path must be 1000 characters or less")
return value.strip()
@validates('cover_url')
def validate_cover_url(self, key: str, value: Optional[str]) -> Optional[str]:
"""Validate cover URL length."""
if value is not None and len(value) > 1000:
raise ValueError("Cover URL must be 1000 characters or less")
return value
@validates('total_episodes')
def validate_total_episodes(self, key: str, value: Optional[int]) -> Optional[int]:
"""Validate total episodes is positive."""
if value is not None and value < 0:
raise ValueError("Total episodes must be non-negative")
if value is not None and value > 10000:
raise ValueError("Total episodes must be 10000 or less")
return value
def __repr__(self) -> str:
return f"<AnimeSeries(id={self.id}, key='{self.key}', name='{self.name}')>"
@ -139,7 +194,9 @@ class Episode(Base, TimestampMixin):
episode_number: Episode number within season
title: Episode title
file_path: Local file path if downloaded
file_size: File size in bytes
is_downloaded: Whether episode is downloaded
download_date: When episode was downloaded
series: Relationship to AnimeSeries
created_at: Creation timestamp (from TimestampMixin)
updated_at: Last update timestamp (from TimestampMixin)
@ -177,10 +234,18 @@ class Episode(Base, TimestampMixin):
String(1000), nullable=True,
doc="Local file path"
)
file_size: Mapped[Optional[int]] = mapped_column(
Integer, nullable=True,
doc="File size in bytes"
)
is_downloaded: Mapped[bool] = mapped_column(
Boolean, default=False, nullable=False,
doc="Whether episode is downloaded"
)
download_date: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True), nullable=True,
doc="When episode was downloaded"
)
# Relationship
series: Mapped["AnimeSeries"] = relationship(
@ -222,6 +287,13 @@ class Episode(Base, TimestampMixin):
raise ValueError("File path must be 1000 characters or less")
return value
@validates('file_size')
def validate_file_size(self, key: str, value: Optional[int]) -> Optional[int]:
"""Validate file size is non-negative."""
if value is not None and value < 0:
raise ValueError("File size must be non-negative")
return value
def __repr__(self) -> str:
return (
f"<Episode(id={self.id}, series_id={self.series_id}, "
@ -249,20 +321,27 @@ class DownloadPriority(str, Enum):
class DownloadQueueItem(Base, TimestampMixin):
"""SQLAlchemy model for download queue items.
Tracks download queue with error information.
Tracks download queue with status, progress, and error information.
Provides persistence for the DownloadService queue state.
Attributes:
id: Primary key
series_id: Foreign key to AnimeSeries
episode_id: Foreign key to Episode
season: Season number
episode_number: Episode number
status: Current download status
priority: Download priority
progress_percent: Download progress (0-100)
downloaded_bytes: Bytes downloaded
total_bytes: Total file size
download_speed: Current speed in bytes/sec
error_message: Error description if failed
retry_count: Number of retry attempts
download_url: Provider download URL
file_destination: Target file path
started_at: When download started
completed_at: When download completed
series: Relationship to AnimeSeries
episode: Relationship to Episode
created_at: Creation timestamp (from TimestampMixin)
updated_at: Last update timestamp (from TimestampMixin)
"""
@ -280,11 +359,47 @@ class DownloadQueueItem(Base, TimestampMixin):
index=True
)
# Foreign key to episode
episode_id: Mapped[int] = mapped_column(
ForeignKey("episodes.id", ondelete="CASCADE"),
# Episode identification
season: Mapped[int] = mapped_column(
Integer, nullable=False,
doc="Season number"
)
episode_number: Mapped[int] = mapped_column(
Integer, nullable=False,
doc="Episode number"
)
# Queue management
status: Mapped[str] = mapped_column(
SQLEnum(DownloadStatus),
default=DownloadStatus.PENDING,
nullable=False,
index=True
index=True,
doc="Current download status"
)
priority: Mapped[str] = mapped_column(
SQLEnum(DownloadPriority),
default=DownloadPriority.NORMAL,
nullable=False,
doc="Download priority"
)
# Progress tracking
progress_percent: Mapped[float] = mapped_column(
Float, default=0.0, nullable=False,
doc="Progress percentage (0-100)"
)
downloaded_bytes: Mapped[int] = mapped_column(
Integer, default=0, nullable=False,
doc="Bytes downloaded"
)
total_bytes: Mapped[Optional[int]] = mapped_column(
Integer, nullable=True,
doc="Total file size"
)
download_speed: Mapped[Optional[float]] = mapped_column(
Float, nullable=True,
doc="Current download speed (bytes/sec)"
)
# Error handling
@ -292,6 +407,10 @@ class DownloadQueueItem(Base, TimestampMixin):
Text, nullable=True,
doc="Error description"
)
retry_count: Mapped[int] = mapped_column(
Integer, default=0, nullable=False,
doc="Number of retry attempts"
)
# Download details
download_url: Mapped[Optional[str]] = mapped_column(
@ -318,9 +437,67 @@ class DownloadQueueItem(Base, TimestampMixin):
"AnimeSeries",
back_populates="download_items"
)
episode: Mapped["Episode"] = relationship(
"Episode"
)
@validates('season')
def validate_season(self, key: str, value: int) -> int:
"""Validate season number is positive."""
if value < 0:
raise ValueError("Season number must be non-negative")
if value > 1000:
raise ValueError("Season number must be 1000 or less")
return value
@validates('episode_number')
def validate_episode_number(self, key: str, value: int) -> int:
"""Validate episode number is positive."""
if value < 0:
raise ValueError("Episode number must be non-negative")
if value > 10000:
raise ValueError("Episode number must be 10000 or less")
return value
@validates('progress_percent')
def validate_progress_percent(self, key: str, value: float) -> float:
"""Validate progress is between 0 and 100."""
if value < 0.0:
raise ValueError("Progress percent must be non-negative")
if value > 100.0:
raise ValueError("Progress percent cannot exceed 100")
return value
@validates('downloaded_bytes')
def validate_downloaded_bytes(self, key: str, value: int) -> int:
"""Validate downloaded bytes is non-negative."""
if value < 0:
raise ValueError("Downloaded bytes must be non-negative")
return value
@validates('total_bytes')
def validate_total_bytes(
self, key: str, value: Optional[int]
) -> Optional[int]:
"""Validate total bytes is non-negative."""
if value is not None and value < 0:
raise ValueError("Total bytes must be non-negative")
return value
@validates('download_speed')
def validate_download_speed(
self, key: str, value: Optional[float]
) -> Optional[float]:
"""Validate download speed is non-negative."""
if value is not None and value < 0.0:
raise ValueError("Download speed must be non-negative")
return value
@validates('retry_count')
def validate_retry_count(self, key: str, value: int) -> int:
"""Validate retry count is non-negative."""
if value < 0:
raise ValueError("Retry count must be non-negative")
if value > 100:
raise ValueError("Retry count cannot exceed 100")
return value
@validates('download_url')
def validate_download_url(
@ -346,7 +523,8 @@ class DownloadQueueItem(Base, TimestampMixin):
return (
f"<DownloadQueueItem(id={self.id}, "
f"series_id={self.series_id}, "
f"episode_id={self.episode_id})>"
f"S{self.season:02d}E{self.episode_number:02d}, "
f"status={self.status})>"
)

View File

@ -15,7 +15,7 @@ from __future__ import annotations
import logging
from datetime import datetime, timedelta, timezone
from typing import List, Optional
from typing import Dict, List, Optional
from sqlalchemy import delete, select, update
from sqlalchemy.ext.asyncio import AsyncSession
@ -23,7 +23,9 @@ from sqlalchemy.orm import Session, selectinload
from src.server.database.models import (
AnimeSeries,
DownloadPriority,
DownloadQueueItem,
DownloadStatus,
Episode,
UserSession,
)
@ -55,6 +57,11 @@ class AnimeSeriesService:
name: str,
site: str,
folder: str,
description: Optional[str] = None,
status: Optional[str] = None,
total_episodes: Optional[int] = None,
cover_url: Optional[str] = None,
episode_dict: Optional[Dict] = None,
) -> AnimeSeries:
"""Create a new anime series.
@ -64,6 +71,11 @@ class AnimeSeriesService:
name: Series name
site: Provider site URL
folder: Local filesystem path
description: Optional series description
status: Optional series status
total_episodes: Optional total episode count
cover_url: Optional cover image URL
episode_dict: Optional episode dictionary
Returns:
Created AnimeSeries instance
@ -76,6 +88,11 @@ class AnimeSeriesService:
name=name,
site=site,
folder=folder,
description=description,
status=status,
total_episodes=total_episodes,
cover_url=cover_url,
episode_dict=episode_dict,
)
db.add(series)
await db.flush()
@ -245,6 +262,7 @@ class EpisodeService:
episode_number: int,
title: Optional[str] = None,
file_path: Optional[str] = None,
file_size: Optional[int] = None,
is_downloaded: bool = False,
) -> Episode:
"""Create a new episode.
@ -256,6 +274,7 @@ class EpisodeService:
episode_number: Episode number within season
title: Optional episode title
file_path: Optional local file path
file_size: Optional file size in bytes
is_downloaded: Whether episode is downloaded
Returns:
@ -267,7 +286,9 @@ class EpisodeService:
episode_number=episode_number,
title=title,
file_path=file_path,
file_size=file_size,
is_downloaded=is_downloaded,
download_date=datetime.now(timezone.utc) if is_downloaded else None,
)
db.add(episode)
await db.flush()
@ -351,6 +372,7 @@ class EpisodeService:
db: AsyncSession,
episode_id: int,
file_path: str,
file_size: int,
) -> Optional[Episode]:
"""Mark episode as downloaded.
@ -358,6 +380,7 @@ class EpisodeService:
db: Database session
episode_id: Episode primary key
file_path: Local file path
file_size: File size in bytes
Returns:
Updated Episode instance or None if not found
@ -368,6 +391,8 @@ class EpisodeService:
episode.is_downloaded = True
episode.file_path = file_path
episode.file_size = file_size
episode.download_date = datetime.now(timezone.utc)
await db.flush()
await db.refresh(episode)
@ -402,14 +427,17 @@ class EpisodeService:
class DownloadQueueService:
"""Service for download queue CRUD operations.
Provides methods for managing the download queue.
Provides methods for managing the download queue with status tracking,
priority management, and progress updates.
"""
@staticmethod
async def create(
db: AsyncSession,
series_id: int,
episode_id: int,
season: int,
episode_number: int,
priority: DownloadPriority = DownloadPriority.NORMAL,
download_url: Optional[str] = None,
file_destination: Optional[str] = None,
) -> DownloadQueueItem:
@ -418,7 +446,9 @@ class DownloadQueueService:
Args:
db: Database session
series_id: Foreign key to AnimeSeries
episode_id: Foreign key to Episode
season: Season number
episode_number: Episode number
priority: Download priority
download_url: Optional provider download URL
file_destination: Optional target file path
@ -427,7 +457,10 @@ class DownloadQueueService:
"""
item = DownloadQueueItem(
series_id=series_id,
episode_id=episode_id,
season=season,
episode_number=episode_number,
status=DownloadStatus.PENDING,
priority=priority,
download_url=download_url,
file_destination=file_destination,
)
@ -435,8 +468,8 @@ class DownloadQueueService:
await db.flush()
await db.refresh(item)
logger.info(
f"Added to download queue: episode_id={episode_id} "
f"for series_id={series_id}"
f"Added to download queue: S{season:02d}E{episode_number:02d} "
f"for series_id={series_id} with priority={priority}"
)
return item
@ -460,25 +493,68 @@ class DownloadQueueService:
return result.scalar_one_or_none()
@staticmethod
async def get_by_episode(
async def get_by_status(
db: AsyncSession,
episode_id: int,
) -> Optional[DownloadQueueItem]:
"""Get download queue item by episode ID.
status: DownloadStatus,
limit: Optional[int] = None,
) -> List[DownloadQueueItem]:
"""Get download queue items by status.
Args:
db: Database session
episode_id: Foreign key to Episode
status: Download status filter
limit: Optional limit for results
Returns:
DownloadQueueItem instance or None if not found
List of DownloadQueueItem instances
"""
result = await db.execute(
select(DownloadQueueItem).where(
DownloadQueueItem.episode_id == episode_id
)
query = select(DownloadQueueItem).where(
DownloadQueueItem.status == status
)
# Order by priority (HIGH first) then creation time
query = query.order_by(
DownloadQueueItem.priority.desc(),
DownloadQueueItem.created_at.asc(),
)
if limit:
query = query.limit(limit)
result = await db.execute(query)
return list(result.scalars().all())
@staticmethod
async def get_pending(
db: AsyncSession,
limit: Optional[int] = None,
) -> List[DownloadQueueItem]:
"""Get pending download queue items.
Args:
db: Database session
limit: Optional limit for results
Returns:
List of pending DownloadQueueItem instances ordered by priority
"""
return await DownloadQueueService.get_by_status(
db, DownloadStatus.PENDING, limit
)
@staticmethod
async def get_active(db: AsyncSession) -> List[DownloadQueueItem]:
"""Get active download queue items.
Args:
db: Database session
Returns:
List of downloading DownloadQueueItem instances
"""
return await DownloadQueueService.get_by_status(
db, DownloadStatus.DOWNLOADING
)
return result.scalar_one_or_none()
@staticmethod
async def get_all(
@ -500,6 +576,7 @@ class DownloadQueueService:
query = query.options(selectinload(DownloadQueueItem.series))
query = query.order_by(
DownloadQueueItem.priority.desc(),
DownloadQueueItem.created_at.asc(),
)
@ -507,17 +584,19 @@ class DownloadQueueService:
return list(result.scalars().all())
@staticmethod
async def set_error(
async def update_status(
db: AsyncSession,
item_id: int,
error_message: str,
status: DownloadStatus,
error_message: Optional[str] = None,
) -> Optional[DownloadQueueItem]:
"""Set error message on download queue item.
"""Update download queue item status.
Args:
db: Database session
item_id: Item primary key
error_message: Error description
status: New download status
error_message: Optional error message for failed status
Returns:
Updated DownloadQueueItem instance or None if not found
@ -526,11 +605,61 @@ class DownloadQueueService:
if not item:
return None
item.error_message = error_message
item.status = status
# Update timestamps based on status
if status == DownloadStatus.DOWNLOADING and not item.started_at:
item.started_at = datetime.now(timezone.utc)
elif status in (DownloadStatus.COMPLETED, DownloadStatus.FAILED):
item.completed_at = datetime.now(timezone.utc)
# Set error message for failed downloads
if status == DownloadStatus.FAILED and error_message:
item.error_message = error_message
item.retry_count += 1
await db.flush()
await db.refresh(item)
logger.debug(f"Updated download queue item {item_id} status to {status}")
return item
@staticmethod
async def update_progress(
db: AsyncSession,
item_id: int,
progress_percent: float,
downloaded_bytes: int,
total_bytes: Optional[int] = None,
download_speed: Optional[float] = None,
) -> Optional[DownloadQueueItem]:
"""Update download progress.
Args:
db: Database session
item_id: Item primary key
progress_percent: Progress percentage (0-100)
downloaded_bytes: Bytes downloaded
total_bytes: Optional total file size
download_speed: Optional current speed (bytes/sec)
Returns:
Updated DownloadQueueItem instance or None if not found
"""
item = await DownloadQueueService.get_by_id(db, item_id)
if not item:
return None
item.progress_percent = progress_percent
item.downloaded_bytes = downloaded_bytes
if total_bytes is not None:
item.total_bytes = total_bytes
if download_speed is not None:
item.download_speed = download_speed
await db.flush()
await db.refresh(item)
logger.debug(f"Set error on download queue item {item_id}")
return item
@staticmethod
@ -553,30 +682,57 @@ class DownloadQueueService:
return deleted
@staticmethod
async def delete_by_episode(
db: AsyncSession,
episode_id: int,
) -> bool:
"""Delete download queue item by episode ID.
async def clear_completed(db: AsyncSession) -> int:
"""Clear completed downloads from queue.
Args:
db: Database session
episode_id: Foreign key to Episode
Returns:
True if deleted, False if not found
Number of items cleared
"""
result = await db.execute(
delete(DownloadQueueItem).where(
DownloadQueueItem.episode_id == episode_id
DownloadQueueItem.status == DownloadStatus.COMPLETED
)
)
deleted = result.rowcount > 0
if deleted:
logger.info(
f"Deleted download queue item with episode_id={episode_id}"
count = result.rowcount
logger.info(f"Cleared {count} completed downloads from queue")
return count
@staticmethod
async def retry_failed(
db: AsyncSession,
max_retries: int = 3,
) -> List[DownloadQueueItem]:
"""Retry failed downloads that haven't exceeded max retries.
Args:
db: Database session
max_retries: Maximum number of retry attempts
Returns:
List of items marked for retry
"""
result = await db.execute(
select(DownloadQueueItem).where(
DownloadQueueItem.status == DownloadStatus.FAILED,
DownloadQueueItem.retry_count < max_retries,
)
return deleted
)
items = list(result.scalars().all())
for item in items:
item.status = DownloadStatus.PENDING
item.error_message = None
item.progress_percent = 0.0
item.downloaded_bytes = 0
item.started_at = None
item.completed_at = None
await db.flush()
logger.info(f"Marked {len(items)} failed downloads for retry")
return items
# ============================================================================

View File

@ -51,15 +51,6 @@ async def lifespan(app: FastAPI):
try:
logger.info("Starting FastAPI application...")
# Initialize database first (required for other services)
try:
from src.server.database.connection import init_db
await init_db()
logger.info("Database initialized successfully")
except Exception as e:
logger.error("Failed to initialize database: %s", e, exc_info=True)
raise # Database is required, fail startup if it fails
# Load configuration from config.json and sync with settings
try:
from src.server.services.config_service import get_config_service
@ -95,24 +86,6 @@ async def lifespan(app: FastAPI):
# Subscribe to progress events
progress_service.subscribe("progress_updated", progress_event_handler)
# Initialize download service and restore queue from database
# Only if anime directory is configured
try:
from src.server.utils.dependencies import get_download_service
if settings.anime_directory:
download_service = get_download_service()
await download_service.initialize()
logger.info("Download service initialized and queue restored")
else:
logger.info(
"Download service initialization skipped - "
"anime directory not configured"
)
except Exception as e:
logger.warning("Failed to initialize download service: %s", e)
# Continue startup - download service can be initialized later
logger.info("FastAPI application started successfully")
logger.info("Server running on http://127.0.0.1:8000")
logger.info(
@ -138,14 +111,6 @@ async def lifespan(app: FastAPI):
except Exception as e:
logger.error("Error stopping download service: %s", e, exc_info=True)
# Close database connections
try:
from src.server.database.connection import close_db
await close_db()
logger.info("Database connections closed")
except Exception as e:
logger.error("Error closing database: %s", e, exc_info=True)
logger.info("FastAPI application shutdown complete")

View File

@ -70,6 +70,8 @@ class AnimeSeriesResponse(BaseModel):
)
)
alt_titles: List[str] = Field(default_factory=list, description="Alternative titles")
description: Optional[str] = Field(None, description="Short series description")
total_episodes: Optional[int] = Field(None, ge=0, description="Declared total episode count if known")
episodes: List[EpisodeInfo] = Field(default_factory=list, description="Known episodes information")
missing_episodes: List[MissingEpisodeInfo] = Field(default_factory=list, description="Detected missing episode ranges")
thumbnail: Optional[HttpUrl] = Field(None, description="Optional thumbnail image URL")

View File

@ -222,7 +222,7 @@ class AnimeService:
loop
)
except Exception as exc:
logger.error("Error handling scan status event: %s", exc)
logger.error("Error handling scan status event", error=str(exc))
@lru_cache(maxsize=128)
def _cached_list_missing(self) -> list[dict]:

View File

@ -4,7 +4,7 @@ This service handles:
- Loading and saving configuration to JSON files
- Configuration validation
- Backup and restore functionality
- Configuration version management
- Configuration migration for version updates
"""
import json
@ -35,8 +35,8 @@ class ConfigBackupError(ConfigServiceError):
class ConfigService:
"""Service for managing application configuration persistence.
Handles loading, saving, validation, backup, and version management
of configuration files. Uses JSON format for human-readable and
Handles loading, saving, validation, backup, and migration of
configuration files. Uses JSON format for human-readable and
version-control friendly storage.
"""
@ -84,6 +84,11 @@ class ConfigService:
with open(self.config_path, "r", encoding="utf-8") as f:
data = json.load(f)
# Check if migration is needed
file_version = data.get("version", "1.0.0")
if file_version != self.CONFIG_VERSION:
data = self._migrate_config(data, file_version)
# Remove version key before constructing AppConfig
data.pop("version", None)
@ -323,6 +328,26 @@ class ConfigService:
except (OSError, IOError):
# Ignore errors during cleanup
continue
def _migrate_config(
self, data: Dict, from_version: str # noqa: ARG002
) -> Dict:
"""Migrate configuration from old version to current.
Args:
data: Configuration data to migrate
from_version: Version to migrate from (reserved for future use)
Returns:
Dict: Migrated configuration data
"""
# Currently only one version exists
# Future migrations would go here
# Example:
# if from_version == "1.0.0" and self.CONFIG_VERSION == "2.0.0":
# data = self._migrate_1_0_to_2_0(data)
return data
# Singleton instance

View File

@ -2,19 +2,18 @@
This module provides a simplified queue management system for handling
anime episode downloads with manual start/stop controls, progress tracking,
database persistence, and retry functionality.
The service uses SQLite database for persistent storage via QueueRepository
while maintaining an in-memory cache for performance.
persistence, and retry functionality.
"""
from __future__ import annotations
import asyncio
import json
import uuid
from collections import deque
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Dict, List, Optional
from pathlib import Path
from typing import Dict, List, Optional
import structlog
@ -29,9 +28,6 @@ from src.server.models.download import (
from src.server.services.anime_service import AnimeService, AnimeServiceError
from src.server.services.progress_service import ProgressService, get_progress_service
if TYPE_CHECKING:
from src.server.services.queue_repository import QueueRepository
logger = structlog.get_logger(__name__)
@ -46,7 +42,7 @@ class DownloadService:
- Manual download start/stop
- FIFO queue processing
- Real-time progress tracking
- Database persistence via QueueRepository
- Queue persistence and recovery
- Automatic retry logic
- WebSocket broadcast support
"""
@ -54,28 +50,24 @@ class DownloadService:
def __init__(
self,
anime_service: AnimeService,
queue_repository: Optional["QueueRepository"] = None,
max_retries: int = 3,
persistence_path: str = "./data/download_queue.json",
progress_service: Optional[ProgressService] = None,
):
"""Initialize the download service.
Args:
anime_service: Service for anime operations
queue_repository: Optional repository for database persistence.
If not provided, will use default singleton.
max_retries: Maximum retry attempts for failed downloads
persistence_path: Path to persist queue state
progress_service: Optional progress service for tracking
"""
self._anime_service = anime_service
self._max_retries = max_retries
self._persistence_path = Path(persistence_path)
self._progress_service = progress_service or get_progress_service()
# Database repository for persistence
self._queue_repository = queue_repository
self._db_initialized = False
# In-memory cache for performance (synced with database)
# Queue storage by status
self._pending_queue: deque[DownloadItem] = deque()
# Helper dict for O(1) lookup of pending items by ID
self._pending_items_by_id: Dict[str, DownloadItem] = {}
@ -100,108 +92,14 @@ class DownloadService:
# Track if queue progress has been initialized
self._queue_progress_initialized: bool = False
# Load persisted queue
self._load_queue()
logger.info(
"DownloadService initialized",
max_retries=max_retries,
)
def _get_repository(self) -> "QueueRepository":
"""Get the queue repository, initializing if needed.
Returns:
QueueRepository instance
"""
if self._queue_repository is None:
from src.server.services.queue_repository import get_queue_repository
self._queue_repository = get_queue_repository()
return self._queue_repository
async def initialize(self) -> None:
"""Initialize the service by loading queue state from database.
Should be called after database is initialized during app startup.
Note: With the simplified model, status/priority/progress are now
managed in-memory only. The database stores the queue items
for persistence across restarts.
"""
if self._db_initialized:
return
try:
repository = self._get_repository()
# Load all items from database - they all start as PENDING
# since status is now managed in-memory only
all_items = await repository.get_all_items()
for item in all_items:
# All items from database are treated as pending
item.status = DownloadStatus.PENDING
self._add_to_pending_queue(item)
self._db_initialized = True
logger.info(
"Queue restored from database: pending_count=%d",
len(self._pending_queue),
)
except Exception as e:
logger.error("Failed to load queue from database: %s", e, exc_info=True)
# Continue without persistence - queue will work in memory only
self._db_initialized = True
async def _save_to_database(self, item: DownloadItem) -> DownloadItem:
"""Save or update an item in the database.
Args:
item: Download item to save
Returns:
Saved item with database ID
"""
try:
repository = self._get_repository()
return await repository.save_item(item)
except Exception as e:
logger.error("Failed to save item to database: %s", e)
return item
async def _set_error_in_database(
self,
item_id: str,
error: str,
) -> bool:
"""Set error message on an item in the database.
Args:
item_id: Download item ID
error: Error message
Returns:
True if update succeeded
"""
try:
repository = self._get_repository()
return await repository.set_error(item_id, error)
except Exception as e:
logger.error("Failed to set error in database: %s", e)
return False
async def _delete_from_database(self, item_id: str) -> bool:
"""Delete an item from the database.
Args:
item_id: Download item ID
Returns:
True if delete succeeded
"""
try:
repository = self._get_repository()
return await repository.delete_item(item_id)
except Exception as e:
logger.error("Failed to delete from database: %s", e)
return False
async def _init_queue_progress(self) -> None:
"""Initialize the download queue progress tracking.
@ -221,7 +119,7 @@ class DownloadService:
)
self._queue_progress_initialized = True
except Exception as e:
logger.error("Failed to initialize queue progress: %s", e)
logger.error("Failed to initialize queue progress", error=str(e))
def _add_to_pending_queue(
self, item: DownloadItem, front: bool = False
@ -267,6 +165,69 @@ class DownloadService:
"""Generate unique identifier for download items."""
return str(uuid.uuid4())
def _load_queue(self) -> None:
"""Load persisted queue from disk."""
try:
if self._persistence_path.exists():
with open(self._persistence_path, "r", encoding="utf-8") as f:
data = json.load(f)
# Restore pending items
for item_dict in data.get("pending", []):
item = DownloadItem(**item_dict)
# Reset status if was downloading when saved
if item.status == DownloadStatus.DOWNLOADING:
item.status = DownloadStatus.PENDING
self._add_to_pending_queue(item)
# Restore failed items that can be retried
for item_dict in data.get("failed", []):
item = DownloadItem(**item_dict)
if item.retry_count < self._max_retries:
item.status = DownloadStatus.PENDING
self._add_to_pending_queue(item)
else:
self._failed_items.append(item)
logger.info(
"Queue restored from disk",
pending_count=len(self._pending_queue),
failed_count=len(self._failed_items),
)
except Exception as e:
logger.error("Failed to load persisted queue", error=str(e))
def _save_queue(self) -> None:
"""Persist current queue state to disk."""
try:
self._persistence_path.parent.mkdir(parents=True, exist_ok=True)
active_items = (
[self._active_download] if self._active_download else []
)
data = {
"pending": [
item.model_dump(mode="json")
for item in self._pending_queue
],
"active": [
item.model_dump(mode="json") for item in active_items
],
"failed": [
item.model_dump(mode="json")
for item in self._failed_items
],
"timestamp": datetime.now(timezone.utc).isoformat(),
}
with open(self._persistence_path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2)
logger.debug("Queue persisted to disk")
except Exception as e:
logger.error("Failed to persist queue", error=str(e))
async def add_to_queue(
self,
serie_id: str,
@ -313,23 +274,22 @@ class DownloadService:
added_at=datetime.now(timezone.utc),
)
# Save to database first to get persistent ID
saved_item = await self._save_to_database(item)
# Always append to end (FIFO order)
self._add_to_pending_queue(item, front=False)
# Add to in-memory cache
self._add_to_pending_queue(saved_item, front=False)
created_ids.append(saved_item.id)
created_ids.append(item.id)
logger.info(
"Item added to queue",
item_id=saved_item.id,
item_id=item.id,
serie_key=serie_id,
serie_name=serie_name,
season=episode.season,
episode=episode.episode,
)
self._save_queue()
# Notify via progress service
queue_status = await self.get_queue_status()
await self._progress_service.update_progress(
@ -346,7 +306,7 @@ class DownloadService:
return created_ids
except Exception as e:
logger.error("Failed to add items to queue: %s", e)
logger.error("Failed to add items to queue", error=str(e))
raise DownloadServiceError(f"Failed to add items: {str(e)}") from e
async def remove_from_queue(self, item_ids: List[str]) -> List[str]:
@ -373,10 +333,8 @@ class DownloadService:
item.completed_at = datetime.now(timezone.utc)
self._failed_items.append(item)
self._active_download = None
# Delete cancelled item from database
await self._delete_from_database(item_id)
removed_ids.append(item_id)
logger.info("Cancelled active download: item_id=%s", item_id)
logger.info("Cancelled active download", item_id=item_id)
continue
# Check pending queue - O(1) lookup using helper dict
@ -384,14 +342,13 @@ class DownloadService:
item = self._pending_items_by_id[item_id]
self._pending_queue.remove(item)
del self._pending_items_by_id[item_id]
# Delete from database
await self._delete_from_database(item_id)
removed_ids.append(item_id)
logger.info(
"Removed from pending queue", item_id=item_id
)
if removed_ids:
self._save_queue()
# Notify via progress service
queue_status = await self.get_queue_status()
await self._progress_service.update_progress(
@ -408,7 +365,7 @@ class DownloadService:
return removed_ids
except Exception as e:
logger.error("Failed to remove items: %s", e)
logger.error("Failed to remove items", error=str(e))
raise DownloadServiceError(
f"Failed to remove items: {str(e)}"
) from e
@ -422,10 +379,6 @@ class DownloadService:
Raises:
DownloadServiceError: If reordering fails
Note:
Reordering is done in-memory only. Database priority is not
updated since the in-memory queue defines the actual order.
"""
try:
# Build new queue based on specified order
@ -446,6 +399,9 @@ class DownloadService:
# Replace queue
self._pending_queue = new_queue
# Save updated queue
self._save_queue()
# Notify via progress service
queue_status = await self.get_queue_status()
await self._progress_service.update_progress(
@ -462,7 +418,7 @@ class DownloadService:
logger.info("Queue reordered", reordered_count=len(item_ids))
except Exception as e:
logger.error("Failed to reorder queue: %s", e)
logger.error("Failed to reorder queue", error=str(e))
raise DownloadServiceError(
f"Failed to reorder queue: {str(e)}"
) from e
@ -506,7 +462,7 @@ class DownloadService:
return "queue_started"
except Exception as e:
logger.error("Failed to start queue processing: %s", e)
logger.error("Failed to start queue processing", error=str(e))
raise DownloadServiceError(
f"Failed to start queue processing: {str(e)}"
) from e
@ -736,15 +692,13 @@ class DownloadService:
Number of items cleared
"""
count = len(self._pending_queue)
# Delete all pending items from database
for item_id in list(self._pending_items_by_id.keys()):
await self._delete_from_database(item_id)
self._pending_queue.clear()
self._pending_items_by_id.clear()
logger.info("Cleared pending items", count=count)
# Save queue state
self._save_queue()
# Notify via progress service
if count > 0:
queue_status = await self.get_queue_status()
@ -795,15 +749,14 @@ class DownloadService:
self._add_to_pending_queue(item)
retried_ids.append(item.id)
# Status is now managed in-memory only
logger.info(
"Retrying failed item: item_id=%s, retry_count=%d",
item.id,
item.retry_count,
"Retrying failed item",
item_id=item.id,
retry_count=item.retry_count
)
if retried_ids:
self._save_queue()
# Notify via progress service
queue_status = await self.get_queue_status()
await self._progress_service.update_progress(
@ -820,7 +773,7 @@ class DownloadService:
return retried_ids
except Exception as e:
logger.error("Failed to retry items: %s", e)
logger.error("Failed to retry items", error=str(e))
raise DownloadServiceError(
f"Failed to retry: {str(e)}"
) from e
@ -837,17 +790,18 @@ class DownloadService:
logger.info("Skipping download due to shutdown")
return
# Update status in memory (status is now in-memory only)
# Update status
item.status = DownloadStatus.DOWNLOADING
item.started_at = datetime.now(timezone.utc)
self._active_download = item
logger.info(
"Starting download: item_id=%s, serie_key=%s, S%02dE%02d",
item.id,
item.serie_id,
item.episode.season,
item.episode.episode,
"Starting download",
item_id=item.id,
serie_key=item.serie_id,
serie_name=item.serie_name,
season=item.episode.season,
episode=item.episode.episode,
)
# Execute download via anime service
@ -855,8 +809,7 @@ class DownloadService:
# - download started/progress/completed/failed events
# - All updates forwarded to ProgressService
# - ProgressService broadcasts to WebSocket clients
# Use serie_folder for filesystem operations
# and serie_id (key) for identification
# Use serie_folder for filesystem operations and serie_id (key) for identification
if not item.serie_folder:
raise DownloadServiceError(
f"Missing serie_folder for download item {item.id}. "
@ -882,11 +835,8 @@ class DownloadService:
self._completed_items.append(item)
# Delete completed item from database (status is in-memory)
await self._delete_from_database(item.id)
logger.info(
"Download completed successfully: item_id=%s", item.id
"Download completed successfully", item_id=item.id
)
else:
raise AnimeServiceError("Download returned False")
@ -894,18 +844,14 @@ class DownloadService:
except asyncio.CancelledError:
# Handle task cancellation during shutdown
logger.info(
"Download cancelled during shutdown: item_id=%s",
item.id,
"Download cancelled during shutdown",
item_id=item.id,
)
item.status = DownloadStatus.CANCELLED
item.completed_at = datetime.now(timezone.utc)
# Delete cancelled item from database
await self._delete_from_database(item.id)
# Return item to pending queue if not shutting down
if not self._is_shutting_down:
self._add_to_pending_queue(item, front=True)
# Re-save to database as pending
await self._save_to_database(item)
raise # Re-raise to properly cancel the task
except Exception as e:
@ -915,14 +861,11 @@ class DownloadService:
item.error = str(e)
self._failed_items.append(item)
# Set error in database
await self._set_error_in_database(item.id, str(e))
logger.error(
"Download failed: item_id=%s, error=%s, retry_count=%d",
item.id,
str(e),
item.retry_count,
"Download failed",
item_id=item.id,
error=str(e),
retry_count=item.retry_count,
)
# Note: Failure is already broadcast by AnimeService
# via ProgressService when SeriesApp fires failed event
@ -931,6 +874,8 @@ class DownloadService:
# Remove from active downloads
if self._active_download and self._active_download.id == item.id:
self._active_download = None
self._save_queue()
async def start(self) -> None:
"""Initialize the download queue service (compatibility method).
@ -951,15 +896,17 @@ class DownloadService:
self._is_stopped = True
# Cancel active download task if running
active_task = self._active_download_task
if active_task and not active_task.done():
if self._active_download_task and not self._active_download_task.done():
logger.info("Cancelling active download task...")
active_task.cancel()
self._active_download_task.cancel()
try:
await active_task
await self._active_download_task
except asyncio.CancelledError:
logger.info("Active download task cancelled")
# Save final state
self._save_queue()
# Shutdown executor immediately, don't wait for tasks
logger.info("Shutting down thread pool executor...")
self._executor.shutdown(wait=False, cancel_futures=True)

View File

@ -1,460 +0,0 @@
"""Queue repository adapter for database-backed download queue operations.
This module provides a repository adapter that wraps the DownloadQueueService
and provides the interface needed by DownloadService for queue persistence.
The repository pattern abstracts the database operations from the business
logic, allowing the DownloadService to work with domain models (DownloadItem)
while the repository handles conversion to/from database models.
"""
from __future__ import annotations
import logging
from datetime import datetime, timezone
from typing import Callable, List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from src.server.database.models import DownloadQueueItem as DBDownloadQueueItem
from src.server.database.service import (
AnimeSeriesService,
DownloadQueueService,
EpisodeService,
)
from src.server.models.download import (
DownloadItem,
DownloadPriority,
DownloadStatus,
EpisodeIdentifier,
)
logger = logging.getLogger(__name__)
class QueueRepositoryError(Exception):
"""Repository-level exception for queue operations."""
class QueueRepository:
"""Repository adapter for database-backed download queue operations.
Provides clean interface for queue operations while handling
model conversion between Pydantic (DownloadItem) and SQLAlchemy
(DownloadQueueItem) models.
Note: The database model (DownloadQueueItem) is simplified and only
stores episode_id as a foreign key. Status, priority, progress, and
retry_count are managed in-memory by the DownloadService.
Attributes:
_db_session_factory: Factory function to create database sessions
"""
def __init__(
self,
db_session_factory: Callable[[], AsyncSession],
) -> None:
"""Initialize the queue repository.
Args:
db_session_factory: Factory function that returns AsyncSession
"""
self._db_session_factory = db_session_factory
logger.info("QueueRepository initialized")
# =========================================================================
# Model Conversion Methods
# =========================================================================
def _from_db_model(
self,
db_item: DBDownloadQueueItem,
item_id: Optional[str] = None,
) -> DownloadItem:
"""Convert database model to DownloadItem.
Note: Since the database model is simplified, status, priority,
progress, and retry_count default to initial values.
Args:
db_item: SQLAlchemy download queue item
item_id: Optional override for item ID
Returns:
Pydantic download item with default status/priority
"""
# Get episode info from the related Episode object
episode = db_item.episode
series = db_item.series
episode_identifier = EpisodeIdentifier(
season=episode.season if episode else 1,
episode=episode.episode_number if episode else 1,
title=episode.title if episode else None,
)
return DownloadItem(
id=item_id or str(db_item.id),
serie_id=series.key if series else "",
serie_folder=series.folder if series else "",
serie_name=series.name if series else "",
episode=episode_identifier,
status=DownloadStatus.PENDING, # Default - managed in-memory
priority=DownloadPriority.NORMAL, # Default - managed in-memory
added_at=db_item.created_at or datetime.now(timezone.utc),
started_at=db_item.started_at,
completed_at=db_item.completed_at,
progress=None, # Managed in-memory
error=db_item.error_message,
retry_count=0, # Managed in-memory
source_url=db_item.download_url,
)
# =========================================================================
# CRUD Operations
# =========================================================================
async def save_item(
self,
item: DownloadItem,
db: Optional[AsyncSession] = None,
) -> DownloadItem:
"""Save a download item to the database.
Creates a new record if the item doesn't exist in the database.
Note: Status, priority, progress, and retry_count are NOT persisted.
Args:
item: Download item to save
db: Optional existing database session
Returns:
Saved download item with database ID
Raises:
QueueRepositoryError: If save operation fails
"""
session = db or self._db_session_factory()
manage_session = db is None
try:
# Find series by key
series = await AnimeSeriesService.get_by_key(session, item.serie_id)
if not series:
# Create series if it doesn't exist
series = await AnimeSeriesService.create(
db=session,
key=item.serie_id,
name=item.serie_name,
site="", # Will be updated later if needed
folder=item.serie_folder,
)
logger.info(
"Created new series for queue item: key=%s, name=%s",
item.serie_id,
item.serie_name,
)
# Find or create episode
episode = await EpisodeService.get_by_episode(
session,
series.id,
item.episode.season,
item.episode.episode,
)
if not episode:
# Create episode if it doesn't exist
episode = await EpisodeService.create(
db=session,
series_id=series.id,
season=item.episode.season,
episode_number=item.episode.episode,
title=item.episode.title,
)
logger.info(
"Created new episode for queue item: S%02dE%02d",
item.episode.season,
item.episode.episode,
)
# Create queue item
db_item = await DownloadQueueService.create(
db=session,
series_id=series.id,
episode_id=episode.id,
download_url=str(item.source_url) if item.source_url else None,
)
if manage_session:
await session.commit()
# Update the item ID with the database ID
item.id = str(db_item.id)
logger.debug(
"Saved queue item to database: item_id=%s, serie_key=%s",
item.id,
item.serie_id,
)
return item
except Exception as e:
if manage_session:
await session.rollback()
logger.error("Failed to save queue item: %s", e)
raise QueueRepositoryError(f"Failed to save item: {e}") from e
finally:
if manage_session:
await session.close()
async def get_item(
self,
item_id: str,
db: Optional[AsyncSession] = None,
) -> Optional[DownloadItem]:
"""Get a download item by ID.
Args:
item_id: Download item ID (database ID as string)
db: Optional existing database session
Returns:
Download item or None if not found
Raises:
QueueRepositoryError: If query fails
"""
session = db or self._db_session_factory()
manage_session = db is None
try:
db_item = await DownloadQueueService.get_by_id(
session, int(item_id)
)
if not db_item:
return None
return self._from_db_model(db_item, item_id)
except ValueError:
# Invalid ID format
return None
except Exception as e:
logger.error("Failed to get queue item: %s", e)
raise QueueRepositoryError(f"Failed to get item: {e}") from e
finally:
if manage_session:
await session.close()
async def get_all_items(
self,
db: Optional[AsyncSession] = None,
) -> List[DownloadItem]:
"""Get all download items regardless of status.
Note: All items are returned with default status (PENDING) since
status is now managed in-memory by the DownloadService.
Args:
db: Optional existing database session
Returns:
List of all download items
Raises:
QueueRepositoryError: If query fails
"""
session = db or self._db_session_factory()
manage_session = db is None
try:
db_items = await DownloadQueueService.get_all(
session, with_series=True
)
return [self._from_db_model(item) for item in db_items]
except Exception as e:
logger.error("Failed to get all items: %s", e)
raise QueueRepositoryError(f"Failed to get all items: {e}") from e
finally:
if manage_session:
await session.close()
async def set_error(
self,
item_id: str,
error: str,
db: Optional[AsyncSession] = None,
) -> bool:
"""Set error message on a download item.
Args:
item_id: Download item ID
error: Error message
db: Optional existing database session
Returns:
True if update succeeded, False if item not found
Raises:
QueueRepositoryError: If update fails
"""
session = db or self._db_session_factory()
manage_session = db is None
try:
result = await DownloadQueueService.set_error(
session,
int(item_id),
error,
)
if manage_session:
await session.commit()
success = result is not None
if success:
logger.debug(
"Set error on queue item: item_id=%s",
item_id,
)
return success
except ValueError:
return False
except Exception as e:
if manage_session:
await session.rollback()
logger.error("Failed to set error: %s", e)
raise QueueRepositoryError(f"Failed to set error: {e}") from e
finally:
if manage_session:
await session.close()
async def delete_item(
self,
item_id: str,
db: Optional[AsyncSession] = None,
) -> bool:
"""Delete a download item from the database.
Args:
item_id: Download item ID
db: Optional existing database session
Returns:
True if item was deleted, False if not found
Raises:
QueueRepositoryError: If delete fails
"""
session = db or self._db_session_factory()
manage_session = db is None
try:
result = await DownloadQueueService.delete(session, int(item_id))
if manage_session:
await session.commit()
if result:
logger.debug("Deleted queue item: item_id=%s", item_id)
return result
except ValueError:
return False
except Exception as e:
if manage_session:
await session.rollback()
logger.error("Failed to delete item: %s", e)
raise QueueRepositoryError(f"Failed to delete item: {e}") from e
finally:
if manage_session:
await session.close()
async def clear_all(
self,
db: Optional[AsyncSession] = None,
) -> int:
"""Clear all download items from the queue.
Args:
db: Optional existing database session
Returns:
Number of items cleared
Raises:
QueueRepositoryError: If operation fails
"""
session = db or self._db_session_factory()
manage_session = db is None
try:
# Get all items first to count them
all_items = await DownloadQueueService.get_all(session)
count = len(all_items)
# Delete each item
for item in all_items:
await DownloadQueueService.delete(session, item.id)
if manage_session:
await session.commit()
logger.info("Cleared all items from queue: count=%d", count)
return count
except Exception as e:
if manage_session:
await session.rollback()
logger.error("Failed to clear queue: %s", e)
raise QueueRepositoryError(f"Failed to clear queue: {e}") from e
finally:
if manage_session:
await session.close()
# Singleton instance
_queue_repository_instance: Optional[QueueRepository] = None
def get_queue_repository(
db_session_factory: Optional[Callable[[], AsyncSession]] = None,
) -> QueueRepository:
"""Get or create the QueueRepository singleton.
Args:
db_session_factory: Optional factory function for database sessions.
If not provided, uses default from connection module.
Returns:
QueueRepository singleton instance
"""
global _queue_repository_instance
if _queue_repository_instance is None:
if db_session_factory is None:
# Use default session factory
from src.server.database.connection import get_async_session_factory
db_session_factory = get_async_session_factory
_queue_repository_instance = QueueRepository(db_session_factory)
return _queue_repository_instance
def reset_queue_repository() -> None:
"""Reset the QueueRepository singleton.
Used for testing to ensure fresh state between tests.
"""
global _queue_repository_instance
_queue_repository_instance = None

View File

@ -415,7 +415,7 @@ class ScanService:
message="Initializing scan...",
)
except Exception as e:
logger.error("Failed to start progress tracking: %s", e)
logger.error("Failed to start progress tracking", error=str(e))
# Emit scan started event
await self._emit_scan_event({
@ -479,7 +479,7 @@ class ScanService:
folder=scan_progress.folder,
)
except Exception as e:
logger.debug("Progress update skipped: %s", e)
logger.debug("Progress update skipped", error=str(e))
# Emit progress event with key as primary identifier
await self._emit_scan_event({
@ -541,7 +541,7 @@ class ScanService:
error_message=completion_context.message,
)
except Exception as e:
logger.debug("Progress completion skipped: %s", e)
logger.debug("Progress completion skipped", error=str(e))
# Emit completion event
await self._emit_scan_event({
@ -598,7 +598,7 @@ class ScanService:
error_message="Scan cancelled by user",
)
except Exception as e:
logger.debug("Progress cancellation skipped: %s", e)
logger.debug("Progress cancellation skipped", error=str(e))
logger.info("Scan cancelled")
return True

View File

@ -65,10 +65,6 @@ def get_series_app() -> SeriesApp:
Raises:
HTTPException: If SeriesApp is not initialized or anime directory
is not configured
Note:
This creates a SeriesApp without database support. For database-
backed storage, use get_series_app_with_db() instead.
"""
global _series_app
@ -107,6 +103,7 @@ def reset_series_app() -> None:
_series_app = None
async def get_database_session() -> AsyncGenerator:
"""
Dependency to get database session.
@ -137,75 +134,6 @@ async def get_database_session() -> AsyncGenerator:
)
async def get_optional_database_session() -> AsyncGenerator:
"""
Dependency to get optional database session.
Unlike get_database_session(), this returns None if the database
is not available, allowing endpoints to fall back to other storage.
Yields:
AsyncSession or None: Database session if available, None otherwise
Example:
@app.post("/anime/add")
async def add_anime(
db: Optional[AsyncSession] = Depends(get_optional_database_session)
):
if db:
# Use database
await AnimeSeriesService.create(db, ...)
else:
# Fall back to file-based storage
series_app.list.add(serie)
"""
try:
from src.server.database import get_db_session
async with get_db_session() as session:
yield session
except (ImportError, RuntimeError):
# Database not available - yield None
yield None
async def get_series_app_with_db(
db: AsyncSession = Depends(get_optional_database_session),
) -> SeriesApp:
"""
Dependency to get SeriesApp instance with database support.
This creates or returns a SeriesApp instance and injects the
database session for database-backed storage.
Args:
db: Optional database session from dependency injection
Returns:
SeriesApp: The main application instance with database support
Raises:
HTTPException: If SeriesApp is not initialized or anime directory
is not configured
Example:
@app.post("/api/anime/scan")
async def scan_anime(
series_app: SeriesApp = Depends(get_series_app_with_db)
):
# series_app has db_session configured
await series_app.serie_scanner.scan_async()
"""
# Get the base SeriesApp
app = get_series_app()
# Inject database session if available
if db:
app.set_db_session(db)
return app
def get_current_user(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(
http_bearer_security

View File

@ -72,14 +72,11 @@ async def anime_service(mock_series_app, progress_service):
@pytest.fixture
async def download_service(anime_service, progress_service):
"""Create a DownloadService with mock queue repository."""
from tests.unit.test_download_service import MockQueueRepository
mock_repo = MockQueueRepository()
"""Create a DownloadService."""
service = DownloadService(
anime_service=anime_service,
progress_service=progress_service,
queue_repository=mock_repo,
persistence_path="/tmp/test_integration_progress_queue.json",
)
yield service
await service.stop()

View File

@ -88,10 +88,9 @@ def progress_service():
@pytest.fixture
async def download_service(mock_series_app, progress_service, tmp_path):
"""Create a DownloadService with mock repository for testing."""
from tests.unit.test_download_service import MockQueueRepository
mock_repo = MockQueueRepository()
"""Create a DownloadService with dependencies."""
import uuid
persistence_path = tmp_path / f"test_queue_{uuid.uuid4()}.json"
anime_service = AnimeService(
series_app=mock_series_app,
@ -102,7 +101,7 @@ async def download_service(mock_series_app, progress_service, tmp_path):
service = DownloadService(
anime_service=anime_service,
progress_service=progress_service,
queue_repository=mock_repo,
persistence_path=str(persistence_path),
)
yield service
await service.stop()
@ -320,6 +319,8 @@ class TestServiceIdentifierConsistency:
- Persisted data contains serie_id (key)
- Data can be restored with correct identifiers
"""
import json
# Add item to queue
await download_service.add_to_queue(
serie_id="jujutsu-kaisen",
@ -329,13 +330,18 @@ class TestServiceIdentifierConsistency:
priority=DownloadPriority.NORMAL,
)
# Verify item is in pending queue (in-memory cache synced with DB)
pending_items = list(download_service._pending_queue)
assert len(pending_items) == 1
# Read persisted data
persistence_path = download_service._persistence_path
with open(persistence_path, "r") as f:
data = json.load(f)
persisted_item = pending_items[0]
assert persisted_item.serie_id == "jujutsu-kaisen"
assert persisted_item.serie_folder == "Jujutsu Kaisen (2020)"
# Verify persisted data structure
assert "pending" in data
assert len(data["pending"]) == 1
persisted_item = data["pending"][0]
assert persisted_item["serie_id"] == "jujutsu-kaisen"
assert persisted_item["serie_folder"] == "Jujutsu Kaisen (2020)"
class TestWebSocketIdentifierConsistency:

View File

@ -69,17 +69,16 @@ async def anime_service(mock_series_app, progress_service):
@pytest.fixture
async def download_service(anime_service, progress_service, tmp_path):
"""Create a DownloadService with mock repository for testing.
"""Create a DownloadService with dependencies.
Uses mock repository to ensure each test has isolated queue storage.
Uses tmp_path to ensure each test has isolated queue storage.
"""
from tests.unit.test_download_service import MockQueueRepository
mock_repo = MockQueueRepository()
import uuid
persistence_path = tmp_path / f"test_queue_{uuid.uuid4()}.json"
service = DownloadService(
anime_service=anime_service,
progress_service=progress_service,
queue_repository=mock_repo,
persistence_path=str(persistence_path),
)
yield service, progress_service
await service.stop()

View File

@ -28,13 +28,12 @@ class TestDownloadQueueStress:
@pytest.fixture
def download_service(self, mock_anime_service, tmp_path):
"""Create download service with mock repository."""
from tests.unit.test_download_service import MockQueueRepository
mock_repo = MockQueueRepository()
"""Create download service with mock."""
persistence_path = str(tmp_path / "test_queue.json")
service = DownloadService(
anime_service=mock_anime_service,
max_retries=3,
queue_repository=mock_repo,
persistence_path=persistence_path,
)
return service
@ -177,13 +176,12 @@ class TestDownloadMemoryUsage:
@pytest.fixture
def download_service(self, mock_anime_service, tmp_path):
"""Create download service with mock repository."""
from tests.unit.test_download_service import MockQueueRepository
mock_repo = MockQueueRepository()
"""Create download service with mock."""
persistence_path = str(tmp_path / "test_queue.json")
service = DownloadService(
anime_service=mock_anime_service,
max_retries=3,
queue_repository=mock_repo,
persistence_path=persistence_path,
)
return service
@ -234,13 +232,12 @@ class TestDownloadConcurrency:
@pytest.fixture
def download_service(self, mock_anime_service, tmp_path):
"""Create download service with mock repository."""
from tests.unit.test_download_service import MockQueueRepository
mock_repo = MockQueueRepository()
"""Create download service with mock."""
persistence_path = str(tmp_path / "test_queue.json")
service = DownloadService(
anime_service=mock_anime_service,
max_retries=3,
queue_repository=mock_repo,
persistence_path=persistence_path,
)
return service
@ -324,12 +321,11 @@ class TestDownloadErrorHandling:
self, mock_failing_anime_service, tmp_path
):
"""Create download service with failing mock."""
from tests.unit.test_download_service import MockQueueRepository
mock_repo = MockQueueRepository()
persistence_path = str(tmp_path / "test_queue.json")
service = DownloadService(
anime_service=mock_failing_anime_service,
max_retries=3,
queue_repository=mock_repo,
persistence_path=persistence_path,
)
return service
@ -342,13 +338,12 @@ class TestDownloadErrorHandling:
@pytest.fixture
def download_service(self, mock_anime_service, tmp_path):
"""Create download service with mock repository."""
from tests.unit.test_download_service import MockQueueRepository
mock_repo = MockQueueRepository()
"""Create download service with mock."""
persistence_path = str(tmp_path / "test_queue.json")
service = DownloadService(
anime_service=mock_anime_service,
max_retries=3,
queue_repository=mock_repo,
persistence_path=persistence_path,
)
return service

View File

@ -65,6 +65,7 @@ class TestAnimeSeriesResponse:
title="Attack on Titan",
folder="Attack on Titan (2013)",
episodes=[ep],
total_episodes=12,
)
assert series.key == "attack-on-titan"

View File

@ -318,6 +318,25 @@ class TestConfigServiceBackups:
assert len(backups) == 3 # Should only keep max_backups
class TestConfigServiceMigration:
"""Test configuration migration."""
def test_migration_preserves_data(self, config_service, sample_config):
"""Test that migration preserves configuration data."""
# Manually save config with old version
data = sample_config.model_dump()
data["version"] = "0.9.0" # Old version
with open(config_service.config_path, "w", encoding="utf-8") as f:
json.dump(data, f)
# Load should migrate automatically
loaded = config_service.load_config()
assert loaded.name == sample_config.name
assert loaded.data_dir == sample_config.data_dir
class TestConfigServiceSingleton:
"""Test singleton instance management."""

View File

@ -25,6 +25,7 @@ from src.server.database.init import (
create_database_backup,
create_database_schema,
get_database_info,
get_migration_guide,
get_schema_version,
initialize_database,
seed_initial_data,
@ -371,6 +372,16 @@ def test_get_database_info():
assert set(info["expected_tables"]) == EXPECTED_TABLES
def test_get_migration_guide():
"""Test getting migration guide."""
guide = get_migration_guide()
assert isinstance(guide, str)
assert "Alembic" in guide
assert "alembic init" in guide
assert "alembic upgrade head" in guide
# =============================================================================
# Integration Tests
# =============================================================================

View File

@ -14,7 +14,9 @@ from sqlalchemy.orm import Session, sessionmaker
from src.server.database.base import Base, SoftDeleteMixin, TimestampMixin
from src.server.database.models import (
AnimeSeries,
DownloadPriority,
DownloadQueueItem,
DownloadStatus,
Episode,
UserSession,
)
@ -47,6 +49,11 @@ class TestAnimeSeries:
name="Attack on Titan",
site="https://aniworld.to",
folder="/anime/attack-on-titan",
description="Epic anime about titans",
status="completed",
total_episodes=75,
cover_url="https://example.com/cover.jpg",
episode_dict={1: [1, 2, 3], 2: [1, 2, 3, 4]},
)
db_session.add(series)
@ -165,7 +172,9 @@ class TestEpisode:
episode_number=5,
title="The Fifth Episode",
file_path="/anime/test/S01E05.mp4",
file_size=524288000, # 500 MB
is_downloaded=True,
download_date=datetime.now(timezone.utc),
)
db_session.add(episode)
@ -216,17 +225,17 @@ class TestDownloadQueueItem:
db_session.add(series)
db_session.commit()
episode = Episode(
item = DownloadQueueItem(
series_id=series.id,
season=1,
episode_number=3,
)
db_session.add(episode)
db_session.commit()
item = DownloadQueueItem(
series_id=series.id,
episode_id=episode.id,
status=DownloadStatus.DOWNLOADING,
priority=DownloadPriority.HIGH,
progress_percent=45.5,
downloaded_bytes=250000000,
total_bytes=550000000,
download_speed=2500000.0,
retry_count=0,
download_url="https://example.com/download/ep3",
file_destination="/anime/download/S01E03.mp4",
)
@ -236,38 +245,37 @@ class TestDownloadQueueItem:
# Verify saved
assert item.id is not None
assert item.episode_id == episode.id
assert item.series_id == series.id
assert item.status == DownloadStatus.DOWNLOADING
assert item.priority == DownloadPriority.HIGH
assert item.progress_percent == 45.5
assert item.retry_count == 0
def test_download_item_episode_relationship(self, db_session: Session):
"""Test download item episode relationship."""
def test_download_item_status_enum(self, db_session: Session):
"""Test download status enum values."""
series = AnimeSeries(
key="relationship-test",
name="Relationship Test",
key="status-test",
name="Status Test",
site="https://example.com",
folder="/anime/relationship",
folder="/anime/status",
)
db_session.add(series)
db_session.commit()
episode = Episode(
item = DownloadQueueItem(
series_id=series.id,
season=1,
episode_number=1,
)
db_session.add(episode)
db_session.commit()
item = DownloadQueueItem(
series_id=series.id,
episode_id=episode.id,
status=DownloadStatus.PENDING,
)
db_session.add(item)
db_session.commit()
# Verify relationship
assert item.episode.id == episode.id
assert item.series.id == series.id
# Update status
item.status = DownloadStatus.COMPLETED
db_session.commit()
# Verify status change
assert item.status == DownloadStatus.COMPLETED
def test_download_item_error_handling(self, db_session: Session):
"""Test download item with error information."""
@ -280,24 +288,21 @@ class TestDownloadQueueItem:
db_session.add(series)
db_session.commit()
episode = Episode(
item = DownloadQueueItem(
series_id=series.id,
season=1,
episode_number=1,
)
db_session.add(episode)
db_session.commit()
item = DownloadQueueItem(
series_id=series.id,
episode_id=episode.id,
status=DownloadStatus.FAILED,
error_message="Network timeout after 30 seconds",
retry_count=2,
)
db_session.add(item)
db_session.commit()
# Verify error info
assert item.status == DownloadStatus.FAILED
assert item.error_message == "Network timeout after 30 seconds"
assert item.retry_count == 2
class TestUserSession:
@ -497,31 +502,32 @@ class TestDatabaseQueries:
db_session.add(series)
db_session.commit()
# Create episodes and items
for i in range(3):
episode = Episode(
# Create items with different statuses
for i, status in enumerate([
DownloadStatus.PENDING,
DownloadStatus.DOWNLOADING,
DownloadStatus.COMPLETED,
]):
item = DownloadQueueItem(
series_id=series.id,
season=1,
episode_number=i + 1,
)
db_session.add(episode)
db_session.commit()
item = DownloadQueueItem(
series_id=series.id,
episode_id=episode.id,
status=status,
)
db_session.add(item)
db_session.commit()
# Query all items
# Query pending items
result = db_session.execute(
select(DownloadQueueItem)
select(DownloadQueueItem).where(
DownloadQueueItem.status == DownloadStatus.PENDING
)
)
items = result.scalars().all()
pending = result.scalars().all()
# Verify query
assert len(items) == 3
assert len(pending) == 1
assert pending[0].episode_number == 1
def test_query_active_sessions(self, db_session: Session):
"""Test querying active user sessions."""

View File

@ -10,6 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from src.server.database.base import Base
from src.server.database.models import DownloadPriority, DownloadStatus
from src.server.database.service import (
AnimeSeriesService,
DownloadQueueService,
@ -64,11 +65,17 @@ async def test_create_anime_series(db_session):
name="Test Anime",
site="https://example.com",
folder="/path/to/anime",
description="A test anime",
status="ongoing",
total_episodes=12,
cover_url="https://example.com/cover.jpg",
)
assert series.id is not None
assert series.key == "test-anime-1"
assert series.name == "Test Anime"
assert series.description == "A test anime"
assert series.total_episodes == 12
@pytest.mark.asyncio
@ -153,11 +160,13 @@ async def test_update_anime_series(db_session):
db_session,
series.id,
name="Updated Name",
total_episodes=24,
)
await db_session.commit()
assert updated is not None
assert updated.name == "Updated Name"
assert updated.total_episodes == 24
@pytest.mark.asyncio
@ -299,12 +308,14 @@ async def test_mark_episode_downloaded(db_session):
db_session,
episode.id,
file_path="/path/to/file.mp4",
file_size=1024000,
)
await db_session.commit()
assert updated is not None
assert updated.is_downloaded is True
assert updated.file_path == "/path/to/file.mp4"
assert updated.download_date is not None
# ============================================================================
@ -325,30 +336,23 @@ async def test_create_download_queue_item(db_session):
)
await db_session.commit()
# Create episode
episode = await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
)
await db_session.commit()
# Add to queue
item = await DownloadQueueService.create(
db_session,
series_id=series.id,
episode_id=episode.id,
season=1,
episode_number=1,
priority=DownloadPriority.HIGH,
)
assert item.id is not None
assert item.episode_id == episode.id
assert item.series_id == series.id
assert item.status == DownloadStatus.PENDING
assert item.priority == DownloadPriority.HIGH
@pytest.mark.asyncio
async def test_get_download_queue_item_by_episode(db_session):
"""Test retrieving download queue item by episode."""
async def test_get_pending_downloads(db_session):
"""Test retrieving pending downloads."""
# Create series
series = await AnimeSeriesService.create(
db_session,
@ -358,32 +362,29 @@ async def test_get_download_queue_item_by_episode(db_session):
folder="/path/test5",
)
# Create episode
episode = await EpisodeService.create(
# Add pending items
await DownloadQueueService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
)
await db_session.commit()
# Add to queue
await DownloadQueueService.create(
db_session,
series_id=series.id,
episode_id=episode.id,
season=1,
episode_number=2,
)
await db_session.commit()
# Retrieve by episode
item = await DownloadQueueService.get_by_episode(db_session, episode.id)
assert item is not None
assert item.episode_id == episode.id
# Retrieve pending
pending = await DownloadQueueService.get_pending(db_session)
assert len(pending) == 2
@pytest.mark.asyncio
async def test_set_download_error(db_session):
"""Test setting error on download queue item."""
async def test_update_download_status(db_session):
"""Test updating download status."""
# Create series and queue item
series = await AnimeSeriesService.create(
db_session,
@ -392,34 +393,30 @@ async def test_set_download_error(db_session):
site="https://example.com",
folder="/path/test6",
)
episode = await EpisodeService.create(
item = await DownloadQueueService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
)
item = await DownloadQueueService.create(
db_session,
series_id=series.id,
episode_id=episode.id,
)
await db_session.commit()
# Set error
updated = await DownloadQueueService.set_error(
# Update status
updated = await DownloadQueueService.update_status(
db_session,
item.id,
"Network error",
DownloadStatus.DOWNLOADING,
)
await db_session.commit()
assert updated is not None
assert updated.error_message == "Network error"
assert updated.status == DownloadStatus.DOWNLOADING
assert updated.started_at is not None
@pytest.mark.asyncio
async def test_delete_download_queue_item_by_episode(db_session):
"""Test deleting download queue item by episode."""
async def test_update_download_progress(db_session):
"""Test updating download progress."""
# Create series and queue item
series = await AnimeSeriesService.create(
db_session,
@ -428,31 +425,109 @@ async def test_delete_download_queue_item_by_episode(db_session):
site="https://example.com",
folder="/path/test7",
)
episode = await EpisodeService.create(
item = await DownloadQueueService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
)
await DownloadQueueService.create(
await db_session.commit()
# Update progress
updated = await DownloadQueueService.update_progress(
db_session,
item.id,
progress_percent=50.0,
downloaded_bytes=500000,
total_bytes=1000000,
download_speed=50000.0,
)
await db_session.commit()
assert updated is not None
assert updated.progress_percent == 50.0
assert updated.downloaded_bytes == 500000
assert updated.total_bytes == 1000000
@pytest.mark.asyncio
async def test_clear_completed_downloads(db_session):
"""Test clearing completed downloads."""
# Create series and completed items
series = await AnimeSeriesService.create(
db_session,
key="test-series-8",
name="Test Series 8",
site="https://example.com",
folder="/path/test8",
)
item1 = await DownloadQueueService.create(
db_session,
series_id=series.id,
episode_id=episode.id,
season=1,
episode_number=1,
)
await db_session.commit()
# Delete by episode
deleted = await DownloadQueueService.delete_by_episode(
item2 = await DownloadQueueService.create(
db_session,
episode.id,
series_id=series.id,
season=1,
episode_number=2,
)
# Mark items as completed
await DownloadQueueService.update_status(
db_session,
item1.id,
DownloadStatus.COMPLETED,
)
await DownloadQueueService.update_status(
db_session,
item2.id,
DownloadStatus.COMPLETED,
)
await db_session.commit()
assert deleted is True
# Clear completed
count = await DownloadQueueService.clear_completed(db_session)
await db_session.commit()
# Verify deleted
item = await DownloadQueueService.get_by_episode(db_session, episode.id)
assert item is None
assert count == 2
@pytest.mark.asyncio
async def test_retry_failed_downloads(db_session):
"""Test retrying failed downloads."""
# Create series and failed item
series = await AnimeSeriesService.create(
db_session,
key="test-series-9",
name="Test Series 9",
site="https://example.com",
folder="/path/test9",
)
item = await DownloadQueueService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
)
# Mark as failed
await DownloadQueueService.update_status(
db_session,
item.id,
DownloadStatus.FAILED,
error_message="Network error",
)
await db_session.commit()
# Retry
retried = await DownloadQueueService.retry_failed(db_session)
await db_session.commit()
assert len(retried) == 1
assert retried[0].status == DownloadStatus.PENDING
assert retried[0].error_message is None
# ============================================================================

View File

@ -102,20 +102,27 @@ async def anime_service(mock_series_app, progress_service):
@pytest.fixture
async def download_service(anime_service, progress_service):
"""Create a DownloadService with mock repository for testing."""
from tests.unit.test_download_service import MockQueueRepository
"""Create a DownloadService with dependencies."""
import os
persistence_path = "/tmp/test_download_progress_queue.json"
mock_repo = MockQueueRepository()
# Remove any existing queue file
if os.path.exists(persistence_path):
os.remove(persistence_path)
service = DownloadService(
anime_service=anime_service,
progress_service=progress_service,
queue_repository=mock_repo,
persistence_path=persistence_path,
)
yield service, progress_service
await service.stop()
# Clean up after test
if os.path.exists(persistence_path):
os.remove(persistence_path)
class TestDownloadProgressWebSocket:

View File

@ -1,13 +1,14 @@
"""Unit tests for the download queue service.
Tests cover queue management, manual download control, database persistence,
Tests cover queue management, manual download control, persistence,
and error scenarios for the simplified download service.
"""
from __future__ import annotations
import asyncio
import json
from datetime import datetime, timezone
from typing import Dict, List, Optional
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
import pytest
@ -22,58 +23,6 @@ from src.server.services.anime_service import AnimeService
from src.server.services.download_service import DownloadService, DownloadServiceError
class MockQueueRepository:
"""Mock implementation of QueueRepository for testing.
This provides an in-memory storage that mimics the simplified database
repository behavior without requiring actual database connections.
Note: The repository is simplified - status, priority, progress are
now managed in-memory by DownloadService, not stored in database.
"""
def __init__(self):
"""Initialize mock repository with in-memory storage."""
self._items: Dict[str, DownloadItem] = {}
async def save_item(self, item: DownloadItem) -> DownloadItem:
"""Save item to in-memory storage."""
self._items[item.id] = item
return item
async def get_item(self, item_id: str) -> Optional[DownloadItem]:
"""Get item by ID from in-memory storage."""
return self._items.get(item_id)
async def get_all_items(self) -> List[DownloadItem]:
"""Get all items in storage."""
return list(self._items.values())
async def set_error(
self,
item_id: str,
error: str,
) -> bool:
"""Set error message on an item."""
if item_id not in self._items:
return False
self._items[item_id].error = error
return True
async def delete_item(self, item_id: str) -> bool:
"""Delete item from storage."""
if item_id in self._items:
del self._items[item_id]
return True
return False
async def clear_all(self) -> int:
"""Clear all items."""
count = len(self._items)
self._items.clear()
return count
@pytest.fixture
def mock_anime_service():
"""Create a mock AnimeService."""
@ -83,18 +32,18 @@ def mock_anime_service():
@pytest.fixture
def mock_queue_repository():
"""Create a mock QueueRepository for testing."""
return MockQueueRepository()
def temp_persistence_path(tmp_path):
"""Create a temporary persistence path."""
return str(tmp_path / "test_queue.json")
@pytest.fixture
def download_service(mock_anime_service, mock_queue_repository):
def download_service(mock_anime_service, temp_persistence_path):
"""Create a DownloadService instance for testing."""
return DownloadService(
anime_service=mock_anime_service,
queue_repository=mock_queue_repository,
max_retries=3,
persistence_path=temp_persistence_path,
)
@ -102,12 +51,12 @@ class TestDownloadServiceInitialization:
"""Test download service initialization."""
def test_initialization_creates_queues(
self, mock_anime_service, mock_queue_repository
self, mock_anime_service, temp_persistence_path
):
"""Test that initialization creates empty queues."""
service = DownloadService(
anime_service=mock_anime_service,
queue_repository=mock_queue_repository,
persistence_path=temp_persistence_path,
)
assert len(service._pending_queue) == 0
@ -116,30 +65,45 @@ class TestDownloadServiceInitialization:
assert len(service._failed_items) == 0
assert service._is_stopped is True
@pytest.mark.asyncio
async def test_initialization_loads_persisted_queue(
self, mock_anime_service, mock_queue_repository
def test_initialization_loads_persisted_queue(
self, mock_anime_service, temp_persistence_path
):
"""Test that initialization loads persisted queue from database."""
# Pre-populate the mock repository with a pending item
test_item = DownloadItem(
id="test-id-1",
serie_id="series-1",
serie_folder="test-series",
serie_name="Test Series",
episode=EpisodeIdentifier(season=1, episode=1),
status=DownloadStatus.PENDING,
priority=DownloadPriority.NORMAL,
added_at=datetime.now(timezone.utc),
)
await mock_queue_repository.save_item(test_item)
"""Test that initialization loads persisted queue state."""
# Create a persisted queue file
persistence_file = Path(temp_persistence_path)
persistence_file.parent.mkdir(parents=True, exist_ok=True)
test_data = {
"pending": [
{
"id": "test-id-1",
"serie_id": "series-1",
"serie_folder": "test-series", # Added missing field
"serie_name": "Test Series",
"episode": {"season": 1, "episode": 1, "title": None},
"status": "pending",
"priority": "NORMAL", # Must be uppercase
"added_at": datetime.now(timezone.utc).isoformat(),
"started_at": None,
"completed_at": None,
"progress": None,
"error": None,
"retry_count": 0,
"source_url": None,
}
],
"active": [],
"failed": [],
"timestamp": datetime.now(timezone.utc).isoformat(),
}
with open(persistence_file, "w", encoding="utf-8") as f:
json.dump(test_data, f)
# Create service and initialize from database
service = DownloadService(
anime_service=mock_anime_service,
queue_repository=mock_queue_repository,
persistence_path=temp_persistence_path,
)
await service.initialize()
assert len(service._pending_queue) == 1
assert service._pending_queue[0].id == "test-id-1"
@ -427,13 +391,11 @@ class TestQueueControl:
class TestPersistence:
"""Test queue persistence functionality with database backend."""
"""Test queue persistence functionality."""
@pytest.mark.asyncio
async def test_queue_persistence(
self, download_service, mock_queue_repository
):
"""Test that queue state is persisted to database."""
async def test_queue_persistence(self, download_service):
"""Test that queue state is persisted to disk."""
await download_service.add_to_queue(
serie_id="series-1",
serie_folder="series",
@ -441,20 +403,26 @@ class TestPersistence:
episodes=[EpisodeIdentifier(season=1, episode=1)],
)
# Item should be saved in mock repository
all_items = await mock_queue_repository.get_all_items()
assert len(all_items) == 1
assert all_items[0].serie_id == "series-1"
# Persistence file should exist
persistence_path = Path(download_service._persistence_path)
assert persistence_path.exists()
# Check file contents
with open(persistence_path, "r") as f:
data = json.load(f)
assert len(data["pending"]) == 1
assert data["pending"][0]["serie_id"] == "series-1"
@pytest.mark.asyncio
async def test_queue_recovery_after_restart(
self, mock_anime_service, mock_queue_repository
self, mock_anime_service, temp_persistence_path
):
"""Test that queue is recovered after service restart."""
# Create and populate first service
service1 = DownloadService(
anime_service=mock_anime_service,
queue_repository=mock_queue_repository,
persistence_path=temp_persistence_path,
)
await service1.add_to_queue(
@ -467,13 +435,11 @@ class TestPersistence:
],
)
# Create new service with same repository (simulating restart)
# Create new service with same persistence path
service2 = DownloadService(
anime_service=mock_anime_service,
queue_repository=mock_queue_repository,
persistence_path=temp_persistence_path,
)
# Initialize to load from database to recover state
await service2.initialize()
# Should recover pending items
assert len(service2._pending_queue) == 2

View File

@ -0,0 +1,419 @@
"""
Tests for database migration system.
This module tests the migration runner, validator, and base classes.
"""
from datetime import datetime
from pathlib import Path
from unittest.mock import AsyncMock, Mock, patch
import pytest
from src.server.database.migrations.base import (
Migration,
MigrationError,
MigrationHistory,
)
from src.server.database.migrations.runner import MigrationRunner
from src.server.database.migrations.validator import MigrationValidator
class TestMigration:
"""Tests for base Migration class."""
def test_migration_initialization(self):
"""Test migration can be initialized with basic attributes."""
class TestMig(Migration):
async def upgrade(self, session):
return None
async def downgrade(self, session):
return None
mig = TestMig(
version="20250124_001", description="Test migration"
)
assert mig.version == "20250124_001"
assert mig.description == "Test migration"
assert isinstance(mig.created_at, datetime)
def test_migration_equality(self):
"""Test migrations are equal based on version."""
class TestMig1(Migration):
async def upgrade(self, session):
return None
async def downgrade(self, session):
return None
class TestMig2(Migration):
async def upgrade(self, session):
return None
async def downgrade(self, session):
return None
mig1 = TestMig1(version="20250124_001", description="Test 1")
mig2 = TestMig2(version="20250124_001", description="Test 2")
mig3 = TestMig1(version="20250124_002", description="Test 3")
assert mig1 == mig2
assert mig1 != mig3
assert hash(mig1) == hash(mig2)
assert hash(mig1) != hash(mig3)
def test_migration_repr(self):
"""Test migration string representation."""
class TestMig(Migration):
async def upgrade(self, session):
return None
async def downgrade(self, session):
return None
mig = TestMig(
version="20250124_001", description="Test migration"
)
assert "20250124_001" in repr(mig)
assert "Test migration" in repr(mig)
class TestMigrationHistory:
"""Tests for MigrationHistory class."""
def test_history_initialization(self):
"""Test migration history record can be created."""
history = MigrationHistory(
version="20250124_001",
description="Test migration",
applied_at=datetime.now(),
execution_time_ms=1500,
success=True,
)
assert history.version == "20250124_001"
assert history.description == "Test migration"
assert history.execution_time_ms == 1500
assert history.success is True
assert history.error_message is None
def test_history_with_error(self):
"""Test migration history with error message."""
history = MigrationHistory(
version="20250124_001",
description="Failed migration",
applied_at=datetime.now(),
execution_time_ms=500,
success=False,
error_message="Test error",
)
assert history.success is False
assert history.error_message == "Test error"
class TestMigrationValidator:
"""Tests for MigrationValidator class."""
def test_validator_initialization(self):
"""Test validator can be initialized."""
validator = MigrationValidator()
assert isinstance(validator.errors, list)
assert isinstance(validator.warnings, list)
assert len(validator.errors) == 0
def test_validate_version_format_valid(self):
"""Test validation of valid version formats."""
validator = MigrationValidator()
assert validator._validate_version_format("20250124_001")
assert validator._validate_version_format("20231201_099")
assert validator._validate_version_format("20250124_001_description")
def test_validate_version_format_invalid(self):
"""Test validation of invalid version formats."""
validator = MigrationValidator()
assert not validator._validate_version_format("")
assert not validator._validate_version_format("20250124")
assert not validator._validate_version_format("invalid_001")
assert not validator._validate_version_format("202501_001")
def test_validate_migration_valid(self):
"""Test validation of valid migration."""
class TestMig(Migration):
async def upgrade(self, session):
return None
async def downgrade(self, session):
return None
mig = TestMig(
version="20250124_001",
description="Valid test migration",
)
validator = MigrationValidator()
assert validator.validate_migration(mig) is True
assert len(validator.errors) == 0
def test_validate_migration_invalid_version(self):
"""Test validation fails for invalid version."""
class TestMig(Migration):
async def upgrade(self, session):
return None
async def downgrade(self, session):
return None
mig = TestMig(
version="invalid",
description="Valid description",
)
validator = MigrationValidator()
assert validator.validate_migration(mig) is False
assert len(validator.errors) > 0
def test_validate_migration_missing_description(self):
"""Test validation fails for missing description."""
class TestMig(Migration):
async def upgrade(self, session):
return None
async def downgrade(self, session):
return None
mig = TestMig(version="20250124_001", description="")
validator = MigrationValidator()
assert validator.validate_migration(mig) is False
assert any("description" in e.lower() for e in validator.errors)
def test_validate_migrations_duplicate_version(self):
"""Test validation detects duplicate versions."""
class TestMig1(Migration):
async def upgrade(self, session):
return None
async def downgrade(self, session):
return None
class TestMig2(Migration):
async def upgrade(self, session):
return None
async def downgrade(self, session):
return None
mig1 = TestMig1(version="20250124_001", description="First")
mig2 = TestMig2(version="20250124_001", description="Duplicate")
validator = MigrationValidator()
assert validator.validate_migrations([mig1, mig2]) is False
assert any("duplicate" in e.lower() for e in validator.errors)
def test_check_migration_conflicts(self):
"""Test detection of migration conflicts."""
class TestMig(Migration):
async def upgrade(self, session):
return None
async def downgrade(self, session):
return None
old_mig = TestMig(version="20250101_001", description="Old")
new_mig = TestMig(version="20250124_001", description="New")
validator = MigrationValidator()
# No conflict when pending is newer
conflict = validator.check_migration_conflicts(
[new_mig], ["20250101_001"]
)
assert conflict is None
# Conflict when pending is older
conflict = validator.check_migration_conflicts(
[old_mig], ["20250124_001"]
)
assert conflict is not None
assert "older" in conflict.lower()
def test_get_validation_report(self):
"""Test validation report generation."""
validator = MigrationValidator()
validator.errors.append("Test error")
validator.warnings.append("Test warning")
report = validator.get_validation_report()
assert "Test error" in report
assert "Test warning" in report
assert "Validation Errors:" in report
assert "Validation Warnings:" in report
def test_raise_if_invalid(self):
"""Test exception raising on validation failure."""
validator = MigrationValidator()
validator.errors.append("Test error")
with pytest.raises(MigrationError):
validator.raise_if_invalid()
@pytest.mark.asyncio
class TestMigrationRunner:
"""Tests for MigrationRunner class."""
@pytest.fixture
def mock_session(self):
"""Create mock database session."""
session = AsyncMock()
session.execute = AsyncMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
return session
@pytest.fixture
def migrations_dir(self, tmp_path):
"""Create temporary migrations directory."""
return tmp_path / "migrations"
async def test_runner_initialization(
self, migrations_dir, mock_session
):
"""Test migration runner can be initialized."""
runner = MigrationRunner(migrations_dir, mock_session)
assert runner.migrations_dir == migrations_dir
assert runner.session == mock_session
assert isinstance(runner._migrations, list)
async def test_initialize_creates_table(
self, migrations_dir, mock_session
):
"""Test initialization creates migration_history table."""
runner = MigrationRunner(migrations_dir, mock_session)
await runner.initialize()
mock_session.execute.assert_called()
mock_session.commit.assert_called()
async def test_load_migrations_empty_dir(
self, migrations_dir, mock_session
):
"""Test loading migrations from empty directory."""
runner = MigrationRunner(migrations_dir, mock_session)
runner.load_migrations()
assert len(runner._migrations) == 0
async def test_get_applied_migrations(
self, migrations_dir, mock_session
):
"""Test retrieving list of applied migrations."""
# Mock database response
mock_result = Mock()
mock_result.fetchall.return_value = [
("20250124_001",),
("20250124_002",),
]
mock_session.execute.return_value = mock_result
runner = MigrationRunner(migrations_dir, mock_session)
applied = await runner.get_applied_migrations()
assert len(applied) == 2
assert "20250124_001" in applied
assert "20250124_002" in applied
async def test_apply_migration_success(
self, migrations_dir, mock_session
):
"""Test successful migration application."""
class TestMig(Migration):
async def upgrade(self, session):
return None
async def downgrade(self, session):
return None
mig = TestMig(version="20250124_001", description="Test")
runner = MigrationRunner(migrations_dir, mock_session)
await runner.apply_migration(mig)
mock_session.commit.assert_called()
async def test_apply_migration_failure(
self, migrations_dir, mock_session
):
"""Test migration application handles failures."""
class FailingMig(Migration):
async def upgrade(self, session):
raise Exception("Test failure")
async def downgrade(self, session):
return None
mig = FailingMig(version="20250124_001", description="Failing")
runner = MigrationRunner(migrations_dir, mock_session)
with pytest.raises(MigrationError):
await runner.apply_migration(mig)
mock_session.rollback.assert_called()
async def test_get_pending_migrations(
self, migrations_dir, mock_session
):
"""Test retrieving pending migrations."""
class TestMig1(Migration):
async def upgrade(self, session):
return None
async def downgrade(self, session):
return None
class TestMig2(Migration):
async def upgrade(self, session):
return None
async def downgrade(self, session):
return None
mig1 = TestMig1(version="20250124_001", description="Applied")
mig2 = TestMig2(version="20250124_002", description="Pending")
runner = MigrationRunner(migrations_dir, mock_session)
runner._migrations = [mig1, mig2]
# Mock only mig1 as applied
mock_result = Mock()
mock_result.fetchall.return_value = [("20250124_001",)]
mock_session.execute.return_value = mock_result
pending = await runner.get_pending_migrations()
assert len(pending) == 1
assert pending[0].version == "20250124_002"

View File

@ -173,8 +173,6 @@ class TestSerieProperties:
def test_serie_save_and_load_from_file(self):
"""Test saving and loading Serie from file."""
import warnings
serie = Serie(
key="test-key",
name="Test Series",
@ -192,15 +190,11 @@ class TestSerieProperties:
temp_filename = f.name
try:
# Suppress deprecation warnings for this test
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
# Save to file
serie.save_to_file(temp_filename)
# Load from file
loaded_serie = Serie.load_from_file(temp_filename)
# Save to file
serie.save_to_file(temp_filename)
# Load from file
loaded_serie = Serie.load_from_file(temp_filename)
# Verify all properties match
assert loaded_serie.key == serie.key
@ -248,75 +242,3 @@ class TestSerieDocumentation:
assert Serie.folder.fget.__doc__ is not None
assert "metadata" in Serie.folder.fget.__doc__.lower()
assert "not used for lookups" in Serie.folder.fget.__doc__.lower()
class TestSerieDeprecationWarnings:
"""Test deprecation warnings for file-based methods."""
def test_save_to_file_raises_deprecation_warning(self):
"""Test save_to_file() raises deprecation warning."""
import warnings
serie = Serie(
key="test-key",
name="Test Series",
site="https://example.com",
folder="Test Folder",
episodeDict={1: [1, 2, 3]}
)
with tempfile.NamedTemporaryFile(
mode='w', suffix='.json', delete=False
) as temp_file:
temp_filename = temp_file.name
try:
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
serie.save_to_file(temp_filename)
# Check deprecation warning was raised
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert "deprecated" in str(w[0].message).lower()
assert "save_to_file" in str(w[0].message)
finally:
if os.path.exists(temp_filename):
os.remove(temp_filename)
def test_load_from_file_raises_deprecation_warning(self):
"""Test load_from_file() raises deprecation warning."""
import warnings
serie = Serie(
key="test-key",
name="Test Series",
site="https://example.com",
folder="Test Folder",
episodeDict={1: [1, 2, 3]}
)
with tempfile.NamedTemporaryFile(
mode='w', suffix='.json', delete=False
) as temp_file:
temp_filename = temp_file.name
try:
# Save first (suppress warning for this)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
serie.save_to_file(temp_filename)
# Now test loading
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
Serie.load_from_file(temp_filename)
# Check deprecation warning was raised
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert "deprecated" in str(w[0].message).lower()
assert "load_from_file" in str(w[0].message)
finally:
if os.path.exists(temp_filename):
os.remove(temp_filename)

View File

@ -2,8 +2,6 @@
import os
import tempfile
import warnings
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@ -30,41 +28,6 @@ def sample_serie():
)
@pytest.fixture
def mock_db_session():
"""Create a mock async database session."""
session = AsyncMock()
return session
@pytest.fixture
def mock_anime_series():
"""Create a mock AnimeSeries database model."""
anime_series = MagicMock()
anime_series.key = "test-series"
anime_series.name = "Test Series"
anime_series.site = "https://aniworld.to/anime/stream/test-series"
anime_series.folder = "Test Series (2020)"
# Mock episodes relationship
mock_ep1 = MagicMock()
mock_ep1.season = 1
mock_ep1.episode_number = 1
mock_ep2 = MagicMock()
mock_ep2.season = 1
mock_ep2.episode_number = 2
mock_ep3 = MagicMock()
mock_ep3.season = 1
mock_ep3.episode_number = 3
mock_ep4 = MagicMock()
mock_ep4.season = 2
mock_ep4.episode_number = 1
mock_ep5 = MagicMock()
mock_ep5.season = 2
mock_ep5.episode_number = 2
anime_series.episodes = [mock_ep1, mock_ep2, mock_ep3, mock_ep4, mock_ep5]
return anime_series
class TestSerieListKeyBasedStorage:
"""Test SerieList uses key for internal storage."""
@ -77,9 +40,7 @@ class TestSerieListKeyBasedStorage:
def test_add_stores_by_key(self, temp_directory, sample_serie):
"""Test add() stores series by key."""
serie_list = SerieList(temp_directory)
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
serie_list.add(sample_serie)
serie_list.add(sample_serie)
# Verify stored by key, not folder
assert sample_serie.key in serie_list.keyDict
@ -88,9 +49,7 @@ class TestSerieListKeyBasedStorage:
def test_contains_checks_by_key(self, temp_directory, sample_serie):
"""Test contains() checks by key."""
serie_list = SerieList(temp_directory)
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
serie_list.add(sample_serie)
serie_list.add(sample_serie)
assert serie_list.contains(sample_serie.key)
assert not serie_list.contains("nonexistent-key")
@ -101,13 +60,11 @@ class TestSerieListKeyBasedStorage:
"""Test add() prevents duplicates based on key."""
serie_list = SerieList(temp_directory)
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
# Add same serie twice
serie_list.add(sample_serie)
initial_count = len(serie_list.keyDict)
serie_list.add(sample_serie)
# Add same serie twice
serie_list.add(sample_serie)
initial_count = len(serie_list.keyDict)
serie_list.add(sample_serie)
# Should still have only one entry
assert len(serie_list.keyDict) == initial_count
@ -118,9 +75,7 @@ class TestSerieListKeyBasedStorage:
):
"""Test get_by_key() retrieves series correctly."""
serie_list = SerieList(temp_directory)
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
serie_list.add(sample_serie)
serie_list.add(sample_serie)
result = serie_list.get_by_key(sample_serie.key)
assert result is not None
@ -139,11 +94,9 @@ class TestSerieListKeyBasedStorage:
):
"""Test get_by_folder() provides backward compatibility."""
serie_list = SerieList(temp_directory)
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
serie_list.add(sample_serie)
result = serie_list.get_by_folder(sample_serie.folder)
serie_list.add(sample_serie)
result = serie_list.get_by_folder(sample_serie.folder)
assert result is not None
assert result.key == sample_serie.key
assert result.folder == sample_serie.folder
@ -152,14 +105,13 @@ class TestSerieListKeyBasedStorage:
"""Test get_by_folder() returns None for nonexistent folder."""
serie_list = SerieList(temp_directory)
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
result = serie_list.get_by_folder("Nonexistent Folder")
result = serie_list.get_by_folder("Nonexistent Folder")
assert result is None
def test_get_all_returns_all_series(self, temp_directory, sample_serie):
"""Test get_all() returns all series from keyDict."""
serie_list = SerieList(temp_directory)
serie_list.add(sample_serie)
serie2 = Serie(
key="naruto",
@ -168,11 +120,7 @@ class TestSerieListKeyBasedStorage:
folder="Naruto (2002)",
episodeDict={1: [1, 2]}
)
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
serie_list.add(sample_serie)
serie_list.add(serie2)
serie_list.add(serie2)
all_series = serie_list.get_all()
assert len(all_series) == 2
@ -203,10 +151,8 @@ class TestSerieListKeyBasedStorage:
episodeDict={}
)
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
serie_list.add(serie_with_episodes)
serie_list.add(serie_without_episodes)
serie_list.add(serie_with_episodes)
serie_list.add(serie_without_episodes)
missing = serie_list.get_missing_episodes()
assert len(missing) == 1
@ -238,10 +184,8 @@ class TestSerieListPublicAPI:
"""Test that all public methods work correctly after refactoring."""
serie_list = SerieList(temp_directory)
# Test add (suppress deprecation warning for test)
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
serie_list.add(sample_serie)
# Test add
serie_list.add(sample_serie)
# Test contains
assert serie_list.contains(sample_serie.key)
@ -256,296 +200,4 @@ class TestSerieListPublicAPI:
# Test new helper methods
assert serie_list.get_by_key(sample_serie.key) is not None
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
assert serie_list.get_by_folder(sample_serie.folder) is not None
class TestSerieListDatabaseMode:
"""Test SerieList database-backed storage functionality."""
def test_init_with_db_session_skips_file_load(
self, temp_directory, mock_db_session
):
"""Test initialization with db_session skips file-based loading."""
# Create a data file that should NOT be loaded
folder_path = os.path.join(temp_directory, "Test Folder")
os.makedirs(folder_path, exist_ok=True)
data_path = os.path.join(folder_path, "data")
serie = Serie(
key="test-key",
name="Test",
site="https://test.com",
folder="Test Folder",
episodeDict={}
)
serie.save_to_file(data_path)
# Initialize with db_session - should skip file loading
serie_list = SerieList(
temp_directory,
db_session=mock_db_session
)
# Should have empty keyDict (file loading skipped)
assert len(serie_list.keyDict) == 0
def test_init_with_skip_load(self, temp_directory):
"""Test initialization with skip_load=True skips loading."""
serie_list = SerieList(temp_directory, skip_load=True)
assert len(serie_list.keyDict) == 0
def test_convert_from_db_basic(self, mock_anime_series):
"""Test _convert_from_db converts AnimeSeries to Serie correctly."""
serie = SerieList._convert_from_db(mock_anime_series)
assert serie.key == mock_anime_series.key
assert serie.name == mock_anime_series.name
assert serie.site == mock_anime_series.site
assert serie.folder == mock_anime_series.folder
# Season keys should be built from episodes relationship
assert 1 in serie.episodeDict
assert 2 in serie.episodeDict
assert serie.episodeDict[1] == [1, 2, 3]
assert serie.episodeDict[2] == [1, 2]
def test_convert_from_db_empty_episodes(self, mock_anime_series):
"""Test _convert_from_db handles empty episodes."""
mock_anime_series.episodes = []
serie = SerieList._convert_from_db(mock_anime_series)
assert serie.episodeDict == {}
def test_convert_from_db_none_episodes(self, mock_anime_series):
"""Test _convert_from_db handles None episodes."""
mock_anime_series.episodes = None
serie = SerieList._convert_from_db(mock_anime_series)
assert serie.episodeDict == {}
def test_convert_to_db_dict(self, sample_serie):
"""Test _convert_to_db_dict creates correct dictionary."""
result = SerieList._convert_to_db_dict(sample_serie)
assert result["key"] == sample_serie.key
assert result["name"] == sample_serie.name
assert result["site"] == sample_serie.site
assert result["folder"] == sample_serie.folder
# episode_dict should not be in result anymore
assert "episode_dict" not in result
def test_convert_to_db_dict_empty_episode_dict(self):
"""Test _convert_to_db_dict handles empty episode_dict."""
serie = Serie(
key="test",
name="Test",
site="https://test.com",
folder="Test",
episodeDict={}
)
result = SerieList._convert_to_db_dict(serie)
# episode_dict should not be in result anymore
assert "episode_dict" not in result
class TestSerieListDatabaseAsync:
"""Test async database methods of SerieList."""
@pytest.mark.asyncio
async def test_load_series_from_db(
self, temp_directory, mock_db_session, mock_anime_series
):
"""Test load_series_from_db loads from database."""
# Setup mock to return list of anime series
with patch(
'src.server.database.service.AnimeSeriesService'
) as mock_service:
mock_service.get_all = AsyncMock(return_value=[mock_anime_series])
serie_list = SerieList(temp_directory, skip_load=True)
count = await serie_list.load_series_from_db(mock_db_session)
assert count == 1
assert mock_anime_series.key in serie_list.keyDict
@pytest.mark.asyncio
async def test_load_series_from_db_clears_existing(
self, temp_directory, mock_db_session, mock_anime_series
):
"""Test load_series_from_db clears existing data."""
serie_list = SerieList(temp_directory, skip_load=True)
# Add an existing entry
serie_list.keyDict["old-key"] = MagicMock()
with patch(
'src.server.database.service.AnimeSeriesService'
) as mock_service:
mock_service.get_all = AsyncMock(return_value=[mock_anime_series])
await serie_list.load_series_from_db(mock_db_session)
# Old entry should be cleared
assert "old-key" not in serie_list.keyDict
assert mock_anime_series.key in serie_list.keyDict
@pytest.mark.asyncio
async def test_add_to_db_creates_new_series(
self, temp_directory, mock_db_session, sample_serie
):
"""Test add_to_db creates new series in database."""
with patch(
'src.server.database.service.AnimeSeriesService'
) as mock_service:
mock_service.get_by_key = AsyncMock(return_value=None)
mock_created = MagicMock()
mock_created.id = 1
mock_service.create = AsyncMock(return_value=mock_created)
serie_list = SerieList(temp_directory, skip_load=True)
result = await serie_list.add_to_db(sample_serie, mock_db_session)
assert result is mock_created
mock_service.create.assert_called_once()
# Should also add to in-memory collection
assert sample_serie.key in serie_list.keyDict
@pytest.mark.asyncio
async def test_add_to_db_skips_existing(
self, temp_directory, mock_db_session, sample_serie
):
"""Test add_to_db skips if series already exists."""
with patch(
'src.server.database.service.AnimeSeriesService'
) as mock_service:
existing = MagicMock()
mock_service.get_by_key = AsyncMock(return_value=existing)
serie_list = SerieList(temp_directory, skip_load=True)
result = await serie_list.add_to_db(sample_serie, mock_db_session)
assert result is None
mock_service.create.assert_not_called()
@pytest.mark.asyncio
async def test_contains_in_db_returns_true_when_exists(
self, temp_directory, mock_db_session
):
"""Test contains_in_db returns True when series exists."""
with patch(
'src.server.database.service.AnimeSeriesService'
) as mock_service:
mock_service.get_by_key = AsyncMock(return_value=MagicMock())
serie_list = SerieList(temp_directory, skip_load=True)
result = await serie_list.contains_in_db(
"test-key", mock_db_session
)
assert result is True
@pytest.mark.asyncio
async def test_contains_in_db_returns_false_when_not_exists(
self, temp_directory, mock_db_session
):
"""Test contains_in_db returns False when series doesn't exist."""
with patch(
'src.server.database.service.AnimeSeriesService'
) as mock_service:
mock_service.get_by_key = AsyncMock(return_value=None)
serie_list = SerieList(temp_directory, skip_load=True)
result = await serie_list.contains_in_db(
"nonexistent", mock_db_session
)
assert result is False
class TestSerieListDeprecationWarnings:
"""Test deprecation warnings are raised for file-based methods."""
def test_add_raises_deprecation_warning(
self, temp_directory, sample_serie
):
"""Test add() raises deprecation warning."""
serie_list = SerieList(temp_directory, skip_load=True)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
serie_list.add(sample_serie)
# Check at least one deprecation warning was raised for add()
# (Note: save_to_file also raises a warning, so we may get 2)
deprecation_warnings = [
warning for warning in w
if issubclass(warning.category, DeprecationWarning)
]
assert len(deprecation_warnings) >= 1
# Check that one of them is from add()
add_warnings = [
warning for warning in deprecation_warnings
if "add_to_db()" in str(warning.message)
]
assert len(add_warnings) == 1
def test_get_by_folder_raises_deprecation_warning(
self, temp_directory, sample_serie
):
"""Test get_by_folder() raises deprecation warning."""
serie_list = SerieList(temp_directory, skip_load=True)
serie_list.keyDict[sample_serie.key] = sample_serie
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
serie_list.get_by_folder(sample_serie.folder)
# Check deprecation warning was raised
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert "get_by_key()" in str(w[0].message)
class TestSerieListBackwardCompatibility:
"""Test backward compatibility of file-based operations."""
def test_file_based_mode_still_works(
self, temp_directory, sample_serie
):
"""Test file-based mode still works without db_session."""
serie_list = SerieList(temp_directory)
# Add should still work (with deprecation warning)
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
serie_list.add(sample_serie)
# File should be created
data_path = os.path.join(
temp_directory, sample_serie.folder, "data"
)
assert os.path.isfile(data_path)
# Series should be in memory
assert serie_list.contains(sample_serie.key)
def test_load_from_file_still_works(
self, temp_directory, sample_serie
):
"""Test loading from files still works."""
# Create directory and save file
folder_path = os.path.join(temp_directory, sample_serie.folder)
os.makedirs(folder_path, exist_ok=True)
data_path = os.path.join(folder_path, "data")
sample_serie.save_to_file(data_path)
# New SerieList should load it
serie_list = SerieList(temp_directory)
assert serie_list.contains(sample_serie.key)
loaded = serie_list.get_by_key(sample_serie.key)
assert loaded.name == sample_serie.name
assert serie_list.get_by_folder(sample_serie.folder) is not None

View File

@ -1,471 +0,0 @@
"""Tests for SerieScanner class - database and file-based operations."""
import os
import tempfile
import warnings
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from src.core.entities.series import Serie
from src.core.SerieScanner import SerieScanner
@pytest.fixture
def temp_directory():
"""Create a temporary directory with subdirectories for testing."""
with tempfile.TemporaryDirectory() as tmpdir:
# Create an anime folder with an mp4 file
anime_folder = os.path.join(tmpdir, "Attack on Titan (2013)")
os.makedirs(anime_folder, exist_ok=True)
# Create a dummy mp4 file
mp4_path = os.path.join(
anime_folder, "Attack on Titan - S01E001 - (German Dub).mp4"
)
with open(mp4_path, "w") as f:
f.write("dummy mp4")
yield tmpdir
@pytest.fixture
def mock_loader():
"""Create a mock Loader instance."""
loader = MagicMock()
loader.get_season_episode_count = MagicMock(return_value={1: 25})
loader.is_language = MagicMock(return_value=True)
return loader
@pytest.fixture
def mock_db_session():
"""Create a mock async database session."""
session = AsyncMock()
return session
@pytest.fixture
def sample_serie():
"""Create a sample Serie for testing."""
return Serie(
key="attack-on-titan",
name="Attack on Titan",
site="aniworld.to",
folder="Attack on Titan (2013)",
episodeDict={1: [2, 3, 4]}
)
class TestSerieScannerInitialization:
"""Test SerieScanner initialization."""
def test_init_success(self, temp_directory, mock_loader):
"""Test successful initialization."""
scanner = SerieScanner(temp_directory, mock_loader)
assert scanner.directory == os.path.abspath(temp_directory)
assert scanner.loader == mock_loader
assert scanner.keyDict == {}
def test_init_with_db_session(
self, temp_directory, mock_loader, mock_db_session
):
"""Test initialization with database session."""
scanner = SerieScanner(
temp_directory,
mock_loader,
db_session=mock_db_session
)
assert scanner._db_session == mock_db_session
def test_init_empty_path_raises_error(self, mock_loader):
"""Test initialization with empty path raises ValueError."""
with pytest.raises(ValueError, match="empty"):
SerieScanner("", mock_loader)
def test_init_nonexistent_path_raises_error(self, mock_loader):
"""Test initialization with non-existent path raises ValueError."""
with pytest.raises(ValueError, match="does not exist"):
SerieScanner("/nonexistent/path", mock_loader)
class TestSerieScannerScanDeprecation:
"""Test scan() deprecation warning."""
def test_scan_raises_deprecation_warning(
self, temp_directory, mock_loader
):
"""Test that scan() raises a deprecation warning."""
scanner = SerieScanner(temp_directory, mock_loader)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
# Mock the internal methods to avoid actual scanning
with patch.object(scanner, 'get_total_to_scan', return_value=0):
with patch.object(
scanner, '_SerieScanner__find_mp4_files',
return_value=iter([])
):
scanner.scan()
# Check deprecation warning was raised
assert len(w) >= 1
deprecation_warnings = [
warning for warning in w
if issubclass(warning.category, DeprecationWarning)
]
assert len(deprecation_warnings) >= 1
assert "scan_async()" in str(deprecation_warnings[0].message)
class TestSerieScannerAsyncScan:
"""Test async database scanning methods."""
@pytest.mark.asyncio
async def test_scan_async_saves_to_database(
self, temp_directory, mock_loader, mock_db_session, sample_serie
):
"""Test scan_async saves results to database."""
scanner = SerieScanner(temp_directory, mock_loader)
# Mock the internal methods
with patch.object(scanner, 'get_total_to_scan', return_value=1):
with patch.object(
scanner,
'_SerieScanner__find_mp4_files',
return_value=iter([
("Attack on Titan (2013)", ["S01E001.mp4"])
])
):
with patch.object(
scanner,
'_SerieScanner__read_data_from_file',
return_value=sample_serie
):
with patch.object(
scanner,
'_SerieScanner__get_missing_episodes_and_season',
return_value=({1: [2, 3]}, "aniworld.to")
):
with patch(
'src.server.database.service.AnimeSeriesService'
) as mock_service:
mock_service.get_by_key = AsyncMock(
return_value=None
)
mock_created = MagicMock()
mock_created.id = 1
mock_service.create = AsyncMock(
return_value=mock_created
)
await scanner.scan_async(mock_db_session)
# Verify database create was called
mock_service.create.assert_called_once()
@pytest.mark.asyncio
async def test_scan_async_updates_existing_series(
self, temp_directory, mock_loader, mock_db_session, sample_serie
):
"""Test scan_async updates existing series in database."""
scanner = SerieScanner(temp_directory, mock_loader)
# Mock existing series in database with different episodes
existing = MagicMock()
existing.id = 1
existing.folder = sample_serie.folder
# Mock episodes (different from sample_serie)
mock_existing_episodes = [
MagicMock(season=1, episode_number=5),
MagicMock(season=1, episode_number=6),
]
with patch.object(scanner, 'get_total_to_scan', return_value=1):
with patch.object(
scanner,
'_SerieScanner__find_mp4_files',
return_value=iter([
("Attack on Titan (2013)", ["S01E001.mp4"])
])
):
with patch.object(
scanner,
'_SerieScanner__read_data_from_file',
return_value=sample_serie
):
with patch.object(
scanner,
'_SerieScanner__get_missing_episodes_and_season',
return_value=({1: [2, 3]}, "aniworld.to")
):
with patch(
'src.server.database.service.AnimeSeriesService'
) as mock_service:
with patch(
'src.server.database.service.EpisodeService'
) as mock_ep_service:
mock_service.get_by_key = AsyncMock(
return_value=existing
)
mock_service.update = AsyncMock(
return_value=existing
)
mock_ep_service.get_by_series = AsyncMock(
return_value=mock_existing_episodes
)
mock_ep_service.create = AsyncMock()
await scanner.scan_async(mock_db_session)
# Verify episodes were created
assert mock_ep_service.create.called
@pytest.mark.asyncio
async def test_scan_async_handles_errors_gracefully(
self, temp_directory, mock_loader, mock_db_session
):
"""Test scan_async handles folder processing errors gracefully."""
scanner = SerieScanner(temp_directory, mock_loader)
with patch.object(scanner, 'get_total_to_scan', return_value=1):
with patch.object(
scanner,
'_SerieScanner__find_mp4_files',
return_value=iter([
("Error Folder", ["S01E001.mp4"])
])
):
with patch.object(
scanner,
'_SerieScanner__read_data_from_file',
side_effect=Exception("Test error")
):
# Should not raise, should continue
await scanner.scan_async(mock_db_session)
class TestSerieScannerDatabaseHelpers:
"""Test database helper methods."""
@pytest.mark.asyncio
async def test_save_serie_to_db_creates_new(
self, temp_directory, mock_loader, mock_db_session, sample_serie
):
"""Test _save_serie_to_db creates new series."""
scanner = SerieScanner(temp_directory, mock_loader)
with patch(
'src.server.database.service.AnimeSeriesService'
) as mock_service:
with patch(
'src.server.database.service.EpisodeService'
) as mock_ep_service:
mock_service.get_by_key = AsyncMock(return_value=None)
mock_created = MagicMock()
mock_created.id = 1
mock_service.create = AsyncMock(return_value=mock_created)
mock_ep_service.create = AsyncMock()
result = await scanner._save_serie_to_db(
sample_serie, mock_db_session
)
assert result is mock_created
mock_service.create.assert_called_once()
@pytest.mark.asyncio
async def test_save_serie_to_db_updates_existing(
self, temp_directory, mock_loader, mock_db_session, sample_serie
):
"""Test _save_serie_to_db updates existing series."""
scanner = SerieScanner(temp_directory, mock_loader)
existing = MagicMock()
existing.id = 1
existing.folder = sample_serie.folder
# Mock existing episodes (different from sample_serie)
mock_existing_episodes = [
MagicMock(season=1, episode_number=5),
MagicMock(season=1, episode_number=6),
]
with patch(
'src.server.database.service.AnimeSeriesService'
) as mock_service:
with patch(
'src.server.database.service.EpisodeService'
) as mock_ep_service:
mock_service.get_by_key = AsyncMock(return_value=existing)
mock_service.update = AsyncMock(return_value=existing)
mock_ep_service.get_by_series = AsyncMock(
return_value=mock_existing_episodes
)
mock_ep_service.create = AsyncMock()
result = await scanner._save_serie_to_db(
sample_serie, mock_db_session
)
assert result is existing
# Should have created new episodes
assert mock_ep_service.create.called
@pytest.mark.asyncio
async def test_save_serie_to_db_skips_unchanged(
self, temp_directory, mock_loader, mock_db_session, sample_serie
):
"""Test _save_serie_to_db skips update if unchanged."""
scanner = SerieScanner(temp_directory, mock_loader)
existing = MagicMock()
existing.id = 1
existing.folder = sample_serie.folder
# Mock episodes matching sample_serie.episodeDict
mock_existing_episodes = []
for season, ep_nums in sample_serie.episodeDict.items():
for ep_num in ep_nums:
mock_existing_episodes.append(
MagicMock(season=season, episode_number=ep_num)
)
with patch(
'src.server.database.service.AnimeSeriesService'
) as mock_service:
with patch(
'src.server.database.service.EpisodeService'
) as mock_ep_service:
mock_service.get_by_key = AsyncMock(return_value=existing)
mock_ep_service.get_by_series = AsyncMock(
return_value=mock_existing_episodes
)
result = await scanner._save_serie_to_db(
sample_serie, mock_db_session
)
assert result is None
mock_service.update.assert_not_called()
@pytest.mark.asyncio
async def test_update_serie_in_db_updates_existing(
self, temp_directory, mock_loader, mock_db_session, sample_serie
):
"""Test _update_serie_in_db updates existing series."""
scanner = SerieScanner(temp_directory, mock_loader)
existing = MagicMock()
existing.id = 1
with patch(
'src.server.database.service.AnimeSeriesService'
) as mock_service:
with patch(
'src.server.database.service.EpisodeService'
) as mock_ep_service:
mock_service.get_by_key = AsyncMock(return_value=existing)
mock_service.update = AsyncMock(return_value=existing)
mock_ep_service.get_by_series = AsyncMock(return_value=[])
mock_ep_service.create = AsyncMock()
result = await scanner._update_serie_in_db(
sample_serie, mock_db_session
)
assert result is existing
mock_service.update.assert_called_once()
@pytest.mark.asyncio
async def test_update_serie_in_db_returns_none_if_not_found(
self, temp_directory, mock_loader, mock_db_session, sample_serie
):
"""Test _update_serie_in_db returns None if series not found."""
scanner = SerieScanner(temp_directory, mock_loader)
with patch(
'src.server.database.service.AnimeSeriesService'
) as mock_service:
mock_service.get_by_key = AsyncMock(return_value=None)
result = await scanner._update_serie_in_db(
sample_serie, mock_db_session
)
assert result is None
class TestSerieScannerBackwardCompatibility:
"""Test backward compatibility of file-based operations."""
def test_file_based_scan_still_works(
self, temp_directory, mock_loader, sample_serie
):
"""Test file-based scan still works with deprecation warning."""
scanner = SerieScanner(temp_directory, mock_loader)
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
with patch.object(scanner, 'get_total_to_scan', return_value=1):
with patch.object(
scanner,
'_SerieScanner__find_mp4_files',
return_value=iter([
("Attack on Titan (2013)", ["S01E001.mp4"])
])
):
with patch.object(
scanner,
'_SerieScanner__read_data_from_file',
return_value=sample_serie
):
with patch.object(
scanner,
'_SerieScanner__get_missing_episodes_and_season',
return_value=({1: [2, 3]}, "aniworld.to")
):
with patch.object(
sample_serie, 'save_to_file'
) as mock_save:
scanner.scan()
# Verify file was saved
mock_save.assert_called_once()
def test_keydict_populated_after_scan(
self, temp_directory, mock_loader, sample_serie
):
"""Test keyDict is populated after scan."""
scanner = SerieScanner(temp_directory, mock_loader)
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
with patch.object(scanner, 'get_total_to_scan', return_value=1):
with patch.object(
scanner,
'_SerieScanner__find_mp4_files',
return_value=iter([
("Attack on Titan (2013)", ["S01E001.mp4"])
])
):
with patch.object(
scanner,
'_SerieScanner__read_data_from_file',
return_value=sample_serie
):
with patch.object(
scanner,
'_SerieScanner__get_missing_episodes_and_season',
return_value=({1: [2, 3]}, "aniworld.to")
):
with patch.object(sample_serie, 'save_to_file'):
scanner.scan()
assert sample_serie.key in scanner.keyDict

View File

@ -385,177 +385,3 @@ class TestSeriesAppGetters:
pass
class TestSeriesAppDatabaseInit:
"""Test SeriesApp database initialization."""
@patch('src.core.SeriesApp.Loaders')
@patch('src.core.SeriesApp.SerieScanner')
@patch('src.core.SeriesApp.SerieList')
def test_init_without_db_session(
self, mock_serie_list, mock_scanner, mock_loaders
):
"""Test SeriesApp initializes without database session."""
test_dir = "/test/anime"
# Create app without db_session
app = SeriesApp(test_dir)
# Verify db_session is None
assert app._db_session is None
assert app.db_session is None
# Verify SerieList was called with db_session=None
mock_serie_list.assert_called_once()
call_kwargs = mock_serie_list.call_args[1]
assert call_kwargs.get("db_session") is None
# Verify SerieScanner was called with db_session=None
call_kwargs = mock_scanner.call_args[1]
assert call_kwargs.get("db_session") is None
@patch('src.core.SeriesApp.Loaders')
@patch('src.core.SeriesApp.SerieScanner')
@patch('src.core.SeriesApp.SerieList')
def test_init_with_db_session(
self, mock_serie_list, mock_scanner, mock_loaders
):
"""Test SeriesApp initializes with database session."""
test_dir = "/test/anime"
mock_db = Mock()
# Create app with db_session
app = SeriesApp(test_dir, db_session=mock_db)
# Verify db_session is set
assert app._db_session is mock_db
assert app.db_session is mock_db
# Verify SerieList was called with db_session
call_kwargs = mock_serie_list.call_args[1]
assert call_kwargs.get("db_session") is mock_db
# Verify SerieScanner was called with db_session
call_kwargs = mock_scanner.call_args[1]
assert call_kwargs.get("db_session") is mock_db
class TestSeriesAppDatabaseSession:
"""Test SeriesApp database session management."""
@patch('src.core.SeriesApp.Loaders')
@patch('src.core.SeriesApp.SerieScanner')
@patch('src.core.SeriesApp.SerieList')
def test_set_db_session_updates_all_components(
self, mock_serie_list, mock_scanner, mock_loaders
):
"""Test set_db_session updates app, list, and scanner."""
test_dir = "/test/anime"
mock_list = Mock()
mock_list.GetMissingEpisode.return_value = []
mock_scan = Mock()
mock_serie_list.return_value = mock_list
mock_scanner.return_value = mock_scan
# Create app without db_session
app = SeriesApp(test_dir)
assert app.db_session is None
# Create mock database session
mock_db = Mock()
# Set database session
app.set_db_session(mock_db)
# Verify all components are updated
assert app._db_session is mock_db
assert app.db_session is mock_db
assert mock_list._db_session is mock_db
assert mock_scan._db_session is mock_db
@patch('src.core.SeriesApp.Loaders')
@patch('src.core.SeriesApp.SerieScanner')
@patch('src.core.SeriesApp.SerieList')
def test_set_db_session_to_none(
self, mock_serie_list, mock_scanner, mock_loaders
):
"""Test setting db_session to None."""
test_dir = "/test/anime"
mock_list = Mock()
mock_list.GetMissingEpisode.return_value = []
mock_scan = Mock()
mock_serie_list.return_value = mock_list
mock_scanner.return_value = mock_scan
mock_db = Mock()
# Create app with db_session
app = SeriesApp(test_dir, db_session=mock_db)
# Set database session to None
app.set_db_session(None)
# Verify all components are updated
assert app._db_session is None
assert app.db_session is None
assert mock_list._db_session is None
assert mock_scan._db_session is None
class TestSeriesAppAsyncDbInit:
"""Test SeriesApp async database initialization."""
@pytest.mark.asyncio
@patch('src.core.SeriesApp.Loaders')
@patch('src.core.SeriesApp.SerieScanner')
@patch('src.core.SeriesApp.SerieList')
async def test_init_from_db_async_loads_from_database(
self, mock_serie_list, mock_scanner, mock_loaders
):
"""Test init_from_db_async loads series from database."""
import warnings
test_dir = "/test/anime"
mock_list = Mock()
mock_list.load_series_from_db = AsyncMock()
mock_list.GetMissingEpisode.return_value = [{"name": "Test"}]
mock_serie_list.return_value = mock_list
mock_db = Mock()
# Create app with db_session
app = SeriesApp(test_dir, db_session=mock_db)
# Initialize from database
await app.init_from_db_async()
# Verify load_series_from_db was called
mock_list.load_series_from_db.assert_called_once_with(mock_db)
# Verify series_list is populated
assert len(app.series_list) == 1
@pytest.mark.asyncio
@patch('src.core.SeriesApp.Loaders')
@patch('src.core.SeriesApp.SerieScanner')
@patch('src.core.SeriesApp.SerieList')
async def test_init_from_db_async_without_session_warns(
self, mock_serie_list, mock_scanner, mock_loaders
):
"""Test init_from_db_async warns without db_session."""
import warnings
test_dir = "/test/anime"
mock_list = Mock()
mock_list.GetMissingEpisode.return_value = []
mock_serie_list.return_value = mock_list
# Create app without db_session
app = SeriesApp(test_dir)
# Initialize from database should warn
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
await app.init_from_db_async()
# Check warning was raised
assert len(w) == 1
assert "without db_session" in str(w[0].message)