better db model

This commit is contained in:
Lukas 2025-12-04 19:22:42 +01:00
parent 942f14f746
commit 798461a1ea
18 changed files with 551 additions and 2161 deletions

View File

@ -120,883 +120,3 @@ For each task completed:
- Good foundation for future enhancements if needed
---
## ✅ Completed: Download Queue Migration to SQLite Database
The download queue has been successfully migrated from JSON file to SQLite database:
| Component | Status | Description |
| --------------------- | ------- | ------------------------------------------------- |
| QueueRepository | ✅ Done | `src/server/services/queue_repository.py` |
| DownloadService | ✅ Done | Refactored to use repository pattern |
| Application Startup | ✅ Done | Queue restored from database on startup |
| API Endpoints | ✅ Done | All endpoints work with database-backed queue |
| Tests Updated | ✅ Done | All 1104 tests passing with MockQueueRepository |
| Documentation Updated | ✅ Done | `infrastructure.md` updated with new architecture |
**Key Changes:**
- `DownloadService` no longer uses `persistence_path` parameter
- Queue state is persisted to SQLite via `QueueRepository`
- In-memory cache maintained for performance
- All tests use `MockQueueRepository` fixture
---
## 🧪 Tests for Download Queue Database Migration
### Unit Tests
**File:** `tests/unit/test_queue_repository.py`
```python
"""Unit tests for QueueRepository database adapter."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from datetime import datetime, timezone
from src.server.services.queue_repository import QueueRepository
from src.server.models.download import DownloadItem, DownloadStatus, DownloadPriority
from src.server.database.models import DownloadQueueItem as DBDownloadQueueItem
class TestQueueRepository:
"""Test suite for QueueRepository."""
@pytest.fixture
def mock_db_session(self):
"""Create mock database session."""
session = AsyncMock()
return session
@pytest.fixture
def repository(self, mock_db_session):
"""Create repository instance with mock session."""
return QueueRepository(db_session_factory=lambda: mock_db_session)
@pytest.fixture
def sample_download_item(self):
"""Create sample DownloadItem for testing."""
return DownloadItem(
id="test-uuid-123",
series_key="attack-on-titan",
series_name="Attack on Titan",
season=1,
episode=5,
status=DownloadStatus.PENDING,
priority=DownloadPriority.NORMAL,
progress_percent=0.0,
downloaded_bytes=0,
total_bytes=None,
)
# === Conversion Tests ===
async def test_convert_to_db_model(self, repository, sample_download_item):
"""Test converting DownloadItem to database model."""
# Arrange
series_id = 42
# Act
db_item = repository._to_db_model(sample_download_item, series_id)
# Assert
assert db_item.series_id == series_id
assert db_item.season == sample_download_item.season
assert db_item.episode_number == sample_download_item.episode
assert db_item.status == sample_download_item.status
assert db_item.priority == sample_download_item.priority
async def test_convert_from_db_model(self, repository):
"""Test converting database model to DownloadItem."""
# Arrange
db_item = MagicMock()
db_item.id = 1
db_item.series_id = 42
db_item.series.key = "attack-on-titan"
db_item.series.name = "Attack on Titan"
db_item.season = 1
db_item.episode_number = 5
db_item.status = DownloadStatus.PENDING
db_item.priority = DownloadPriority.NORMAL
db_item.progress_percent = 25.5
db_item.downloaded_bytes = 1024000
db_item.total_bytes = 4096000
# Act
item = repository._from_db_model(db_item)
# Assert
assert item.series_key == "attack-on-titan"
assert item.series_name == "Attack on Titan"
assert item.season == 1
assert item.episode == 5
assert item.progress_percent == 25.5
# === CRUD Operation Tests ===
async def test_save_item_creates_new_record(self, repository, mock_db_session, sample_download_item):
"""Test saving a new download item to database."""
# Arrange
mock_db_session.execute.return_value.scalar_one_or_none.return_value = MagicMock(id=42)
# Act
result = await repository.save_item(sample_download_item)
# Assert
mock_db_session.add.assert_called_once()
mock_db_session.flush.assert_called_once()
assert result is not None
async def test_get_pending_items_returns_ordered_list(self, repository, mock_db_session):
"""Test retrieving pending items ordered by priority."""
# Arrange
mock_items = [MagicMock(), MagicMock()]
mock_db_session.execute.return_value.scalars.return_value.all.return_value = mock_items
# Act
result = await repository.get_pending_items()
# Assert
assert len(result) == 2
mock_db_session.execute.assert_called_once()
async def test_update_status_success(self, repository, mock_db_session):
"""Test updating item status."""
# Arrange
mock_item = MagicMock()
mock_db_session.execute.return_value.scalar_one_or_none.return_value = mock_item
# Act
result = await repository.update_status("test-id", DownloadStatus.DOWNLOADING)
# Assert
assert result is True
assert mock_item.status == DownloadStatus.DOWNLOADING
async def test_update_status_item_not_found(self, repository, mock_db_session):
"""Test updating status for non-existent item."""
# Arrange
mock_db_session.execute.return_value.scalar_one_or_none.return_value = None
# Act
result = await repository.update_status("non-existent", DownloadStatus.DOWNLOADING)
# Assert
assert result is False
async def test_update_progress(self, repository, mock_db_session):
"""Test updating download progress."""
# Arrange
mock_item = MagicMock()
mock_db_session.execute.return_value.scalar_one_or_none.return_value = mock_item
# Act
result = await repository.update_progress(
item_id="test-id",
progress=50.0,
downloaded=2048000,
total=4096000,
speed=1024000.0
)
# Assert
assert result is True
assert mock_item.progress_percent == 50.0
assert mock_item.downloaded_bytes == 2048000
async def test_delete_item_success(self, repository, mock_db_session):
"""Test deleting download item."""
# Arrange
mock_db_session.execute.return_value.rowcount = 1
# Act
result = await repository.delete_item("test-id")
# Assert
assert result is True
async def test_clear_completed_returns_count(self, repository, mock_db_session):
"""Test clearing completed items returns count."""
# Arrange
mock_db_session.execute.return_value.rowcount = 5
# Act
result = await repository.clear_completed()
# Assert
assert result == 5
class TestQueueRepositoryErrorHandling:
"""Test error handling in QueueRepository."""
@pytest.fixture
def mock_db_session(self):
"""Create mock database session."""
return AsyncMock()
@pytest.fixture
def repository(self, mock_db_session):
"""Create repository instance."""
return QueueRepository(db_session_factory=lambda: mock_db_session)
async def test_save_item_handles_database_error(self, repository, mock_db_session):
"""Test handling database errors on save."""
# Arrange
mock_db_session.execute.side_effect = Exception("Database connection failed")
# Act & Assert
with pytest.raises(Exception):
await repository.save_item(MagicMock())
async def test_get_items_handles_database_error(self, repository, mock_db_session):
"""Test handling database errors on query."""
# Arrange
mock_db_session.execute.side_effect = Exception("Query failed")
# Act & Assert
with pytest.raises(Exception):
await repository.get_pending_items()
```
---
**File:** `tests/unit/test_download_service_database.py`
```python
"""Unit tests for DownloadService with database persistence."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from datetime import datetime, timezone
from src.server.services.download_service import DownloadService
from src.server.models.download import DownloadItem, DownloadStatus, DownloadPriority
class TestDownloadServiceDatabasePersistence:
"""Test DownloadService database persistence."""
@pytest.fixture
def mock_anime_service(self):
"""Create mock anime service."""
return AsyncMock()
@pytest.fixture
def mock_queue_repository(self):
"""Create mock queue repository."""
repo = AsyncMock()
repo.get_pending_items.return_value = []
repo.get_active_item.return_value = None
repo.get_completed_items.return_value = []
repo.get_failed_items.return_value = []
return repo
@pytest.fixture
def download_service(self, mock_anime_service, mock_queue_repository):
"""Create download service with mocked dependencies."""
return DownloadService(
anime_service=mock_anime_service,
queue_repository=mock_queue_repository,
)
# === Persistence Tests ===
async def test_add_to_queue_saves_to_database(
self, download_service, mock_queue_repository
):
"""Test that adding to queue persists to database."""
# Arrange
mock_queue_repository.save_item.return_value = MagicMock(id="new-id")
# Act
result = await download_service.add_to_queue(
series_key="test-series",
season=1,
episode=1,
)
# Assert
mock_queue_repository.save_item.assert_called_once()
async def test_startup_loads_from_database(
self, mock_anime_service, mock_queue_repository
):
"""Test that startup loads queue state from database."""
# Arrange
pending_items = [
MagicMock(id="1", status=DownloadStatus.PENDING),
MagicMock(id="2", status=DownloadStatus.PENDING),
]
mock_queue_repository.get_pending_items.return_value = pending_items
# Act
service = DownloadService(
anime_service=mock_anime_service,
queue_repository=mock_queue_repository,
)
await service.initialize()
# Assert
mock_queue_repository.get_pending_items.assert_called()
async def test_download_completion_updates_database(
self, download_service, mock_queue_repository
):
"""Test that download completion updates database status."""
# Arrange
item = MagicMock(id="test-id")
# Act
await download_service._mark_completed(item)
# Assert
mock_queue_repository.update_status.assert_called_with(
"test-id", DownloadStatus.COMPLETED, error=None
)
async def test_download_failure_updates_database(
self, download_service, mock_queue_repository
):
"""Test that download failure updates database with error."""
# Arrange
item = MagicMock(id="test-id")
error_message = "Network timeout"
# Act
await download_service._mark_failed(item, error_message)
# Assert
mock_queue_repository.update_status.assert_called_with(
"test-id", DownloadStatus.FAILED, error=error_message
)
async def test_progress_update_persists_to_database(
self, download_service, mock_queue_repository
):
"""Test that progress updates are persisted."""
# Arrange
item = MagicMock(id="test-id")
# Act
await download_service._update_progress(
item, progress=50.0, downloaded=2048, total=4096, speed=1024.0
)
# Assert
mock_queue_repository.update_progress.assert_called_with(
item_id="test-id",
progress=50.0,
downloaded=2048,
total=4096,
speed=1024.0,
)
async def test_remove_from_queue_deletes_from_database(
self, download_service, mock_queue_repository
):
"""Test that removing from queue deletes from database."""
# Arrange
mock_queue_repository.delete_item.return_value = True
# Act
result = await download_service.remove_from_queue("test-id")
# Assert
mock_queue_repository.delete_item.assert_called_with("test-id")
assert result is True
async def test_clear_completed_clears_database(
self, download_service, mock_queue_repository
):
"""Test that clearing completed items updates database."""
# Arrange
mock_queue_repository.clear_completed.return_value = 5
# Act
result = await download_service.clear_completed()
# Assert
mock_queue_repository.clear_completed.assert_called_once()
assert result == 5
class TestDownloadServiceNoJsonFile:
"""Verify DownloadService no longer uses JSON files."""
async def test_no_json_file_operations(self):
"""Verify no JSON file read/write operations exist."""
import inspect
from src.server.services.download_service import DownloadService
source = inspect.getsource(DownloadService)
# Assert no JSON file operations
assert "download_queue.json" not in source
assert "_load_queue" not in source or "database" in source.lower()
assert "_save_queue" not in source or "database" in source.lower()
```
---
### Integration Tests
**File:** `tests/integration/test_queue_database_integration.py`
```python
"""Integration tests for download queue database operations."""
import pytest
from datetime import datetime, timezone
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
from src.server.database.base import Base
from src.server.database.models import AnimeSeries, DownloadQueueItem, DownloadStatus, DownloadPriority
from src.server.database.service import DownloadQueueService, AnimeSeriesService
from src.server.services.queue_repository import QueueRepository
@pytest.fixture
async def async_engine():
"""Create async test database engine."""
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def async_session(async_engine):
"""Create async session for tests."""
async_session_maker = sessionmaker(
async_engine, class_=AsyncSession, expire_on_commit=False
)
async with async_session_maker() as session:
yield session
await session.rollback()
@pytest.fixture
async def test_series(async_session):
"""Create test anime series."""
series = await AnimeSeriesService.create(
db=async_session,
key="test-anime",
name="Test Anime",
site="https://example.com/test-anime",
folder="Test Anime (2024)",
)
await async_session.commit()
return series
class TestQueueDatabaseIntegration:
"""Integration tests for queue database operations."""
async def test_create_and_retrieve_queue_item(self, async_session, test_series):
"""Test creating and retrieving a queue item."""
# Create
item = await DownloadQueueService.create(
db=async_session,
series_id=test_series.id,
season=1,
episode_number=5,
priority=DownloadPriority.HIGH,
)
await async_session.commit()
# Retrieve
retrieved = await DownloadQueueService.get_by_id(async_session, item.id)
# Assert
assert retrieved is not None
assert retrieved.series_id == test_series.id
assert retrieved.season == 1
assert retrieved.episode_number == 5
assert retrieved.priority == DownloadPriority.HIGH
assert retrieved.status == DownloadStatus.PENDING
async def test_update_download_progress(self, async_session, test_series):
"""Test updating download progress."""
# Create item
item = await DownloadQueueService.create(
db=async_session,
series_id=test_series.id,
season=1,
episode_number=1,
)
await async_session.commit()
# Update progress
updated = await DownloadQueueService.update_progress(
db=async_session,
item_id=item.id,
progress_percent=75.5,
downloaded_bytes=3072000,
total_bytes=4096000,
download_speed=1024000.0,
)
await async_session.commit()
# Assert
assert updated.progress_percent == 75.5
assert updated.downloaded_bytes == 3072000
assert updated.total_bytes == 4096000
assert updated.download_speed == 1024000.0
async def test_status_transitions(self, async_session, test_series):
"""Test download status transitions."""
# Create pending item
item = await DownloadQueueService.create(
db=async_session,
series_id=test_series.id,
season=1,
episode_number=1,
)
await async_session.commit()
assert item.status == DownloadStatus.PENDING
# Transition to downloading
item = await DownloadQueueService.update_status(
async_session, item.id, DownloadStatus.DOWNLOADING
)
await async_session.commit()
assert item.status == DownloadStatus.DOWNLOADING
assert item.started_at is not None
# Transition to completed
item = await DownloadQueueService.update_status(
async_session, item.id, DownloadStatus.COMPLETED
)
await async_session.commit()
assert item.status == DownloadStatus.COMPLETED
assert item.completed_at is not None
async def test_failed_download_with_retry(self, async_session, test_series):
"""Test failed download with error message and retry count."""
# Create item
item = await DownloadQueueService.create(
db=async_session,
series_id=test_series.id,
season=1,
episode_number=1,
)
await async_session.commit()
# Mark as failed with error
item = await DownloadQueueService.update_status(
async_session,
item.id,
DownloadStatus.FAILED,
error_message="Connection timeout",
)
await async_session.commit()
# Assert
assert item.status == DownloadStatus.FAILED
assert item.error_message == "Connection timeout"
assert item.retry_count == 1
async def test_get_pending_items_ordered_by_priority(self, async_session, test_series):
"""Test retrieving pending items ordered by priority."""
# Create items with different priorities
await DownloadQueueService.create(
async_session, test_series.id, 1, 1, priority=DownloadPriority.LOW
)
await DownloadQueueService.create(
async_session, test_series.id, 1, 2, priority=DownloadPriority.HIGH
)
await DownloadQueueService.create(
async_session, test_series.id, 1, 3, priority=DownloadPriority.NORMAL
)
await async_session.commit()
# Get pending items
pending = await DownloadQueueService.get_pending(async_session)
# Assert order: HIGH -> NORMAL -> LOW
assert len(pending) == 3
assert pending[0].priority == DownloadPriority.HIGH
assert pending[1].priority == DownloadPriority.NORMAL
assert pending[2].priority == DownloadPriority.LOW
async def test_clear_completed_items(self, async_session, test_series):
"""Test clearing completed download items."""
# Create items
item1 = await DownloadQueueService.create(
async_session, test_series.id, 1, 1
)
item2 = await DownloadQueueService.create(
async_session, test_series.id, 1, 2
)
item3 = await DownloadQueueService.create(
async_session, test_series.id, 1, 3
)
# Complete first two
await DownloadQueueService.update_status(
async_session, item1.id, DownloadStatus.COMPLETED
)
await DownloadQueueService.update_status(
async_session, item2.id, DownloadStatus.COMPLETED
)
await async_session.commit()
# Clear completed
cleared = await DownloadQueueService.clear_completed(async_session)
await async_session.commit()
# Assert
assert cleared == 2
# Verify pending item remains
remaining = await DownloadQueueService.get_all(async_session)
assert len(remaining) == 1
assert remaining[0].id == item3.id
async def test_cascade_delete_with_series(self, async_session, test_series):
"""Test that queue items are deleted when series is deleted."""
# Create queue items
await DownloadQueueService.create(
async_session, test_series.id, 1, 1
)
await DownloadQueueService.create(
async_session, test_series.id, 1, 2
)
await async_session.commit()
# Delete series
await AnimeSeriesService.delete(async_session, test_series.id)
await async_session.commit()
# Verify queue items are gone
all_items = await DownloadQueueService.get_all(async_session)
assert len(all_items) == 0
```
---
### API Tests
**File:** `tests/api/test_queue_endpoints_database.py`
```python
"""API tests for queue endpoints with database persistence."""
import pytest
from httpx import AsyncClient
from unittest.mock import patch, AsyncMock
class TestQueueAPIWithDatabase:
"""Test queue API endpoints with database backend."""
@pytest.fixture
def auth_headers(self):
"""Get authentication headers."""
return {"Authorization": "Bearer test-token"}
async def test_get_queue_returns_database_items(
self, client: AsyncClient, auth_headers
):
"""Test GET /api/queue returns items from database."""
response = await client.get("/api/queue", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert "pending" in data
assert "active" in data
assert "completed" in data
async def test_add_to_queue_persists_to_database(
self, client: AsyncClient, auth_headers
):
"""Test POST /api/queue persists item to database."""
payload = {
"series_key": "test-anime",
"season": 1,
"episode": 1,
"priority": "normal",
}
response = await client.post(
"/api/queue",
json=payload,
headers=auth_headers,
)
assert response.status_code == 201
data = response.json()
assert "id" in data
async def test_remove_from_queue_deletes_from_database(
self, client: AsyncClient, auth_headers
):
"""Test DELETE /api/queue/{id} removes from database."""
# First add an item
add_response = await client.post(
"/api/queue",
json={"series_key": "test-anime", "season": 1, "episode": 1},
headers=auth_headers,
)
item_id = add_response.json()["id"]
# Then delete it
response = await client.delete(
f"/api/queue/{item_id}",
headers=auth_headers,
)
assert response.status_code == 200
# Verify it's gone
get_response = await client.get("/api/queue", headers=auth_headers)
queue_data = get_response.json()
item_ids = [item["id"] for item in queue_data.get("pending", [])]
assert item_id not in item_ids
async def test_queue_survives_server_restart(
self, client: AsyncClient, auth_headers
):
"""Test that queue items persist across simulated restart."""
# Add item
add_response = await client.post(
"/api/queue",
json={"series_key": "test-anime", "season": 1, "episode": 5},
headers=auth_headers,
)
item_id = add_response.json()["id"]
# Simulate restart by clearing in-memory cache
# (In real scenario, this would be a server restart)
# Verify item still exists
response = await client.get("/api/queue", headers=auth_headers)
queue_data = response.json()
item_ids = [item["id"] for item in queue_data.get("pending", [])]
assert item_id in item_ids
async def test_clear_completed_endpoint(
self, client: AsyncClient, auth_headers
):
"""Test POST /api/queue/clear-completed endpoint."""
response = await client.post(
"/api/queue/clear-completed",
headers=auth_headers,
)
assert response.status_code == 200
data = response.json()
assert "cleared_count" in data
```
---
### Performance Tests
**File:** `tests/performance/test_queue_database_performance.py`
```python
"""Performance tests for database-backed download queue."""
import pytest
import asyncio
import time
from datetime import datetime
class TestQueueDatabasePerformance:
"""Performance tests for queue database operations."""
@pytest.mark.performance
async def test_bulk_insert_performance(self, async_session, test_series):
"""Test performance of bulk queue item insertion."""
from src.server.database.service import DownloadQueueService
start_time = time.time()
# Insert 100 queue items
for i in range(100):
await DownloadQueueService.create(
async_session,
test_series.id,
season=1,
episode_number=i + 1,
)
await async_session.commit()
elapsed = time.time() - start_time
# Should complete in under 2 seconds
assert elapsed < 2.0, f"Bulk insert took {elapsed:.2f}s, expected < 2s"
@pytest.mark.performance
async def test_query_performance_with_many_items(self, async_session, test_series):
"""Test query performance with many queue items."""
from src.server.database.service import DownloadQueueService
# Setup: Create 500 items
for i in range(500):
await DownloadQueueService.create(
async_session,
test_series.id,
season=(i // 12) + 1,
episode_number=(i % 12) + 1,
)
await async_session.commit()
# Test query performance
start_time = time.time()
pending = await DownloadQueueService.get_pending(async_session)
elapsed = time.time() - start_time
# Query should complete in under 100ms
assert elapsed < 0.1, f"Query took {elapsed*1000:.1f}ms, expected < 100ms"
assert len(pending) == 500
@pytest.mark.performance
async def test_progress_update_performance(self, async_session, test_series):
"""Test performance of frequent progress updates."""
from src.server.database.service import DownloadQueueService
# Create item
item = await DownloadQueueService.create(
async_session, test_series.id, 1, 1
)
await async_session.commit()
start_time = time.time()
# Simulate 100 progress updates (like during download)
for i in range(100):
await DownloadQueueService.update_progress(
async_session,
item.id,
progress_percent=i,
downloaded_bytes=i * 10240,
total_bytes=1024000,
download_speed=102400.0,
)
await async_session.commit()
elapsed = time.time() - start_time
# 100 updates should complete in under 1 second
assert elapsed < 1.0, f"Progress updates took {elapsed:.2f}s, expected < 1s"
```
---
## Summary
These tasks will migrate the download queue from JSON file persistence to SQLite database, providing:
1. **Data Integrity**: ACID-compliant storage with proper relationships
2. **Query Capability**: Efficient filtering, sorting, and pagination
3. **Consistency**: Single source of truth for all application data
4. **Scalability**: Better performance for large queues
5. **Recovery**: Robust handling of crashes and restarts
The existing database infrastructure (`DownloadQueueItem` model and `DownloadQueueService`) is already in place, making this primarily an integration task rather than new development.

View File

@ -540,7 +540,7 @@ class SerieScanner:
Save or update a series in the database.
Creates a new record if the series doesn't exist, or updates
the episode_dict if it has changed.
the episodes if they have changed.
Args:
serie: Serie instance to save
@ -549,26 +549,53 @@ class SerieScanner:
Returns:
Created or updated AnimeSeries instance, or None if unchanged
"""
from src.server.database.service import AnimeSeriesService
from src.server.database.service import AnimeSeriesService, EpisodeService
# Check if series already exists
existing = await AnimeSeriesService.get_by_key(db, serie.key)
if existing:
# Update episode_dict if changed
if existing.episode_dict != serie.episodeDict:
updated = await AnimeSeriesService.update(
db,
existing.id,
episode_dict=serie.episodeDict,
folder=serie.folder
)
# Build existing episode dict from episodes for comparison
existing_episodes = await EpisodeService.get_by_series(
db, existing.id
)
existing_dict: dict[int, list[int]] = {}
for ep in existing_episodes:
if ep.season not in existing_dict:
existing_dict[ep.season] = []
existing_dict[ep.season].append(ep.episode_number)
for season in existing_dict:
existing_dict[season].sort()
# Update episodes if changed
if existing_dict != serie.episodeDict:
# Add new episodes
new_dict = serie.episodeDict or {}
for season, episode_numbers in new_dict.items():
existing_eps = set(existing_dict.get(season, []))
for ep_num in episode_numbers:
if ep_num not in existing_eps:
await EpisodeService.create(
db=db,
series_id=existing.id,
season=season,
episode_number=ep_num,
)
# Update folder if changed
if existing.folder != serie.folder:
await AnimeSeriesService.update(
db,
existing.id,
folder=serie.folder
)
logger.info(
"Updated series in database: %s (key=%s)",
serie.name,
serie.key
)
return updated
return existing
else:
logger.debug(
"Series unchanged in database: %s (key=%s)",
@ -584,8 +611,19 @@ class SerieScanner:
name=serie.name,
site=serie.site,
folder=serie.folder,
episode_dict=serie.episodeDict,
)
# Create Episode records
if serie.episodeDict:
for season, episode_numbers in serie.episodeDict.items():
for ep_num in episode_numbers:
await EpisodeService.create(
db=db,
series_id=anime_series.id,
season=season,
episode_number=ep_num,
)
logger.info(
"Created series in database: %s (key=%s)",
serie.name,
@ -608,7 +646,7 @@ class SerieScanner:
Returns:
Updated AnimeSeries instance, or None if not found
"""
from src.server.database.service import AnimeSeriesService
from src.server.database.service import AnimeSeriesService, EpisodeService
existing = await AnimeSeriesService.get_by_key(db, serie.key)
if not existing:
@ -619,20 +657,43 @@ class SerieScanner:
)
return None
updated = await AnimeSeriesService.update(
# Update basic fields
await AnimeSeriesService.update(
db,
existing.id,
name=serie.name,
site=serie.site,
folder=serie.folder,
episode_dict=serie.episodeDict,
)
# Update episodes - add any new ones
if serie.episodeDict:
existing_episodes = await EpisodeService.get_by_series(
db, existing.id
)
existing_dict: dict[int, set[int]] = {}
for ep in existing_episodes:
if ep.season not in existing_dict:
existing_dict[ep.season] = set()
existing_dict[ep.season].add(ep.episode_number)
for season, episode_numbers in serie.episodeDict.items():
existing_eps = existing_dict.get(season, set())
for ep_num in episode_numbers:
if ep_num not in existing_eps:
await EpisodeService.create(
db=db,
series_id=existing.id,
season=season,
episode_number=ep_num,
)
logger.info(
"Updated series in database: %s (key=%s)",
serie.name,
serie.key
)
return updated
return existing
def __find_mp4_files(self) -> Iterator[tuple[str, list[str]]]:
"""Find all .mp4 files in the directory structure."""

View File

@ -147,7 +147,7 @@ class SerieList:
if result:
print(f"Added series: {result.name}")
"""
from src.server.database.service import AnimeSeriesService
from src.server.database.service import AnimeSeriesService, EpisodeService
# Check if series already exists in DB
existing = await AnimeSeriesService.get_by_key(db, serie.key)
@ -166,9 +166,19 @@ class SerieList:
name=serie.name,
site=serie.site,
folder=serie.folder,
episode_dict=serie.episodeDict,
)
# Create Episode records for each episode in episodeDict
if serie.episodeDict:
for season, episode_numbers in serie.episodeDict.items():
for episode_number in episode_numbers:
await EpisodeService.create(
db=db,
series_id=anime_series.id,
season=season,
episode_number=episode_number,
)
# Also add to in-memory collection
self.keyDict[serie.key] = serie
@ -267,8 +277,10 @@ class SerieList:
# Clear existing in-memory data
self.keyDict.clear()
# Load all series from database
anime_series_list = await AnimeSeriesService.get_all(db)
# Load all series from database (with episodes for episodeDict)
anime_series_list = await AnimeSeriesService.get_all(
db, with_episodes=True
)
for anime_series in anime_series_list:
serie = self._convert_from_db(anime_series)
@ -288,23 +300,22 @@ class SerieList:
Args:
anime_series: AnimeSeries model from database
(must have episodes relationship loaded)
Returns:
Serie entity instance
"""
# Convert episode_dict from JSON (string keys) to int keys
# Build episode_dict from episodes relationship
episode_dict: dict[int, list[int]] = {}
if anime_series.episode_dict:
for season_str, episodes in anime_series.episode_dict.items():
try:
season = int(season_str)
episode_dict[season] = list(episodes)
except (ValueError, TypeError):
logger.warning(
"Invalid season key '%s' in episode_dict for %s",
season_str,
anime_series.key
)
if anime_series.episodes:
for episode in anime_series.episodes:
season = episode.season
if season not in episode_dict:
episode_dict[season] = []
episode_dict[season].append(episode.episode_number)
# Sort episode numbers within each season
for season in episode_dict:
episode_dict[season].sort()
return Serie(
key=anime_series.key,
@ -325,19 +336,11 @@ class SerieList:
Returns:
Dictionary suitable for AnimeSeriesService.create()
"""
# Convert episode_dict keys to strings for JSON storage
episode_dict = None
if serie.episodeDict:
episode_dict = {
str(k): list(v) for k, v in serie.episodeDict.items()
}
return {
"key": serie.key,
"name": serie.name,
"site": serie.site,
"folder": serie.folder,
"episode_dict": episode_dict,
}
async def contains_in_db(self, key: str, db: "AsyncSession") -> bool:

View File

@ -229,37 +229,6 @@ class DatabaseIntegrityChecker:
logger.warning(msg)
issues_found += count
# Check for invalid progress percentages
stmt = select(DownloadQueueItem).where(
(DownloadQueueItem.progress < 0) |
(DownloadQueueItem.progress > 100)
)
invalid_progress = self.session.execute(stmt).scalars().all()
if invalid_progress:
count = len(invalid_progress)
msg = (
f"Found {count} queue items with invalid progress "
f"percentages"
)
self.issues.append(msg)
logger.warning(msg)
issues_found += count
# Check for queue items with invalid status
valid_statuses = {'pending', 'downloading', 'completed', 'failed'}
stmt = select(DownloadQueueItem).where(
~DownloadQueueItem.status.in_(valid_statuses)
)
invalid_status = self.session.execute(stmt).scalars().all()
if invalid_status:
count = len(invalid_status)
msg = f"Found {count} queue items with invalid status"
self.issues.append(msg)
logger.warning(msg)
issues_found += count
if issues_found == 0:
logger.info("No data consistency issues found")

View File

@ -669,7 +669,6 @@ async def add_series(
name=request.name.strip(),
site="aniworld.to",
folder=folder,
episode_dict={}, # Empty for new series
)
db_id = anime_series.id

View File

@ -1,479 +0,0 @@
"""Example integration of database service with existing services.
This file demonstrates how to integrate the database service layer with
existing application services like AnimeService and DownloadService.
These examples show patterns for:
- Persisting scan results to database
- Loading queue from database on startup
- Syncing download progress to database
- Maintaining consistency between in-memory state and database
"""
from __future__ import annotations
import logging
from typing import List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from src.core.entities.series import Serie
from src.server.database.models import DownloadPriority, DownloadStatus
from src.server.database.service import (
AnimeSeriesService,
DownloadQueueService,
EpisodeService,
)
logger = logging.getLogger(__name__)
# ============================================================================
# Example 1: Persist Scan Results
# ============================================================================
async def persist_scan_results(
db: AsyncSession,
series_list: List[Serie],
) -> None:
"""Persist scan results to database.
Updates or creates anime series and their episodes based on
scan results from SerieScanner.
Args:
db: Database session
series_list: List of Serie objects from scan
"""
logger.info(f"Persisting {len(series_list)} series to database")
for serie in series_list:
# Check if series exists
existing = await AnimeSeriesService.get_by_key(db, serie.key)
if existing:
# Update existing series
await AnimeSeriesService.update(
db,
existing.id,
name=serie.name,
site=serie.site,
folder=serie.folder,
episode_dict=serie.episode_dict,
)
series_id = existing.id
else:
# Create new series
new_series = await AnimeSeriesService.create(
db,
key=serie.key,
name=serie.name,
site=serie.site,
folder=serie.folder,
episode_dict=serie.episode_dict,
)
series_id = new_series.id
# Update episodes for this series
await _update_episodes(db, series_id, serie)
await db.commit()
logger.info("Scan results persisted successfully")
async def _update_episodes(
db: AsyncSession,
series_id: int,
serie: Serie,
) -> None:
"""Update episodes for a series.
Args:
db: Database session
series_id: Series ID in database
serie: Serie object with episode information
"""
# Get existing episodes
existing_episodes = await EpisodeService.get_by_series(db, series_id)
existing_map = {
(ep.season, ep.episode_number): ep
for ep in existing_episodes
}
# Iterate through episode_dict to create/update episodes
for season, episodes in serie.episode_dict.items():
for ep_num in episodes:
key = (int(season), int(ep_num))
if key in existing_map:
# Episode exists, check if downloaded
episode = existing_map[key]
# Update if needed (e.g., file path changed)
if not episode.is_downloaded:
# Check if file exists locally
# This would be done by checking serie.local_episodes
pass
else:
# Create new episode
await EpisodeService.create(
db,
series_id=series_id,
season=int(season),
episode_number=int(ep_num),
is_downloaded=False,
)
# ============================================================================
# Example 2: Load Queue from Database
# ============================================================================
async def load_queue_from_database(
db: AsyncSession,
) -> List[dict]:
"""Load download queue from database.
Retrieves pending and active download items from database and
converts them to format suitable for DownloadService.
Args:
db: Database session
Returns:
List of download items as dictionaries
"""
logger.info("Loading download queue from database")
# Get pending and active items
pending = await DownloadQueueService.get_pending(db)
active = await DownloadQueueService.get_active(db)
all_items = pending + active
# Convert to dictionary format for DownloadService
queue_items = []
for item in all_items:
queue_items.append({
"id": item.id,
"series_id": item.series_id,
"season": item.season,
"episode_number": item.episode_number,
"status": item.status.value,
"priority": item.priority.value,
"progress_percent": item.progress_percent,
"downloaded_bytes": item.downloaded_bytes,
"total_bytes": item.total_bytes,
"download_speed": item.download_speed,
"error_message": item.error_message,
"retry_count": item.retry_count,
})
logger.info(f"Loaded {len(queue_items)} items from database")
return queue_items
# ============================================================================
# Example 3: Sync Download Progress to Database
# ============================================================================
async def sync_download_progress(
db: AsyncSession,
item_id: int,
progress_percent: float,
downloaded_bytes: int,
total_bytes: Optional[int] = None,
download_speed: Optional[float] = None,
) -> None:
"""Sync download progress to database.
Updates download queue item progress in database. This would be called
from the download progress callback.
Args:
db: Database session
item_id: Download queue item ID
progress_percent: Progress percentage (0-100)
downloaded_bytes: Bytes downloaded
total_bytes: Optional total file size
download_speed: Optional current speed (bytes/sec)
"""
await DownloadQueueService.update_progress(
db,
item_id,
progress_percent,
downloaded_bytes,
total_bytes,
download_speed,
)
await db.commit()
async def mark_download_complete(
db: AsyncSession,
item_id: int,
file_path: str,
file_size: int,
) -> None:
"""Mark download as complete in database.
Updates download queue item status and marks episode as downloaded.
Args:
db: Database session
item_id: Download queue item ID
file_path: Path to downloaded file
file_size: File size in bytes
"""
# Get download item
item = await DownloadQueueService.get_by_id(db, item_id)
if not item:
logger.error(f"Download item {item_id} not found")
return
# Update download status
await DownloadQueueService.update_status(
db,
item_id,
DownloadStatus.COMPLETED,
)
# Find or create episode and mark as downloaded
episode = await EpisodeService.get_by_episode(
db,
item.series_id,
item.season,
item.episode_number,
)
if episode:
await EpisodeService.mark_downloaded(
db,
episode.id,
file_path,
file_size,
)
else:
# Create episode
episode = await EpisodeService.create(
db,
series_id=item.series_id,
season=item.season,
episode_number=item.episode_number,
file_path=file_path,
file_size=file_size,
is_downloaded=True,
)
await db.commit()
logger.info(
f"Marked download complete: S{item.season:02d}E{item.episode_number:02d}"
)
async def mark_download_failed(
db: AsyncSession,
item_id: int,
error_message: str,
) -> None:
"""Mark download as failed in database.
Args:
db: Database session
item_id: Download queue item ID
error_message: Error description
"""
await DownloadQueueService.update_status(
db,
item_id,
DownloadStatus.FAILED,
error_message=error_message,
)
await db.commit()
# ============================================================================
# Example 4: Add Episodes to Download Queue
# ============================================================================
async def add_episodes_to_queue(
db: AsyncSession,
series_key: str,
episodes: List[tuple[int, int]], # List of (season, episode) tuples
priority: DownloadPriority = DownloadPriority.NORMAL,
) -> int:
"""Add multiple episodes to download queue.
Args:
db: Database session
series_key: Series provider key
episodes: List of (season, episode_number) tuples
priority: Download priority
Returns:
Number of episodes added to queue
"""
# Get series
series = await AnimeSeriesService.get_by_key(db, series_key)
if not series:
logger.error(f"Series not found: {series_key}")
return 0
added_count = 0
for season, episode_number in episodes:
# Check if already in queue
existing_items = await DownloadQueueService.get_all(db)
already_queued = any(
item.series_id == series.id
and item.season == season
and item.episode_number == episode_number
and item.status in (DownloadStatus.PENDING, DownloadStatus.DOWNLOADING)
for item in existing_items
)
if not already_queued:
await DownloadQueueService.create(
db,
series_id=series.id,
season=season,
episode_number=episode_number,
priority=priority,
)
added_count += 1
await db.commit()
logger.info(f"Added {added_count} episodes to download queue")
return added_count
# ============================================================================
# Example 5: Integration with AnimeService
# ============================================================================
class EnhancedAnimeService:
"""Enhanced AnimeService with database persistence.
This is an example of how to wrap the existing AnimeService with
database persistence capabilities.
"""
def __init__(self, db_session_factory):
"""Initialize enhanced anime service.
Args:
db_session_factory: Async session factory for database access
"""
self.db_session_factory = db_session_factory
async def rescan_with_persistence(self, directory: str) -> dict:
"""Rescan directory and persist results.
Args:
directory: Directory to scan
Returns:
Scan results dictionary
"""
# Import here to avoid circular dependencies
from src.core.SeriesApp import SeriesApp
# Perform scan
app = SeriesApp(directory)
series_list = app.ReScan()
# Persist to database
async with self.db_session_factory() as db:
await persist_scan_results(db, series_list)
return {
"total_series": len(series_list),
"message": "Scan completed and persisted to database",
}
async def get_series_with_missing_episodes(self) -> List[dict]:
"""Get series with missing episodes from database.
Returns:
List of series with missing episodes
"""
async with self.db_session_factory() as db:
# Get all series
all_series = await AnimeSeriesService.get_all(
db,
with_episodes=True,
)
# Filter series with missing episodes
series_with_missing = []
for series in all_series:
if series.episode_dict:
total_episodes = sum(
len(eps) for eps in series.episode_dict.values()
)
downloaded_episodes = sum(
1 for ep in series.episodes if ep.is_downloaded
)
if downloaded_episodes < total_episodes:
series_with_missing.append({
"id": series.id,
"key": series.key,
"name": series.name,
"total_episodes": total_episodes,
"downloaded_episodes": downloaded_episodes,
"missing_episodes": total_episodes - downloaded_episodes,
})
return series_with_missing
# ============================================================================
# Usage Example
# ============================================================================
async def example_usage():
"""Example usage of database service integration."""
from src.server.database import get_db_session
# Get database session
async with get_db_session() as db:
# Example 1: Add episodes to queue
added = await add_episodes_to_queue(
db,
series_key="attack-on-titan",
episodes=[(1, 1), (1, 2), (1, 3)],
priority=DownloadPriority.HIGH,
)
print(f"Added {added} episodes to queue")
# Example 2: Load queue
queue_items = await load_queue_from_database(db)
print(f"Queue has {len(queue_items)} items")
# Example 3: Update progress
if queue_items:
await sync_download_progress(
db,
item_id=queue_items[0]["id"],
progress_percent=50.0,
downloaded_bytes=500000,
total_bytes=1000000,
)
# Example 4: Mark complete
if queue_items:
await mark_download_complete(
db,
item_id=queue_items[0]["id"],
file_path="/path/to/file.mp4",
file_size=1000000,
)
if __name__ == "__main__":
import asyncio
asyncio.run(example_usage())

View File

@ -47,7 +47,7 @@ EXPECTED_INDEXES = {
"episodes": ["ix_episodes_series_id"],
"download_queue": [
"ix_download_queue_series_id",
"ix_download_queue_status",
"ix_download_queue_episode_id",
],
"user_sessions": [
"ix_user_sessions_session_id",

View File

@ -15,18 +15,7 @@ from datetime import datetime, timezone
from enum import Enum
from typing import List, Optional
from sqlalchemy import (
JSON,
Boolean,
DateTime,
Float,
ForeignKey,
Integer,
String,
Text,
func,
)
from sqlalchemy import Enum as SQLEnum
from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, String, Text, func
from sqlalchemy.orm import Mapped, mapped_column, relationship, validates
from src.server.database.base import Base, TimestampMixin
@ -51,10 +40,6 @@ class AnimeSeries(Base, TimestampMixin):
name: Display name of the series
site: Provider site URL
folder: Filesystem folder name (metadata only, not for lookups)
description: Optional series description
status: Current status (ongoing, completed, etc.)
total_episodes: Total number of episodes
cover_url: URL to series cover image
episodes: Relationship to Episode models (via id foreign key)
download_items: Relationship to DownloadQueueItem models (via id foreign key)
created_at: Creation timestamp (from TimestampMixin)
@ -89,30 +74,6 @@ class AnimeSeries(Base, TimestampMixin):
doc="Filesystem folder name - METADATA ONLY, not for lookups"
)
# Metadata
description: Mapped[Optional[str]] = mapped_column(
Text, nullable=True,
doc="Series description"
)
status: Mapped[Optional[str]] = mapped_column(
String(50), nullable=True,
doc="Series status (ongoing, completed, etc.)"
)
total_episodes: Mapped[Optional[int]] = mapped_column(
Integer, nullable=True,
doc="Total number of episodes"
)
cover_url: Mapped[Optional[str]] = mapped_column(
String(1000), nullable=True,
doc="URL to cover image"
)
# JSON field for episode dictionary (season -> [episodes])
episode_dict: Mapped[Optional[dict]] = mapped_column(
JSON, nullable=True,
doc="Episode dictionary {season: [episodes]}"
)
# Relationships
episodes: Mapped[List["Episode"]] = relationship(
"Episode",
@ -161,22 +122,6 @@ class AnimeSeries(Base, TimestampMixin):
raise ValueError("Folder path must be 1000 characters or less")
return value.strip()
@validates('cover_url')
def validate_cover_url(self, key: str, value: Optional[str]) -> Optional[str]:
"""Validate cover URL length."""
if value is not None and len(value) > 1000:
raise ValueError("Cover URL must be 1000 characters or less")
return value
@validates('total_episodes')
def validate_total_episodes(self, key: str, value: Optional[int]) -> Optional[int]:
"""Validate total episodes is positive."""
if value is not None and value < 0:
raise ValueError("Total episodes must be non-negative")
if value is not None and value > 10000:
raise ValueError("Total episodes must be 10000 or less")
return value
def __repr__(self) -> str:
return f"<AnimeSeries(id={self.id}, key='{self.key}', name='{self.name}')>"
@ -194,9 +139,7 @@ class Episode(Base, TimestampMixin):
episode_number: Episode number within season
title: Episode title
file_path: Local file path if downloaded
file_size: File size in bytes
is_downloaded: Whether episode is downloaded
download_date: When episode was downloaded
series: Relationship to AnimeSeries
created_at: Creation timestamp (from TimestampMixin)
updated_at: Last update timestamp (from TimestampMixin)
@ -234,18 +177,10 @@ class Episode(Base, TimestampMixin):
String(1000), nullable=True,
doc="Local file path"
)
file_size: Mapped[Optional[int]] = mapped_column(
Integer, nullable=True,
doc="File size in bytes"
)
is_downloaded: Mapped[bool] = mapped_column(
Boolean, default=False, nullable=False,
doc="Whether episode is downloaded"
)
download_date: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True), nullable=True,
doc="When episode was downloaded"
)
# Relationship
series: Mapped["AnimeSeries"] = relationship(
@ -287,13 +222,6 @@ class Episode(Base, TimestampMixin):
raise ValueError("File path must be 1000 characters or less")
return value
@validates('file_size')
def validate_file_size(self, key: str, value: Optional[int]) -> Optional[int]:
"""Validate file size is non-negative."""
if value is not None and value < 0:
raise ValueError("File size must be non-negative")
return value
def __repr__(self) -> str:
return (
f"<Episode(id={self.id}, series_id={self.series_id}, "
@ -321,27 +249,20 @@ class DownloadPriority(str, Enum):
class DownloadQueueItem(Base, TimestampMixin):
"""SQLAlchemy model for download queue items.
Tracks download queue with status, progress, and error information.
Tracks download queue with error information.
Provides persistence for the DownloadService queue state.
Attributes:
id: Primary key
series_id: Foreign key to AnimeSeries
season: Season number
episode_number: Episode number
status: Current download status
priority: Download priority
progress_percent: Download progress (0-100)
downloaded_bytes: Bytes downloaded
total_bytes: Total file size
download_speed: Current speed in bytes/sec
episode_id: Foreign key to Episode
error_message: Error description if failed
retry_count: Number of retry attempts
download_url: Provider download URL
file_destination: Target file path
started_at: When download started
completed_at: When download completed
series: Relationship to AnimeSeries
episode: Relationship to Episode
created_at: Creation timestamp (from TimestampMixin)
updated_at: Last update timestamp (from TimestampMixin)
"""
@ -359,47 +280,11 @@ class DownloadQueueItem(Base, TimestampMixin):
index=True
)
# Episode identification
season: Mapped[int] = mapped_column(
Integer, nullable=False,
doc="Season number"
)
episode_number: Mapped[int] = mapped_column(
Integer, nullable=False,
doc="Episode number"
)
# Queue management
status: Mapped[str] = mapped_column(
SQLEnum(DownloadStatus),
default=DownloadStatus.PENDING,
# Foreign key to episode
episode_id: Mapped[int] = mapped_column(
ForeignKey("episodes.id", ondelete="CASCADE"),
nullable=False,
index=True,
doc="Current download status"
)
priority: Mapped[str] = mapped_column(
SQLEnum(DownloadPriority),
default=DownloadPriority.NORMAL,
nullable=False,
doc="Download priority"
)
# Progress tracking
progress_percent: Mapped[float] = mapped_column(
Float, default=0.0, nullable=False,
doc="Progress percentage (0-100)"
)
downloaded_bytes: Mapped[int] = mapped_column(
Integer, default=0, nullable=False,
doc="Bytes downloaded"
)
total_bytes: Mapped[Optional[int]] = mapped_column(
Integer, nullable=True,
doc="Total file size"
)
download_speed: Mapped[Optional[float]] = mapped_column(
Float, nullable=True,
doc="Current download speed (bytes/sec)"
index=True
)
# Error handling
@ -407,10 +292,6 @@ class DownloadQueueItem(Base, TimestampMixin):
Text, nullable=True,
doc="Error description"
)
retry_count: Mapped[int] = mapped_column(
Integer, default=0, nullable=False,
doc="Number of retry attempts"
)
# Download details
download_url: Mapped[Optional[str]] = mapped_column(
@ -437,67 +318,9 @@ class DownloadQueueItem(Base, TimestampMixin):
"AnimeSeries",
back_populates="download_items"
)
@validates('season')
def validate_season(self, key: str, value: int) -> int:
"""Validate season number is positive."""
if value < 0:
raise ValueError("Season number must be non-negative")
if value > 1000:
raise ValueError("Season number must be 1000 or less")
return value
@validates('episode_number')
def validate_episode_number(self, key: str, value: int) -> int:
"""Validate episode number is positive."""
if value < 0:
raise ValueError("Episode number must be non-negative")
if value > 10000:
raise ValueError("Episode number must be 10000 or less")
return value
@validates('progress_percent')
def validate_progress_percent(self, key: str, value: float) -> float:
"""Validate progress is between 0 and 100."""
if value < 0.0:
raise ValueError("Progress percent must be non-negative")
if value > 100.0:
raise ValueError("Progress percent cannot exceed 100")
return value
@validates('downloaded_bytes')
def validate_downloaded_bytes(self, key: str, value: int) -> int:
"""Validate downloaded bytes is non-negative."""
if value < 0:
raise ValueError("Downloaded bytes must be non-negative")
return value
@validates('total_bytes')
def validate_total_bytes(
self, key: str, value: Optional[int]
) -> Optional[int]:
"""Validate total bytes is non-negative."""
if value is not None and value < 0:
raise ValueError("Total bytes must be non-negative")
return value
@validates('download_speed')
def validate_download_speed(
self, key: str, value: Optional[float]
) -> Optional[float]:
"""Validate download speed is non-negative."""
if value is not None and value < 0.0:
raise ValueError("Download speed must be non-negative")
return value
@validates('retry_count')
def validate_retry_count(self, key: str, value: int) -> int:
"""Validate retry count is non-negative."""
if value < 0:
raise ValueError("Retry count must be non-negative")
if value > 100:
raise ValueError("Retry count cannot exceed 100")
return value
episode: Mapped["Episode"] = relationship(
"Episode"
)
@validates('download_url')
def validate_download_url(
@ -523,8 +346,7 @@ class DownloadQueueItem(Base, TimestampMixin):
return (
f"<DownloadQueueItem(id={self.id}, "
f"series_id={self.series_id}, "
f"S{self.season:02d}E{self.episode_number:02d}, "
f"status={self.status})>"
f"episode_id={self.episode_id})>"
)

View File

@ -15,7 +15,7 @@ from __future__ import annotations
import logging
from datetime import datetime, timedelta, timezone
from typing import Dict, List, Optional
from typing import List, Optional
from sqlalchemy import delete, select, update
from sqlalchemy.ext.asyncio import AsyncSession
@ -23,9 +23,7 @@ from sqlalchemy.orm import Session, selectinload
from src.server.database.models import (
AnimeSeries,
DownloadPriority,
DownloadQueueItem,
DownloadStatus,
Episode,
UserSession,
)
@ -57,11 +55,6 @@ class AnimeSeriesService:
name: str,
site: str,
folder: str,
description: Optional[str] = None,
status: Optional[str] = None,
total_episodes: Optional[int] = None,
cover_url: Optional[str] = None,
episode_dict: Optional[Dict] = None,
) -> AnimeSeries:
"""Create a new anime series.
@ -71,11 +64,6 @@ class AnimeSeriesService:
name: Series name
site: Provider site URL
folder: Local filesystem path
description: Optional series description
status: Optional series status
total_episodes: Optional total episode count
cover_url: Optional cover image URL
episode_dict: Optional episode dictionary
Returns:
Created AnimeSeries instance
@ -88,11 +76,6 @@ class AnimeSeriesService:
name=name,
site=site,
folder=folder,
description=description,
status=status,
total_episodes=total_episodes,
cover_url=cover_url,
episode_dict=episode_dict,
)
db.add(series)
await db.flush()
@ -262,7 +245,6 @@ class EpisodeService:
episode_number: int,
title: Optional[str] = None,
file_path: Optional[str] = None,
file_size: Optional[int] = None,
is_downloaded: bool = False,
) -> Episode:
"""Create a new episode.
@ -274,7 +256,6 @@ class EpisodeService:
episode_number: Episode number within season
title: Optional episode title
file_path: Optional local file path
file_size: Optional file size in bytes
is_downloaded: Whether episode is downloaded
Returns:
@ -286,9 +267,7 @@ class EpisodeService:
episode_number=episode_number,
title=title,
file_path=file_path,
file_size=file_size,
is_downloaded=is_downloaded,
download_date=datetime.now(timezone.utc) if is_downloaded else None,
)
db.add(episode)
await db.flush()
@ -372,7 +351,6 @@ class EpisodeService:
db: AsyncSession,
episode_id: int,
file_path: str,
file_size: int,
) -> Optional[Episode]:
"""Mark episode as downloaded.
@ -380,7 +358,6 @@ class EpisodeService:
db: Database session
episode_id: Episode primary key
file_path: Local file path
file_size: File size in bytes
Returns:
Updated Episode instance or None if not found
@ -391,8 +368,6 @@ class EpisodeService:
episode.is_downloaded = True
episode.file_path = file_path
episode.file_size = file_size
episode.download_date = datetime.now(timezone.utc)
await db.flush()
await db.refresh(episode)
@ -427,17 +402,14 @@ class EpisodeService:
class DownloadQueueService:
"""Service for download queue CRUD operations.
Provides methods for managing the download queue with status tracking,
priority management, and progress updates.
Provides methods for managing the download queue.
"""
@staticmethod
async def create(
db: AsyncSession,
series_id: int,
season: int,
episode_number: int,
priority: DownloadPriority = DownloadPriority.NORMAL,
episode_id: int,
download_url: Optional[str] = None,
file_destination: Optional[str] = None,
) -> DownloadQueueItem:
@ -446,9 +418,7 @@ class DownloadQueueService:
Args:
db: Database session
series_id: Foreign key to AnimeSeries
season: Season number
episode_number: Episode number
priority: Download priority
episode_id: Foreign key to Episode
download_url: Optional provider download URL
file_destination: Optional target file path
@ -457,10 +427,7 @@ class DownloadQueueService:
"""
item = DownloadQueueItem(
series_id=series_id,
season=season,
episode_number=episode_number,
status=DownloadStatus.PENDING,
priority=priority,
episode_id=episode_id,
download_url=download_url,
file_destination=file_destination,
)
@ -468,8 +435,8 @@ class DownloadQueueService:
await db.flush()
await db.refresh(item)
logger.info(
f"Added to download queue: S{season:02d}E{episode_number:02d} "
f"for series_id={series_id} with priority={priority}"
f"Added to download queue: episode_id={episode_id} "
f"for series_id={series_id}"
)
return item
@ -493,68 +460,25 @@ class DownloadQueueService:
return result.scalar_one_or_none()
@staticmethod
async def get_by_status(
async def get_by_episode(
db: AsyncSession,
status: DownloadStatus,
limit: Optional[int] = None,
) -> List[DownloadQueueItem]:
"""Get download queue items by status.
episode_id: int,
) -> Optional[DownloadQueueItem]:
"""Get download queue item by episode ID.
Args:
db: Database session
status: Download status filter
limit: Optional limit for results
episode_id: Foreign key to Episode
Returns:
List of DownloadQueueItem instances
DownloadQueueItem instance or None if not found
"""
query = select(DownloadQueueItem).where(
DownloadQueueItem.status == status
)
# Order by priority (HIGH first) then creation time
query = query.order_by(
DownloadQueueItem.priority.desc(),
DownloadQueueItem.created_at.asc(),
)
if limit:
query = query.limit(limit)
result = await db.execute(query)
return list(result.scalars().all())
@staticmethod
async def get_pending(
db: AsyncSession,
limit: Optional[int] = None,
) -> List[DownloadQueueItem]:
"""Get pending download queue items.
Args:
db: Database session
limit: Optional limit for results
Returns:
List of pending DownloadQueueItem instances ordered by priority
"""
return await DownloadQueueService.get_by_status(
db, DownloadStatus.PENDING, limit
)
@staticmethod
async def get_active(db: AsyncSession) -> List[DownloadQueueItem]:
"""Get active download queue items.
Args:
db: Database session
Returns:
List of downloading DownloadQueueItem instances
"""
return await DownloadQueueService.get_by_status(
db, DownloadStatus.DOWNLOADING
result = await db.execute(
select(DownloadQueueItem).where(
DownloadQueueItem.episode_id == episode_id
)
)
return result.scalar_one_or_none()
@staticmethod
async def get_all(
@ -576,7 +500,6 @@ class DownloadQueueService:
query = query.options(selectinload(DownloadQueueItem.series))
query = query.order_by(
DownloadQueueItem.priority.desc(),
DownloadQueueItem.created_at.asc(),
)
@ -584,19 +507,17 @@ class DownloadQueueService:
return list(result.scalars().all())
@staticmethod
async def update_status(
async def set_error(
db: AsyncSession,
item_id: int,
status: DownloadStatus,
error_message: Optional[str] = None,
error_message: str,
) -> Optional[DownloadQueueItem]:
"""Update download queue item status.
"""Set error message on download queue item.
Args:
db: Database session
item_id: Item primary key
status: New download status
error_message: Optional error message for failed status
error_message: Error description
Returns:
Updated DownloadQueueItem instance or None if not found
@ -605,61 +526,11 @@ class DownloadQueueService:
if not item:
return None
item.status = status
# Update timestamps based on status
if status == DownloadStatus.DOWNLOADING and not item.started_at:
item.started_at = datetime.now(timezone.utc)
elif status in (DownloadStatus.COMPLETED, DownloadStatus.FAILED):
item.completed_at = datetime.now(timezone.utc)
# Set error message for failed downloads
if status == DownloadStatus.FAILED and error_message:
item.error_message = error_message
item.retry_count += 1
await db.flush()
await db.refresh(item)
logger.debug(f"Updated download queue item {item_id} status to {status}")
return item
@staticmethod
async def update_progress(
db: AsyncSession,
item_id: int,
progress_percent: float,
downloaded_bytes: int,
total_bytes: Optional[int] = None,
download_speed: Optional[float] = None,
) -> Optional[DownloadQueueItem]:
"""Update download progress.
Args:
db: Database session
item_id: Item primary key
progress_percent: Progress percentage (0-100)
downloaded_bytes: Bytes downloaded
total_bytes: Optional total file size
download_speed: Optional current speed (bytes/sec)
Returns:
Updated DownloadQueueItem instance or None if not found
"""
item = await DownloadQueueService.get_by_id(db, item_id)
if not item:
return None
item.progress_percent = progress_percent
item.downloaded_bytes = downloaded_bytes
if total_bytes is not None:
item.total_bytes = total_bytes
if download_speed is not None:
item.download_speed = download_speed
item.error_message = error_message
await db.flush()
await db.refresh(item)
logger.debug(f"Set error on download queue item {item_id}")
return item
@staticmethod
@ -682,57 +553,30 @@ class DownloadQueueService:
return deleted
@staticmethod
async def clear_completed(db: AsyncSession) -> int:
"""Clear completed downloads from queue.
async def delete_by_episode(
db: AsyncSession,
episode_id: int,
) -> bool:
"""Delete download queue item by episode ID.
Args:
db: Database session
episode_id: Foreign key to Episode
Returns:
Number of items cleared
True if deleted, False if not found
"""
result = await db.execute(
delete(DownloadQueueItem).where(
DownloadQueueItem.status == DownloadStatus.COMPLETED
DownloadQueueItem.episode_id == episode_id
)
)
count = result.rowcount
logger.info(f"Cleared {count} completed downloads from queue")
return count
@staticmethod
async def retry_failed(
db: AsyncSession,
max_retries: int = 3,
) -> List[DownloadQueueItem]:
"""Retry failed downloads that haven't exceeded max retries.
Args:
db: Database session
max_retries: Maximum number of retry attempts
Returns:
List of items marked for retry
"""
result = await db.execute(
select(DownloadQueueItem).where(
DownloadQueueItem.status == DownloadStatus.FAILED,
DownloadQueueItem.retry_count < max_retries,
deleted = result.rowcount > 0
if deleted:
logger.info(
f"Deleted download queue item with episode_id={episode_id}"
)
)
items = list(result.scalars().all())
for item in items:
item.status = DownloadStatus.PENDING
item.error_message = None
item.progress_percent = 0.0
item.downloaded_bytes = 0
item.started_at = None
item.completed_at = None
await db.flush()
logger.info(f"Marked {len(items)} failed downloads for retry")
return items
return deleted
# ============================================================================

View File

@ -70,8 +70,6 @@ class AnimeSeriesResponse(BaseModel):
)
)
alt_titles: List[str] = Field(default_factory=list, description="Alternative titles")
description: Optional[str] = Field(None, description="Short series description")
total_episodes: Optional[int] = Field(None, ge=0, description="Declared total episode count if known")
episodes: List[EpisodeInfo] = Field(default_factory=list, description="Known episodes information")
missing_episodes: List[MissingEpisodeInfo] = Field(default_factory=list, description="Detected missing episode ranges")
thumbnail: Optional[HttpUrl] = Field(None, description="Optional thumbnail image URL")

View File

@ -22,7 +22,7 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from src.core.entities.series import Serie
from src.server.database.service import AnimeSeriesService
from src.server.database.service import AnimeSeriesService, EpisodeService
logger = logging.getLogger(__name__)
@ -206,7 +206,7 @@ class DataMigrationService:
Reads the data file, checks if the series already exists in the
database, and creates a new record if it doesn't exist. If the
series exists, optionally updates the episode_dict if changed.
series exists, optionally updates the episodes if changed.
Args:
data_path: Path to the data file
@ -229,41 +229,44 @@ class DataMigrationService:
existing = await AnimeSeriesService.get_by_key(db, serie.key)
if existing is not None:
# Check if episode_dict has changed
existing_dict = existing.episode_dict or {}
# Build episode dict from existing episodes for comparison
existing_dict: dict[int, list[int]] = {}
episodes = await EpisodeService.get_by_series(db, existing.id)
for ep in episodes:
if ep.season not in existing_dict:
existing_dict[ep.season] = []
existing_dict[ep.season].append(ep.episode_number)
for season in existing_dict:
existing_dict[season].sort()
new_dict = serie.episodeDict or {}
# Convert keys to strings for comparison (JSON stores keys as strings)
new_dict_str_keys = {
str(k): v for k, v in new_dict.items()
}
if existing_dict == new_dict_str_keys:
if existing_dict == new_dict:
logger.debug(
"Series '%s' already exists with same data, skipping",
serie.key
)
return False
# Update episode_dict if different
await AnimeSeriesService.update(
db,
existing.id,
episode_dict=new_dict_str_keys
)
# Update episodes if different - add new episodes
for season, episode_numbers in new_dict.items():
existing_eps = set(existing_dict.get(season, []))
for ep_num in episode_numbers:
if ep_num not in existing_eps:
await EpisodeService.create(
db=db,
series_id=existing.id,
season=season,
episode_number=ep_num,
)
logger.info(
"Updated episode_dict for existing series '%s'",
"Updated episodes for existing series '%s'",
serie.key
)
return True
# Create new series in database
try:
# Convert episode_dict keys to strings for JSON storage
episode_dict_for_db = {
str(k): v for k, v in (serie.episodeDict or {}).items()
}
# Use folder as fallback name if name is empty
series_name = serie.name
if not series_name or not series_name.strip():
@ -274,14 +277,25 @@ class DataMigrationService:
serie.key
)
await AnimeSeriesService.create(
anime_series = await AnimeSeriesService.create(
db,
key=serie.key,
name=series_name,
site=serie.site,
folder=serie.folder,
episode_dict=episode_dict_for_db,
)
# Create Episode records for each episode in episodeDict
if serie.episodeDict:
for season, episode_numbers in serie.episodeDict.items():
for episode_number in episode_numbers:
await EpisodeService.create(
db=db,
series_id=anime_series.id,
season=season,
episode_number=episode_number,
)
logger.info(
"Migrated series '%s' to database",
serie.key

View File

@ -153,29 +153,40 @@ class TestMigrationIdempotency:
}
(series_dir / "data").write_text(json.dumps(data))
# Mock existing series in database
# Mock existing series in database with same episodes
existing = MagicMock()
existing.id = 1
existing.episode_dict = {"1": [1, 2]} # Same data
# Mock episodes matching data file
mock_episodes = [
MagicMock(season=1, episode_number=1),
MagicMock(season=1, episode_number=2),
]
service = DataMigrationService()
with patch(
'src.server.services.data_migration_service.AnimeSeriesService'
) as MockService:
MockService.get_by_key = AsyncMock(return_value=existing)
mock_db = AsyncMock()
mock_db.commit = AsyncMock()
result = await service.migrate_all(tmp_dir, mock_db)
# Should skip since data is same
assert result.total_found == 1
assert result.skipped == 1
assert result.migrated == 0
# Should not call create
MockService.create.assert_not_called()
with patch(
'src.server.services.data_migration_service.EpisodeService'
) as MockEpisodeService:
MockService.get_by_key = AsyncMock(return_value=existing)
MockEpisodeService.get_by_series = AsyncMock(
return_value=mock_episodes
)
mock_db = AsyncMock()
mock_db.commit = AsyncMock()
result = await service.migrate_all(tmp_dir, mock_db)
# Should skip since data is same
assert result.total_found == 1
assert result.skipped == 1
assert result.migrated == 0
# Should not call create
MockService.create.assert_not_called()
@pytest.mark.asyncio
async def test_migration_updates_changed_episodes(self):
@ -196,25 +207,37 @@ class TestMigrationIdempotency:
# Mock existing series with fewer episodes
existing = MagicMock()
existing.id = 1
existing.episode_dict = {"1": [1, 2]} # Fewer episodes
# Mock existing episodes (fewer than data file)
mock_episodes = [
MagicMock(season=1, episode_number=1),
MagicMock(season=1, episode_number=2),
]
service = DataMigrationService()
with patch(
'src.server.services.data_migration_service.AnimeSeriesService'
) as MockService:
MockService.get_by_key = AsyncMock(return_value=existing)
MockService.update = AsyncMock()
mock_db = AsyncMock()
mock_db.commit = AsyncMock()
result = await service.migrate_all(tmp_dir, mock_db)
# Should update since data changed
assert result.total_found == 1
assert result.migrated == 1
MockService.update.assert_called_once()
with patch(
'src.server.services.data_migration_service.EpisodeService'
) as MockEpisodeService:
MockService.get_by_key = AsyncMock(return_value=existing)
MockEpisodeService.get_by_series = AsyncMock(
return_value=mock_episodes
)
MockEpisodeService.create = AsyncMock()
mock_db = AsyncMock()
mock_db.commit = AsyncMock()
result = await service.migrate_all(tmp_dir, mock_db)
# Should update since data changed
assert result.total_found == 1
assert result.migrated == 1
# Should create 3 new episodes (3, 4, 5)
assert MockEpisodeService.create.call_count == 3
class TestMigrationOnFreshStart:
@ -348,13 +371,18 @@ class TestSerieListReadsFromDatabase:
# Create mock series in database with spec to avoid mock attributes
from dataclasses import dataclass
@dataclass
class MockEpisode:
season: int
episode_number: int
@dataclass
class MockAnimeSeries:
key: str
name: str
site: str
folder: str
episode_dict: dict
episodes: list
mock_series = [
MockAnimeSeries(
@ -362,14 +390,18 @@ class TestSerieListReadsFromDatabase:
name="Anime 1",
site="aniworld.to",
folder="Anime 1",
episode_dict={"1": [1, 2, 3]}
episodes=[
MockEpisode(1, 1), MockEpisode(1, 2), MockEpisode(1, 3)
]
),
MockAnimeSeries(
key="anime-2",
name="Anime 2",
site="aniworld.to",
folder="Anime 2",
episode_dict={"1": [1, 2], "2": [1]}
episodes=[
MockEpisode(1, 1), MockEpisode(1, 2), MockEpisode(2, 1)
]
)
]
@ -389,8 +421,8 @@ class TestSerieListReadsFromDatabase:
# Load from database
await serie_list.load_series_from_db(mock_db)
# Verify service was called
mock_get_all.assert_called_once_with(mock_db)
# Verify service was called with with_episodes=True
mock_get_all.assert_called_once_with(mock_db, with_episodes=True)
# Verify series were loaded
all_series = serie_list.get_all()

View File

@ -65,7 +65,6 @@ class TestAnimeSeriesResponse:
title="Attack on Titan",
folder="Attack on Titan (2013)",
episodes=[ep],
total_episodes=12,
)
assert series.key == "attack-on-titan"

View File

@ -304,10 +304,18 @@ class TestDataMigrationServiceMigrateSingle:
"""Test migrating series that already exists with same data."""
service = DataMigrationService()
# Create mock existing series with same episode_dict
# Create mock existing series with same episodes
existing = MagicMock()
existing.id = 1
existing.episode_dict = {"1": [1, 2, 3], "2": [1, 2]}
# Mock episodes matching sample_serie.episodeDict = {1: [1, 2, 3], 2: [1, 2]}
mock_episodes = []
for season, eps in {1: [1, 2, 3], 2: [1, 2]}.items():
for ep_num in eps:
mock_ep = MagicMock()
mock_ep.season = season
mock_ep.episode_number = ep_num
mock_episodes.append(mock_ep)
with patch.object(
service,
@ -317,19 +325,25 @@ class TestDataMigrationServiceMigrateSingle:
with patch(
'src.server.services.data_migration_service.AnimeSeriesService'
) as MockService:
MockService.get_by_key = AsyncMock(return_value=existing)
result = await service.migrate_data_file(
Path("/fake/data"),
mock_db
)
assert result is False
MockService.create.assert_not_called()
with patch(
'src.server.services.data_migration_service.EpisodeService'
) as MockEpisodeService:
MockService.get_by_key = AsyncMock(return_value=existing)
MockEpisodeService.get_by_series = AsyncMock(
return_value=mock_episodes
)
result = await service.migrate_data_file(
Path("/fake/data"),
mock_db
)
assert result is False
MockService.create.assert_not_called()
@pytest.mark.asyncio
async def test_migrate_existing_series_different_data(self, mock_db):
"""Test migrating series that exists with different episode_dict."""
"""Test migrating series that exists with different episodes."""
service = DataMigrationService()
# Serie with new episodes
@ -344,7 +358,14 @@ class TestDataMigrationServiceMigrateSingle:
# Existing series has fewer episodes
existing = MagicMock()
existing.id = 1
existing.episode_dict = {"1": [1, 2, 3]}
# Mock episodes for existing (only 3 episodes)
mock_episodes = []
for ep_num in [1, 2, 3]:
mock_ep = MagicMock()
mock_ep.season = 1
mock_ep.episode_number = ep_num
mock_episodes.append(mock_ep)
with patch.object(
service,
@ -354,16 +375,23 @@ class TestDataMigrationServiceMigrateSingle:
with patch(
'src.server.services.data_migration_service.AnimeSeriesService'
) as MockService:
MockService.get_by_key = AsyncMock(return_value=existing)
MockService.update = AsyncMock()
result = await service.migrate_data_file(
Path("/fake/data"),
mock_db
)
assert result is True
MockService.update.assert_called_once()
with patch(
'src.server.services.data_migration_service.EpisodeService'
) as MockEpisodeService:
MockService.get_by_key = AsyncMock(return_value=existing)
MockEpisodeService.get_by_series = AsyncMock(
return_value=mock_episodes
)
MockEpisodeService.create = AsyncMock()
result = await service.migrate_data_file(
Path("/fake/data"),
mock_db
)
assert result is True
# Should create 2 new episodes (4 and 5)
assert MockEpisodeService.create.call_count == 2
@pytest.mark.asyncio
async def test_migrate_read_error(self, mock_db):
@ -493,21 +521,26 @@ class TestDataMigrationServiceMigrateAll:
# Mock: first series doesn't exist, second already exists
existing = MagicMock()
existing.id = 2
existing.episode_dict = {}
with patch(
'src.server.services.data_migration_service.AnimeSeriesService'
) as MockService:
MockService.get_by_key = AsyncMock(
side_effect=[None, existing]
)
MockService.create = AsyncMock()
result = await service.migrate_all(tmp_dir, mock_db)
assert result.total_found == 2
assert result.migrated == 1
assert result.skipped == 1
with patch(
'src.server.services.data_migration_service.EpisodeService'
) as MockEpisodeService:
MockService.get_by_key = AsyncMock(
side_effect=[None, existing]
)
MockService.create = AsyncMock(
return_value=MagicMock(id=1)
)
MockEpisodeService.get_by_series = AsyncMock(return_value=[])
result = await service.migrate_all(tmp_dir, mock_db)
assert result.total_found == 2
assert result.migrated == 1
assert result.skipped == 1
class TestDataMigrationServiceIsMigrationNeeded:

View File

@ -14,9 +14,7 @@ from sqlalchemy.orm import Session, sessionmaker
from src.server.database.base import Base, SoftDeleteMixin, TimestampMixin
from src.server.database.models import (
AnimeSeries,
DownloadPriority,
DownloadQueueItem,
DownloadStatus,
Episode,
UserSession,
)
@ -49,11 +47,6 @@ class TestAnimeSeries:
name="Attack on Titan",
site="https://aniworld.to",
folder="/anime/attack-on-titan",
description="Epic anime about titans",
status="completed",
total_episodes=75,
cover_url="https://example.com/cover.jpg",
episode_dict={1: [1, 2, 3], 2: [1, 2, 3, 4]},
)
db_session.add(series)
@ -172,9 +165,7 @@ class TestEpisode:
episode_number=5,
title="The Fifth Episode",
file_path="/anime/test/S01E05.mp4",
file_size=524288000, # 500 MB
is_downloaded=True,
download_date=datetime.now(timezone.utc),
)
db_session.add(episode)
@ -225,17 +216,17 @@ class TestDownloadQueueItem:
db_session.add(series)
db_session.commit()
item = DownloadQueueItem(
episode = Episode(
series_id=series.id,
season=1,
episode_number=3,
status=DownloadStatus.DOWNLOADING,
priority=DownloadPriority.HIGH,
progress_percent=45.5,
downloaded_bytes=250000000,
total_bytes=550000000,
download_speed=2500000.0,
retry_count=0,
)
db_session.add(episode)
db_session.commit()
item = DownloadQueueItem(
series_id=series.id,
episode_id=episode.id,
download_url="https://example.com/download/ep3",
file_destination="/anime/download/S01E03.mp4",
)
@ -245,37 +236,38 @@ class TestDownloadQueueItem:
# Verify saved
assert item.id is not None
assert item.status == DownloadStatus.DOWNLOADING
assert item.priority == DownloadPriority.HIGH
assert item.progress_percent == 45.5
assert item.retry_count == 0
assert item.episode_id == episode.id
assert item.series_id == series.id
def test_download_item_status_enum(self, db_session: Session):
"""Test download status enum values."""
def test_download_item_episode_relationship(self, db_session: Session):
"""Test download item episode relationship."""
series = AnimeSeries(
key="status-test",
name="Status Test",
key="relationship-test",
name="Relationship Test",
site="https://example.com",
folder="/anime/status",
folder="/anime/relationship",
)
db_session.add(series)
db_session.commit()
item = DownloadQueueItem(
episode = Episode(
series_id=series.id,
season=1,
episode_number=1,
status=DownloadStatus.PENDING,
)
db_session.add(episode)
db_session.commit()
item = DownloadQueueItem(
series_id=series.id,
episode_id=episode.id,
)
db_session.add(item)
db_session.commit()
# Update status
item.status = DownloadStatus.COMPLETED
db_session.commit()
# Verify status change
assert item.status == DownloadStatus.COMPLETED
# Verify relationship
assert item.episode.id == episode.id
assert item.series.id == series.id
def test_download_item_error_handling(self, db_session: Session):
"""Test download item with error information."""
@ -288,21 +280,24 @@ class TestDownloadQueueItem:
db_session.add(series)
db_session.commit()
item = DownloadQueueItem(
episode = Episode(
series_id=series.id,
season=1,
episode_number=1,
status=DownloadStatus.FAILED,
)
db_session.add(episode)
db_session.commit()
item = DownloadQueueItem(
series_id=series.id,
episode_id=episode.id,
error_message="Network timeout after 30 seconds",
retry_count=2,
)
db_session.add(item)
db_session.commit()
# Verify error info
assert item.status == DownloadStatus.FAILED
assert item.error_message == "Network timeout after 30 seconds"
assert item.retry_count == 2
class TestUserSession:
@ -502,32 +497,31 @@ class TestDatabaseQueries:
db_session.add(series)
db_session.commit()
# Create items with different statuses
for i, status in enumerate([
DownloadStatus.PENDING,
DownloadStatus.DOWNLOADING,
DownloadStatus.COMPLETED,
]):
item = DownloadQueueItem(
# Create episodes and items
for i in range(3):
episode = Episode(
series_id=series.id,
season=1,
episode_number=i + 1,
status=status,
)
db_session.add(episode)
db_session.commit()
item = DownloadQueueItem(
series_id=series.id,
episode_id=episode.id,
)
db_session.add(item)
db_session.commit()
# Query pending items
# Query all items
result = db_session.execute(
select(DownloadQueueItem).where(
DownloadQueueItem.status == DownloadStatus.PENDING
)
select(DownloadQueueItem)
)
pending = result.scalars().all()
items = result.scalars().all()
# Verify query
assert len(pending) == 1
assert pending[0].episode_number == 1
assert len(items) == 3
def test_query_active_sessions(self, db_session: Session):
"""Test querying active user sessions."""

View File

@ -10,7 +10,6 @@ from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from src.server.database.base import Base
from src.server.database.models import DownloadPriority, DownloadStatus
from src.server.database.service import (
AnimeSeriesService,
DownloadQueueService,
@ -65,17 +64,11 @@ async def test_create_anime_series(db_session):
name="Test Anime",
site="https://example.com",
folder="/path/to/anime",
description="A test anime",
status="ongoing",
total_episodes=12,
cover_url="https://example.com/cover.jpg",
)
assert series.id is not None
assert series.key == "test-anime-1"
assert series.name == "Test Anime"
assert series.description == "A test anime"
assert series.total_episodes == 12
@pytest.mark.asyncio
@ -160,13 +153,11 @@ async def test_update_anime_series(db_session):
db_session,
series.id,
name="Updated Name",
total_episodes=24,
)
await db_session.commit()
assert updated is not None
assert updated.name == "Updated Name"
assert updated.total_episodes == 24
@pytest.mark.asyncio
@ -308,14 +299,12 @@ async def test_mark_episode_downloaded(db_session):
db_session,
episode.id,
file_path="/path/to/file.mp4",
file_size=1024000,
)
await db_session.commit()
assert updated is not None
assert updated.is_downloaded is True
assert updated.file_path == "/path/to/file.mp4"
assert updated.download_date is not None
# ============================================================================
@ -336,23 +325,30 @@ async def test_create_download_queue_item(db_session):
)
await db_session.commit()
# Add to queue
item = await DownloadQueueService.create(
# Create episode
episode = await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
priority=DownloadPriority.HIGH,
)
await db_session.commit()
# Add to queue
item = await DownloadQueueService.create(
db_session,
series_id=series.id,
episode_id=episode.id,
)
assert item.id is not None
assert item.status == DownloadStatus.PENDING
assert item.priority == DownloadPriority.HIGH
assert item.episode_id == episode.id
assert item.series_id == series.id
@pytest.mark.asyncio
async def test_get_pending_downloads(db_session):
"""Test retrieving pending downloads."""
async def test_get_download_queue_item_by_episode(db_session):
"""Test retrieving download queue item by episode."""
# Create series
series = await AnimeSeriesService.create(
db_session,
@ -362,29 +358,32 @@ async def test_get_pending_downloads(db_session):
folder="/path/test5",
)
# Add pending items
await DownloadQueueService.create(
# Create episode
episode = await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
)
await db_session.commit()
# Add to queue
await DownloadQueueService.create(
db_session,
series_id=series.id,
season=1,
episode_number=2,
episode_id=episode.id,
)
await db_session.commit()
# Retrieve pending
pending = await DownloadQueueService.get_pending(db_session)
assert len(pending) == 2
# Retrieve by episode
item = await DownloadQueueService.get_by_episode(db_session, episode.id)
assert item is not None
assert item.episode_id == episode.id
@pytest.mark.asyncio
async def test_update_download_status(db_session):
"""Test updating download status."""
async def test_set_download_error(db_session):
"""Test setting error on download queue item."""
# Create series and queue item
series = await AnimeSeriesService.create(
db_session,
@ -393,30 +392,34 @@ async def test_update_download_status(db_session):
site="https://example.com",
folder="/path/test6",
)
item = await DownloadQueueService.create(
episode = await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
)
item = await DownloadQueueService.create(
db_session,
series_id=series.id,
episode_id=episode.id,
)
await db_session.commit()
# Update status
updated = await DownloadQueueService.update_status(
# Set error
updated = await DownloadQueueService.set_error(
db_session,
item.id,
DownloadStatus.DOWNLOADING,
"Network error",
)
await db_session.commit()
assert updated is not None
assert updated.status == DownloadStatus.DOWNLOADING
assert updated.started_at is not None
assert updated.error_message == "Network error"
@pytest.mark.asyncio
async def test_update_download_progress(db_session):
"""Test updating download progress."""
async def test_delete_download_queue_item_by_episode(db_session):
"""Test deleting download queue item by episode."""
# Create series and queue item
series = await AnimeSeriesService.create(
db_session,
@ -425,109 +428,31 @@ async def test_update_download_progress(db_session):
site="https://example.com",
folder="/path/test7",
)
item = await DownloadQueueService.create(
episode = await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
)
await db_session.commit()
# Update progress
updated = await DownloadQueueService.update_progress(
db_session,
item.id,
progress_percent=50.0,
downloaded_bytes=500000,
total_bytes=1000000,
download_speed=50000.0,
)
await db_session.commit()
assert updated is not None
assert updated.progress_percent == 50.0
assert updated.downloaded_bytes == 500000
assert updated.total_bytes == 1000000
@pytest.mark.asyncio
async def test_clear_completed_downloads(db_session):
"""Test clearing completed downloads."""
# Create series and completed items
series = await AnimeSeriesService.create(
db_session,
key="test-series-8",
name="Test Series 8",
site="https://example.com",
folder="/path/test8",
)
item1 = await DownloadQueueService.create(
await DownloadQueueService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
)
item2 = await DownloadQueueService.create(
db_session,
series_id=series.id,
season=1,
episode_number=2,
)
# Mark items as completed
await DownloadQueueService.update_status(
db_session,
item1.id,
DownloadStatus.COMPLETED,
)
await DownloadQueueService.update_status(
db_session,
item2.id,
DownloadStatus.COMPLETED,
episode_id=episode.id,
)
await db_session.commit()
# Clear completed
count = await DownloadQueueService.clear_completed(db_session)
await db_session.commit()
assert count == 2
@pytest.mark.asyncio
async def test_retry_failed_downloads(db_session):
"""Test retrying failed downloads."""
# Create series and failed item
series = await AnimeSeriesService.create(
# Delete by episode
deleted = await DownloadQueueService.delete_by_episode(
db_session,
key="test-series-9",
name="Test Series 9",
site="https://example.com",
folder="/path/test9",
)
item = await DownloadQueueService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
)
# Mark as failed
await DownloadQueueService.update_status(
db_session,
item.id,
DownloadStatus.FAILED,
error_message="Network error",
episode.id,
)
await db_session.commit()
# Retry
retried = await DownloadQueueService.retry_failed(db_session)
await db_session.commit()
assert deleted is True
assert len(retried) == 1
assert retried[0].status == DownloadStatus.PENDING
assert retried[0].error_message is None
# Verify deleted
item = await DownloadQueueService.get_by_episode(db_session, episode.id)
assert item is None
# ============================================================================

View File

@ -45,7 +45,23 @@ def mock_anime_series():
anime_series.name = "Test Series"
anime_series.site = "https://aniworld.to/anime/stream/test-series"
anime_series.folder = "Test Series (2020)"
anime_series.episode_dict = {"1": [1, 2, 3], "2": [1, 2]}
# Mock episodes relationship
mock_ep1 = MagicMock()
mock_ep1.season = 1
mock_ep1.episode_number = 1
mock_ep2 = MagicMock()
mock_ep2.season = 1
mock_ep2.episode_number = 2
mock_ep3 = MagicMock()
mock_ep3.season = 1
mock_ep3.episode_number = 3
mock_ep4 = MagicMock()
mock_ep4.season = 2
mock_ep4.episode_number = 1
mock_ep5 = MagicMock()
mock_ep5.season = 2
mock_ep5.episode_number = 2
anime_series.episodes = [mock_ep1, mock_ep2, mock_ep3, mock_ep4, mock_ep5]
return anime_series
@ -288,37 +304,27 @@ class TestSerieListDatabaseMode:
assert serie.name == mock_anime_series.name
assert serie.site == mock_anime_series.site
assert serie.folder == mock_anime_series.folder
# Season keys should be converted from string to int
# Season keys should be built from episodes relationship
assert 1 in serie.episodeDict
assert 2 in serie.episodeDict
assert serie.episodeDict[1] == [1, 2, 3]
assert serie.episodeDict[2] == [1, 2]
def test_convert_from_db_empty_episode_dict(self, mock_anime_series):
"""Test _convert_from_db handles empty episode_dict."""
mock_anime_series.episode_dict = None
def test_convert_from_db_empty_episodes(self, mock_anime_series):
"""Test _convert_from_db handles empty episodes."""
mock_anime_series.episodes = []
serie = SerieList._convert_from_db(mock_anime_series)
assert serie.episodeDict == {}
def test_convert_from_db_handles_invalid_season_keys(
self, mock_anime_series
):
"""Test _convert_from_db handles invalid season keys gracefully."""
mock_anime_series.episode_dict = {
"1": [1, 2],
"invalid": [3, 4], # Invalid key - not an integer
"2": [5, 6]
}
def test_convert_from_db_none_episodes(self, mock_anime_series):
"""Test _convert_from_db handles None episodes."""
mock_anime_series.episodes = None
serie = SerieList._convert_from_db(mock_anime_series)
# Valid keys should be converted
assert 1 in serie.episodeDict
assert 2 in serie.episodeDict
# Invalid key should be skipped
assert "invalid" not in serie.episodeDict
assert serie.episodeDict == {}
def test_convert_to_db_dict(self, sample_serie):
"""Test _convert_to_db_dict creates correct dictionary."""
@ -328,9 +334,8 @@ class TestSerieListDatabaseMode:
assert result["name"] == sample_serie.name
assert result["site"] == sample_serie.site
assert result["folder"] == sample_serie.folder
# Keys should be converted to strings for JSON
assert "1" in result["episode_dict"]
assert result["episode_dict"]["1"] == [1, 2, 3]
# episode_dict should not be in result anymore
assert "episode_dict" not in result
def test_convert_to_db_dict_empty_episode_dict(self):
"""Test _convert_to_db_dict handles empty episode_dict."""
@ -344,7 +349,8 @@ class TestSerieListDatabaseMode:
result = SerieList._convert_to_db_dict(serie)
assert result["episode_dict"] is None
# episode_dict should not be in result anymore
assert "episode_dict" not in result
class TestSerieListDatabaseAsync:

View File

@ -174,10 +174,16 @@ class TestSerieScannerAsyncScan:
"""Test scan_async updates existing series in database."""
scanner = SerieScanner(temp_directory, mock_loader)
# Mock existing series in database
# Mock existing series in database with different episodes
existing = MagicMock()
existing.id = 1
existing.episode_dict = {1: [5, 6]} # Different from sample_serie
existing.folder = sample_serie.folder
# Mock episodes (different from sample_serie)
mock_existing_episodes = [
MagicMock(season=1, episode_number=5),
MagicMock(season=1, episode_number=6),
]
with patch.object(scanner, 'get_total_to_scan', return_value=1):
with patch.object(
@ -200,17 +206,24 @@ class TestSerieScannerAsyncScan:
with patch(
'src.server.database.service.AnimeSeriesService'
) as mock_service:
mock_service.get_by_key = AsyncMock(
return_value=existing
)
mock_service.update = AsyncMock(
return_value=existing
)
await scanner.scan_async(mock_db_session)
# Verify database update was called
mock_service.update.assert_called_once()
with patch(
'src.server.database.service.EpisodeService'
) as mock_ep_service:
mock_service.get_by_key = AsyncMock(
return_value=existing
)
mock_service.update = AsyncMock(
return_value=existing
)
mock_ep_service.get_by_series = AsyncMock(
return_value=mock_existing_episodes
)
mock_ep_service.create = AsyncMock()
await scanner.scan_async(mock_db_session)
# Verify episodes were created
assert mock_ep_service.create.called
@pytest.mark.asyncio
async def test_scan_async_handles_errors_gracefully(
@ -249,17 +262,21 @@ class TestSerieScannerDatabaseHelpers:
with patch(
'src.server.database.service.AnimeSeriesService'
) as mock_service:
mock_service.get_by_key = AsyncMock(return_value=None)
mock_created = MagicMock()
mock_created.id = 1
mock_service.create = AsyncMock(return_value=mock_created)
result = await scanner._save_serie_to_db(
sample_serie, mock_db_session
)
assert result is mock_created
mock_service.create.assert_called_once()
with patch(
'src.server.database.service.EpisodeService'
) as mock_ep_service:
mock_service.get_by_key = AsyncMock(return_value=None)
mock_created = MagicMock()
mock_created.id = 1
mock_service.create = AsyncMock(return_value=mock_created)
mock_ep_service.create = AsyncMock()
result = await scanner._save_serie_to_db(
sample_serie, mock_db_session
)
assert result is mock_created
mock_service.create.assert_called_once()
@pytest.mark.asyncio
async def test_save_serie_to_db_updates_existing(
@ -270,20 +287,34 @@ class TestSerieScannerDatabaseHelpers:
existing = MagicMock()
existing.id = 1
existing.episode_dict = {1: [5, 6]} # Different episodes
existing.folder = sample_serie.folder
# Mock existing episodes (different from sample_serie)
mock_existing_episodes = [
MagicMock(season=1, episode_number=5),
MagicMock(season=1, episode_number=6),
]
with patch(
'src.server.database.service.AnimeSeriesService'
) as mock_service:
mock_service.get_by_key = AsyncMock(return_value=existing)
mock_service.update = AsyncMock(return_value=existing)
result = await scanner._save_serie_to_db(
sample_serie, mock_db_session
)
assert result is existing
mock_service.update.assert_called_once()
with patch(
'src.server.database.service.EpisodeService'
) as mock_ep_service:
mock_service.get_by_key = AsyncMock(return_value=existing)
mock_service.update = AsyncMock(return_value=existing)
mock_ep_service.get_by_series = AsyncMock(
return_value=mock_existing_episodes
)
mock_ep_service.create = AsyncMock()
result = await scanner._save_serie_to_db(
sample_serie, mock_db_session
)
assert result is existing
# Should have created new episodes
assert mock_ep_service.create.called
@pytest.mark.asyncio
async def test_save_serie_to_db_skips_unchanged(
@ -294,19 +325,33 @@ class TestSerieScannerDatabaseHelpers:
existing = MagicMock()
existing.id = 1
existing.episode_dict = sample_serie.episodeDict # Same episodes
existing.folder = sample_serie.folder
# Mock episodes matching sample_serie.episodeDict
mock_existing_episodes = []
for season, ep_nums in sample_serie.episodeDict.items():
for ep_num in ep_nums:
mock_existing_episodes.append(
MagicMock(season=season, episode_number=ep_num)
)
with patch(
'src.server.database.service.AnimeSeriesService'
) as mock_service:
mock_service.get_by_key = AsyncMock(return_value=existing)
result = await scanner._save_serie_to_db(
sample_serie, mock_db_session
)
assert result is None
mock_service.update.assert_not_called()
with patch(
'src.server.database.service.EpisodeService'
) as mock_ep_service:
mock_service.get_by_key = AsyncMock(return_value=existing)
mock_ep_service.get_by_series = AsyncMock(
return_value=mock_existing_episodes
)
result = await scanner._save_serie_to_db(
sample_serie, mock_db_session
)
assert result is None
mock_service.update.assert_not_called()
@pytest.mark.asyncio
async def test_update_serie_in_db_updates_existing(
@ -321,15 +366,20 @@ class TestSerieScannerDatabaseHelpers:
with patch(
'src.server.database.service.AnimeSeriesService'
) as mock_service:
mock_service.get_by_key = AsyncMock(return_value=existing)
mock_service.update = AsyncMock(return_value=existing)
result = await scanner._update_serie_in_db(
sample_serie, mock_db_session
)
assert result is existing
mock_service.update.assert_called_once()
with patch(
'src.server.database.service.EpisodeService'
) as mock_ep_service:
mock_service.get_by_key = AsyncMock(return_value=existing)
mock_service.update = AsyncMock(return_value=existing)
mock_ep_service.get_by_series = AsyncMock(return_value=[])
mock_ep_service.create = AsyncMock()
result = await scanner._update_serie_in_db(
sample_serie, mock_db_session
)
assert result is existing
mock_service.update.assert_called_once()
@pytest.mark.asyncio
async def test_update_serie_in_db_returns_none_if_not_found(