better db model
This commit is contained in:
parent
942f14f746
commit
798461a1ea
880
instructions.md
880
instructions.md
@ -120,883 +120,3 @@ For each task completed:
|
||||
- Good foundation for future enhancements if needed
|
||||
|
||||
---
|
||||
|
||||
## ✅ Completed: Download Queue Migration to SQLite Database
|
||||
|
||||
The download queue has been successfully migrated from JSON file to SQLite database:
|
||||
|
||||
| Component | Status | Description |
|
||||
| --------------------- | ------- | ------------------------------------------------- |
|
||||
| QueueRepository | ✅ Done | `src/server/services/queue_repository.py` |
|
||||
| DownloadService | ✅ Done | Refactored to use repository pattern |
|
||||
| Application Startup | ✅ Done | Queue restored from database on startup |
|
||||
| API Endpoints | ✅ Done | All endpoints work with database-backed queue |
|
||||
| Tests Updated | ✅ Done | All 1104 tests passing with MockQueueRepository |
|
||||
| Documentation Updated | ✅ Done | `infrastructure.md` updated with new architecture |
|
||||
|
||||
**Key Changes:**
|
||||
|
||||
- `DownloadService` no longer uses `persistence_path` parameter
|
||||
- Queue state is persisted to SQLite via `QueueRepository`
|
||||
- In-memory cache maintained for performance
|
||||
- All tests use `MockQueueRepository` fixture
|
||||
|
||||
---
|
||||
|
||||
## 🧪 Tests for Download Queue Database Migration
|
||||
|
||||
### Unit Tests
|
||||
|
||||
**File:** `tests/unit/test_queue_repository.py`
|
||||
|
||||
```python
|
||||
"""Unit tests for QueueRepository database adapter."""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from src.server.services.queue_repository import QueueRepository
|
||||
from src.server.models.download import DownloadItem, DownloadStatus, DownloadPriority
|
||||
from src.server.database.models import DownloadQueueItem as DBDownloadQueueItem
|
||||
|
||||
|
||||
class TestQueueRepository:
|
||||
"""Test suite for QueueRepository."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""Create mock database session."""
|
||||
session = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def repository(self, mock_db_session):
|
||||
"""Create repository instance with mock session."""
|
||||
return QueueRepository(db_session_factory=lambda: mock_db_session)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_download_item(self):
|
||||
"""Create sample DownloadItem for testing."""
|
||||
return DownloadItem(
|
||||
id="test-uuid-123",
|
||||
series_key="attack-on-titan",
|
||||
series_name="Attack on Titan",
|
||||
season=1,
|
||||
episode=5,
|
||||
status=DownloadStatus.PENDING,
|
||||
priority=DownloadPriority.NORMAL,
|
||||
progress_percent=0.0,
|
||||
downloaded_bytes=0,
|
||||
total_bytes=None,
|
||||
)
|
||||
|
||||
# === Conversion Tests ===
|
||||
|
||||
async def test_convert_to_db_model(self, repository, sample_download_item):
|
||||
"""Test converting DownloadItem to database model."""
|
||||
# Arrange
|
||||
series_id = 42
|
||||
|
||||
# Act
|
||||
db_item = repository._to_db_model(sample_download_item, series_id)
|
||||
|
||||
# Assert
|
||||
assert db_item.series_id == series_id
|
||||
assert db_item.season == sample_download_item.season
|
||||
assert db_item.episode_number == sample_download_item.episode
|
||||
assert db_item.status == sample_download_item.status
|
||||
assert db_item.priority == sample_download_item.priority
|
||||
|
||||
async def test_convert_from_db_model(self, repository):
|
||||
"""Test converting database model to DownloadItem."""
|
||||
# Arrange
|
||||
db_item = MagicMock()
|
||||
db_item.id = 1
|
||||
db_item.series_id = 42
|
||||
db_item.series.key = "attack-on-titan"
|
||||
db_item.series.name = "Attack on Titan"
|
||||
db_item.season = 1
|
||||
db_item.episode_number = 5
|
||||
db_item.status = DownloadStatus.PENDING
|
||||
db_item.priority = DownloadPriority.NORMAL
|
||||
db_item.progress_percent = 25.5
|
||||
db_item.downloaded_bytes = 1024000
|
||||
db_item.total_bytes = 4096000
|
||||
|
||||
# Act
|
||||
item = repository._from_db_model(db_item)
|
||||
|
||||
# Assert
|
||||
assert item.series_key == "attack-on-titan"
|
||||
assert item.series_name == "Attack on Titan"
|
||||
assert item.season == 1
|
||||
assert item.episode == 5
|
||||
assert item.progress_percent == 25.5
|
||||
|
||||
# === CRUD Operation Tests ===
|
||||
|
||||
async def test_save_item_creates_new_record(self, repository, mock_db_session, sample_download_item):
|
||||
"""Test saving a new download item to database."""
|
||||
# Arrange
|
||||
mock_db_session.execute.return_value.scalar_one_or_none.return_value = MagicMock(id=42)
|
||||
|
||||
# Act
|
||||
result = await repository.save_item(sample_download_item)
|
||||
|
||||
# Assert
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.flush.assert_called_once()
|
||||
assert result is not None
|
||||
|
||||
async def test_get_pending_items_returns_ordered_list(self, repository, mock_db_session):
|
||||
"""Test retrieving pending items ordered by priority."""
|
||||
# Arrange
|
||||
mock_items = [MagicMock(), MagicMock()]
|
||||
mock_db_session.execute.return_value.scalars.return_value.all.return_value = mock_items
|
||||
|
||||
# Act
|
||||
result = await repository.get_pending_items()
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
mock_db_session.execute.assert_called_once()
|
||||
|
||||
async def test_update_status_success(self, repository, mock_db_session):
|
||||
"""Test updating item status."""
|
||||
# Arrange
|
||||
mock_item = MagicMock()
|
||||
mock_db_session.execute.return_value.scalar_one_or_none.return_value = mock_item
|
||||
|
||||
# Act
|
||||
result = await repository.update_status("test-id", DownloadStatus.DOWNLOADING)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert mock_item.status == DownloadStatus.DOWNLOADING
|
||||
|
||||
async def test_update_status_item_not_found(self, repository, mock_db_session):
|
||||
"""Test updating status for non-existent item."""
|
||||
# Arrange
|
||||
mock_db_session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
|
||||
# Act
|
||||
result = await repository.update_status("non-existent", DownloadStatus.DOWNLOADING)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
async def test_update_progress(self, repository, mock_db_session):
|
||||
"""Test updating download progress."""
|
||||
# Arrange
|
||||
mock_item = MagicMock()
|
||||
mock_db_session.execute.return_value.scalar_one_or_none.return_value = mock_item
|
||||
|
||||
# Act
|
||||
result = await repository.update_progress(
|
||||
item_id="test-id",
|
||||
progress=50.0,
|
||||
downloaded=2048000,
|
||||
total=4096000,
|
||||
speed=1024000.0
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert mock_item.progress_percent == 50.0
|
||||
assert mock_item.downloaded_bytes == 2048000
|
||||
|
||||
async def test_delete_item_success(self, repository, mock_db_session):
|
||||
"""Test deleting download item."""
|
||||
# Arrange
|
||||
mock_db_session.execute.return_value.rowcount = 1
|
||||
|
||||
# Act
|
||||
result = await repository.delete_item("test-id")
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
async def test_clear_completed_returns_count(self, repository, mock_db_session):
|
||||
"""Test clearing completed items returns count."""
|
||||
# Arrange
|
||||
mock_db_session.execute.return_value.rowcount = 5
|
||||
|
||||
# Act
|
||||
result = await repository.clear_completed()
|
||||
|
||||
# Assert
|
||||
assert result == 5
|
||||
|
||||
|
||||
class TestQueueRepositoryErrorHandling:
|
||||
"""Test error handling in QueueRepository."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""Create mock database session."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def repository(self, mock_db_session):
|
||||
"""Create repository instance."""
|
||||
return QueueRepository(db_session_factory=lambda: mock_db_session)
|
||||
|
||||
async def test_save_item_handles_database_error(self, repository, mock_db_session):
|
||||
"""Test handling database errors on save."""
|
||||
# Arrange
|
||||
mock_db_session.execute.side_effect = Exception("Database connection failed")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception):
|
||||
await repository.save_item(MagicMock())
|
||||
|
||||
async def test_get_items_handles_database_error(self, repository, mock_db_session):
|
||||
"""Test handling database errors on query."""
|
||||
# Arrange
|
||||
mock_db_session.execute.side_effect = Exception("Query failed")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception):
|
||||
await repository.get_pending_items()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
**File:** `tests/unit/test_download_service_database.py`
|
||||
|
||||
```python
|
||||
"""Unit tests for DownloadService with database persistence."""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from src.server.services.download_service import DownloadService
|
||||
from src.server.models.download import DownloadItem, DownloadStatus, DownloadPriority
|
||||
|
||||
|
||||
class TestDownloadServiceDatabasePersistence:
|
||||
"""Test DownloadService database persistence."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_anime_service(self):
|
||||
"""Create mock anime service."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_queue_repository(self):
|
||||
"""Create mock queue repository."""
|
||||
repo = AsyncMock()
|
||||
repo.get_pending_items.return_value = []
|
||||
repo.get_active_item.return_value = None
|
||||
repo.get_completed_items.return_value = []
|
||||
repo.get_failed_items.return_value = []
|
||||
return repo
|
||||
|
||||
@pytest.fixture
|
||||
def download_service(self, mock_anime_service, mock_queue_repository):
|
||||
"""Create download service with mocked dependencies."""
|
||||
return DownloadService(
|
||||
anime_service=mock_anime_service,
|
||||
queue_repository=mock_queue_repository,
|
||||
)
|
||||
|
||||
# === Persistence Tests ===
|
||||
|
||||
async def test_add_to_queue_saves_to_database(
|
||||
self, download_service, mock_queue_repository
|
||||
):
|
||||
"""Test that adding to queue persists to database."""
|
||||
# Arrange
|
||||
mock_queue_repository.save_item.return_value = MagicMock(id="new-id")
|
||||
|
||||
# Act
|
||||
result = await download_service.add_to_queue(
|
||||
series_key="test-series",
|
||||
season=1,
|
||||
episode=1,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_queue_repository.save_item.assert_called_once()
|
||||
|
||||
async def test_startup_loads_from_database(
|
||||
self, mock_anime_service, mock_queue_repository
|
||||
):
|
||||
"""Test that startup loads queue state from database."""
|
||||
# Arrange
|
||||
pending_items = [
|
||||
MagicMock(id="1", status=DownloadStatus.PENDING),
|
||||
MagicMock(id="2", status=DownloadStatus.PENDING),
|
||||
]
|
||||
mock_queue_repository.get_pending_items.return_value = pending_items
|
||||
|
||||
# Act
|
||||
service = DownloadService(
|
||||
anime_service=mock_anime_service,
|
||||
queue_repository=mock_queue_repository,
|
||||
)
|
||||
await service.initialize()
|
||||
|
||||
# Assert
|
||||
mock_queue_repository.get_pending_items.assert_called()
|
||||
|
||||
async def test_download_completion_updates_database(
|
||||
self, download_service, mock_queue_repository
|
||||
):
|
||||
"""Test that download completion updates database status."""
|
||||
# Arrange
|
||||
item = MagicMock(id="test-id")
|
||||
|
||||
# Act
|
||||
await download_service._mark_completed(item)
|
||||
|
||||
# Assert
|
||||
mock_queue_repository.update_status.assert_called_with(
|
||||
"test-id", DownloadStatus.COMPLETED, error=None
|
||||
)
|
||||
|
||||
async def test_download_failure_updates_database(
|
||||
self, download_service, mock_queue_repository
|
||||
):
|
||||
"""Test that download failure updates database with error."""
|
||||
# Arrange
|
||||
item = MagicMock(id="test-id")
|
||||
error_message = "Network timeout"
|
||||
|
||||
# Act
|
||||
await download_service._mark_failed(item, error_message)
|
||||
|
||||
# Assert
|
||||
mock_queue_repository.update_status.assert_called_with(
|
||||
"test-id", DownloadStatus.FAILED, error=error_message
|
||||
)
|
||||
|
||||
async def test_progress_update_persists_to_database(
|
||||
self, download_service, mock_queue_repository
|
||||
):
|
||||
"""Test that progress updates are persisted."""
|
||||
# Arrange
|
||||
item = MagicMock(id="test-id")
|
||||
|
||||
# Act
|
||||
await download_service._update_progress(
|
||||
item, progress=50.0, downloaded=2048, total=4096, speed=1024.0
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_queue_repository.update_progress.assert_called_with(
|
||||
item_id="test-id",
|
||||
progress=50.0,
|
||||
downloaded=2048,
|
||||
total=4096,
|
||||
speed=1024.0,
|
||||
)
|
||||
|
||||
async def test_remove_from_queue_deletes_from_database(
|
||||
self, download_service, mock_queue_repository
|
||||
):
|
||||
"""Test that removing from queue deletes from database."""
|
||||
# Arrange
|
||||
mock_queue_repository.delete_item.return_value = True
|
||||
|
||||
# Act
|
||||
result = await download_service.remove_from_queue("test-id")
|
||||
|
||||
# Assert
|
||||
mock_queue_repository.delete_item.assert_called_with("test-id")
|
||||
assert result is True
|
||||
|
||||
async def test_clear_completed_clears_database(
|
||||
self, download_service, mock_queue_repository
|
||||
):
|
||||
"""Test that clearing completed items updates database."""
|
||||
# Arrange
|
||||
mock_queue_repository.clear_completed.return_value = 5
|
||||
|
||||
# Act
|
||||
result = await download_service.clear_completed()
|
||||
|
||||
# Assert
|
||||
mock_queue_repository.clear_completed.assert_called_once()
|
||||
assert result == 5
|
||||
|
||||
|
||||
class TestDownloadServiceNoJsonFile:
|
||||
"""Verify DownloadService no longer uses JSON files."""
|
||||
|
||||
async def test_no_json_file_operations(self):
|
||||
"""Verify no JSON file read/write operations exist."""
|
||||
import inspect
|
||||
from src.server.services.download_service import DownloadService
|
||||
|
||||
source = inspect.getsource(DownloadService)
|
||||
|
||||
# Assert no JSON file operations
|
||||
assert "download_queue.json" not in source
|
||||
assert "_load_queue" not in source or "database" in source.lower()
|
||||
assert "_save_queue" not in source or "database" in source.lower()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Integration Tests
|
||||
|
||||
**File:** `tests/integration/test_queue_database_integration.py`
|
||||
|
||||
```python
|
||||
"""Integration tests for download queue database operations."""
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from src.server.database.base import Base
|
||||
from src.server.database.models import AnimeSeries, DownloadQueueItem, DownloadStatus, DownloadPriority
|
||||
from src.server.database.service import DownloadQueueService, AnimeSeriesService
|
||||
from src.server.services.queue_repository import QueueRepository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""Create async test database engine."""
|
||||
engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session(async_engine):
|
||||
"""Create async session for tests."""
|
||||
async_session_maker = sessionmaker(
|
||||
async_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
async with async_session_maker() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def test_series(async_session):
|
||||
"""Create test anime series."""
|
||||
series = await AnimeSeriesService.create(
|
||||
db=async_session,
|
||||
key="test-anime",
|
||||
name="Test Anime",
|
||||
site="https://example.com/test-anime",
|
||||
folder="Test Anime (2024)",
|
||||
)
|
||||
await async_session.commit()
|
||||
return series
|
||||
|
||||
|
||||
class TestQueueDatabaseIntegration:
|
||||
"""Integration tests for queue database operations."""
|
||||
|
||||
async def test_create_and_retrieve_queue_item(self, async_session, test_series):
|
||||
"""Test creating and retrieving a queue item."""
|
||||
# Create
|
||||
item = await DownloadQueueService.create(
|
||||
db=async_session,
|
||||
series_id=test_series.id,
|
||||
season=1,
|
||||
episode_number=5,
|
||||
priority=DownloadPriority.HIGH,
|
||||
)
|
||||
await async_session.commit()
|
||||
|
||||
# Retrieve
|
||||
retrieved = await DownloadQueueService.get_by_id(async_session, item.id)
|
||||
|
||||
# Assert
|
||||
assert retrieved is not None
|
||||
assert retrieved.series_id == test_series.id
|
||||
assert retrieved.season == 1
|
||||
assert retrieved.episode_number == 5
|
||||
assert retrieved.priority == DownloadPriority.HIGH
|
||||
assert retrieved.status == DownloadStatus.PENDING
|
||||
|
||||
async def test_update_download_progress(self, async_session, test_series):
|
||||
"""Test updating download progress."""
|
||||
# Create item
|
||||
item = await DownloadQueueService.create(
|
||||
db=async_session,
|
||||
series_id=test_series.id,
|
||||
season=1,
|
||||
episode_number=1,
|
||||
)
|
||||
await async_session.commit()
|
||||
|
||||
# Update progress
|
||||
updated = await DownloadQueueService.update_progress(
|
||||
db=async_session,
|
||||
item_id=item.id,
|
||||
progress_percent=75.5,
|
||||
downloaded_bytes=3072000,
|
||||
total_bytes=4096000,
|
||||
download_speed=1024000.0,
|
||||
)
|
||||
await async_session.commit()
|
||||
|
||||
# Assert
|
||||
assert updated.progress_percent == 75.5
|
||||
assert updated.downloaded_bytes == 3072000
|
||||
assert updated.total_bytes == 4096000
|
||||
assert updated.download_speed == 1024000.0
|
||||
|
||||
async def test_status_transitions(self, async_session, test_series):
|
||||
"""Test download status transitions."""
|
||||
# Create pending item
|
||||
item = await DownloadQueueService.create(
|
||||
db=async_session,
|
||||
series_id=test_series.id,
|
||||
season=1,
|
||||
episode_number=1,
|
||||
)
|
||||
await async_session.commit()
|
||||
assert item.status == DownloadStatus.PENDING
|
||||
|
||||
# Transition to downloading
|
||||
item = await DownloadQueueService.update_status(
|
||||
async_session, item.id, DownloadStatus.DOWNLOADING
|
||||
)
|
||||
await async_session.commit()
|
||||
assert item.status == DownloadStatus.DOWNLOADING
|
||||
assert item.started_at is not None
|
||||
|
||||
# Transition to completed
|
||||
item = await DownloadQueueService.update_status(
|
||||
async_session, item.id, DownloadStatus.COMPLETED
|
||||
)
|
||||
await async_session.commit()
|
||||
assert item.status == DownloadStatus.COMPLETED
|
||||
assert item.completed_at is not None
|
||||
|
||||
async def test_failed_download_with_retry(self, async_session, test_series):
|
||||
"""Test failed download with error message and retry count."""
|
||||
# Create item
|
||||
item = await DownloadQueueService.create(
|
||||
db=async_session,
|
||||
series_id=test_series.id,
|
||||
season=1,
|
||||
episode_number=1,
|
||||
)
|
||||
await async_session.commit()
|
||||
|
||||
# Mark as failed with error
|
||||
item = await DownloadQueueService.update_status(
|
||||
async_session,
|
||||
item.id,
|
||||
DownloadStatus.FAILED,
|
||||
error_message="Connection timeout",
|
||||
)
|
||||
await async_session.commit()
|
||||
|
||||
# Assert
|
||||
assert item.status == DownloadStatus.FAILED
|
||||
assert item.error_message == "Connection timeout"
|
||||
assert item.retry_count == 1
|
||||
|
||||
async def test_get_pending_items_ordered_by_priority(self, async_session, test_series):
|
||||
"""Test retrieving pending items ordered by priority."""
|
||||
# Create items with different priorities
|
||||
await DownloadQueueService.create(
|
||||
async_session, test_series.id, 1, 1, priority=DownloadPriority.LOW
|
||||
)
|
||||
await DownloadQueueService.create(
|
||||
async_session, test_series.id, 1, 2, priority=DownloadPriority.HIGH
|
||||
)
|
||||
await DownloadQueueService.create(
|
||||
async_session, test_series.id, 1, 3, priority=DownloadPriority.NORMAL
|
||||
)
|
||||
await async_session.commit()
|
||||
|
||||
# Get pending items
|
||||
pending = await DownloadQueueService.get_pending(async_session)
|
||||
|
||||
# Assert order: HIGH -> NORMAL -> LOW
|
||||
assert len(pending) == 3
|
||||
assert pending[0].priority == DownloadPriority.HIGH
|
||||
assert pending[1].priority == DownloadPriority.NORMAL
|
||||
assert pending[2].priority == DownloadPriority.LOW
|
||||
|
||||
async def test_clear_completed_items(self, async_session, test_series):
|
||||
"""Test clearing completed download items."""
|
||||
# Create items
|
||||
item1 = await DownloadQueueService.create(
|
||||
async_session, test_series.id, 1, 1
|
||||
)
|
||||
item2 = await DownloadQueueService.create(
|
||||
async_session, test_series.id, 1, 2
|
||||
)
|
||||
item3 = await DownloadQueueService.create(
|
||||
async_session, test_series.id, 1, 3
|
||||
)
|
||||
|
||||
# Complete first two
|
||||
await DownloadQueueService.update_status(
|
||||
async_session, item1.id, DownloadStatus.COMPLETED
|
||||
)
|
||||
await DownloadQueueService.update_status(
|
||||
async_session, item2.id, DownloadStatus.COMPLETED
|
||||
)
|
||||
await async_session.commit()
|
||||
|
||||
# Clear completed
|
||||
cleared = await DownloadQueueService.clear_completed(async_session)
|
||||
await async_session.commit()
|
||||
|
||||
# Assert
|
||||
assert cleared == 2
|
||||
|
||||
# Verify pending item remains
|
||||
remaining = await DownloadQueueService.get_all(async_session)
|
||||
assert len(remaining) == 1
|
||||
assert remaining[0].id == item3.id
|
||||
|
||||
async def test_cascade_delete_with_series(self, async_session, test_series):
|
||||
"""Test that queue items are deleted when series is deleted."""
|
||||
# Create queue items
|
||||
await DownloadQueueService.create(
|
||||
async_session, test_series.id, 1, 1
|
||||
)
|
||||
await DownloadQueueService.create(
|
||||
async_session, test_series.id, 1, 2
|
||||
)
|
||||
await async_session.commit()
|
||||
|
||||
# Delete series
|
||||
await AnimeSeriesService.delete(async_session, test_series.id)
|
||||
await async_session.commit()
|
||||
|
||||
# Verify queue items are gone
|
||||
all_items = await DownloadQueueService.get_all(async_session)
|
||||
assert len(all_items) == 0
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### API Tests
|
||||
|
||||
**File:** `tests/api/test_queue_endpoints_database.py`
|
||||
|
||||
```python
|
||||
"""API tests for queue endpoints with database persistence."""
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from unittest.mock import patch, AsyncMock
|
||||
|
||||
|
||||
class TestQueueAPIWithDatabase:
|
||||
"""Test queue API endpoints with database backend."""
|
||||
|
||||
@pytest.fixture
|
||||
def auth_headers(self):
|
||||
"""Get authentication headers."""
|
||||
return {"Authorization": "Bearer test-token"}
|
||||
|
||||
async def test_get_queue_returns_database_items(
|
||||
self, client: AsyncClient, auth_headers
|
||||
):
|
||||
"""Test GET /api/queue returns items from database."""
|
||||
response = await client.get("/api/queue", headers=auth_headers)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "pending" in data
|
||||
assert "active" in data
|
||||
assert "completed" in data
|
||||
|
||||
async def test_add_to_queue_persists_to_database(
|
||||
self, client: AsyncClient, auth_headers
|
||||
):
|
||||
"""Test POST /api/queue persists item to database."""
|
||||
payload = {
|
||||
"series_key": "test-anime",
|
||||
"season": 1,
|
||||
"episode": 1,
|
||||
"priority": "normal",
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
"/api/queue",
|
||||
json=payload,
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert "id" in data
|
||||
|
||||
async def test_remove_from_queue_deletes_from_database(
|
||||
self, client: AsyncClient, auth_headers
|
||||
):
|
||||
"""Test DELETE /api/queue/{id} removes from database."""
|
||||
# First add an item
|
||||
add_response = await client.post(
|
||||
"/api/queue",
|
||||
json={"series_key": "test-anime", "season": 1, "episode": 1},
|
||||
headers=auth_headers,
|
||||
)
|
||||
item_id = add_response.json()["id"]
|
||||
|
||||
# Then delete it
|
||||
response = await client.delete(
|
||||
f"/api/queue/{item_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify it's gone
|
||||
get_response = await client.get("/api/queue", headers=auth_headers)
|
||||
queue_data = get_response.json()
|
||||
item_ids = [item["id"] for item in queue_data.get("pending", [])]
|
||||
assert item_id not in item_ids
|
||||
|
||||
async def test_queue_survives_server_restart(
|
||||
self, client: AsyncClient, auth_headers
|
||||
):
|
||||
"""Test that queue items persist across simulated restart."""
|
||||
# Add item
|
||||
add_response = await client.post(
|
||||
"/api/queue",
|
||||
json={"series_key": "test-anime", "season": 1, "episode": 5},
|
||||
headers=auth_headers,
|
||||
)
|
||||
item_id = add_response.json()["id"]
|
||||
|
||||
# Simulate restart by clearing in-memory cache
|
||||
# (In real scenario, this would be a server restart)
|
||||
|
||||
# Verify item still exists
|
||||
response = await client.get("/api/queue", headers=auth_headers)
|
||||
queue_data = response.json()
|
||||
item_ids = [item["id"] for item in queue_data.get("pending", [])]
|
||||
assert item_id in item_ids
|
||||
|
||||
async def test_clear_completed_endpoint(
|
||||
self, client: AsyncClient, auth_headers
|
||||
):
|
||||
"""Test POST /api/queue/clear-completed endpoint."""
|
||||
response = await client.post(
|
||||
"/api/queue/clear-completed",
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "cleared_count" in data
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Performance Tests
|
||||
|
||||
**File:** `tests/performance/test_queue_database_performance.py`
|
||||
|
||||
```python
|
||||
"""Performance tests for database-backed download queue."""
|
||||
import pytest
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class TestQueueDatabasePerformance:
|
||||
"""Performance tests for queue database operations."""
|
||||
|
||||
@pytest.mark.performance
|
||||
async def test_bulk_insert_performance(self, async_session, test_series):
|
||||
"""Test performance of bulk queue item insertion."""
|
||||
from src.server.database.service import DownloadQueueService
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Insert 100 queue items
|
||||
for i in range(100):
|
||||
await DownloadQueueService.create(
|
||||
async_session,
|
||||
test_series.id,
|
||||
season=1,
|
||||
episode_number=i + 1,
|
||||
)
|
||||
await async_session.commit()
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Should complete in under 2 seconds
|
||||
assert elapsed < 2.0, f"Bulk insert took {elapsed:.2f}s, expected < 2s"
|
||||
|
||||
@pytest.mark.performance
|
||||
async def test_query_performance_with_many_items(self, async_session, test_series):
|
||||
"""Test query performance with many queue items."""
|
||||
from src.server.database.service import DownloadQueueService
|
||||
|
||||
# Setup: Create 500 items
|
||||
for i in range(500):
|
||||
await DownloadQueueService.create(
|
||||
async_session,
|
||||
test_series.id,
|
||||
season=(i // 12) + 1,
|
||||
episode_number=(i % 12) + 1,
|
||||
)
|
||||
await async_session.commit()
|
||||
|
||||
# Test query performance
|
||||
start_time = time.time()
|
||||
|
||||
pending = await DownloadQueueService.get_pending(async_session)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Query should complete in under 100ms
|
||||
assert elapsed < 0.1, f"Query took {elapsed*1000:.1f}ms, expected < 100ms"
|
||||
assert len(pending) == 500
|
||||
|
||||
@pytest.mark.performance
|
||||
async def test_progress_update_performance(self, async_session, test_series):
|
||||
"""Test performance of frequent progress updates."""
|
||||
from src.server.database.service import DownloadQueueService
|
||||
|
||||
# Create item
|
||||
item = await DownloadQueueService.create(
|
||||
async_session, test_series.id, 1, 1
|
||||
)
|
||||
await async_session.commit()
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Simulate 100 progress updates (like during download)
|
||||
for i in range(100):
|
||||
await DownloadQueueService.update_progress(
|
||||
async_session,
|
||||
item.id,
|
||||
progress_percent=i,
|
||||
downloaded_bytes=i * 10240,
|
||||
total_bytes=1024000,
|
||||
download_speed=102400.0,
|
||||
)
|
||||
await async_session.commit()
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# 100 updates should complete in under 1 second
|
||||
assert elapsed < 1.0, f"Progress updates took {elapsed:.2f}s, expected < 1s"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
These tasks will migrate the download queue from JSON file persistence to SQLite database, providing:
|
||||
|
||||
1. **Data Integrity**: ACID-compliant storage with proper relationships
|
||||
2. **Query Capability**: Efficient filtering, sorting, and pagination
|
||||
3. **Consistency**: Single source of truth for all application data
|
||||
4. **Scalability**: Better performance for large queues
|
||||
5. **Recovery**: Robust handling of crashes and restarts
|
||||
|
||||
The existing database infrastructure (`DownloadQueueItem` model and `DownloadQueueService`) is already in place, making this primarily an integration task rather than new development.
|
||||
|
||||
@ -540,7 +540,7 @@ class SerieScanner:
|
||||
Save or update a series in the database.
|
||||
|
||||
Creates a new record if the series doesn't exist, or updates
|
||||
the episode_dict if it has changed.
|
||||
the episodes if they have changed.
|
||||
|
||||
Args:
|
||||
serie: Serie instance to save
|
||||
@ -549,26 +549,53 @@ class SerieScanner:
|
||||
Returns:
|
||||
Created or updated AnimeSeries instance, or None if unchanged
|
||||
"""
|
||||
from src.server.database.service import AnimeSeriesService
|
||||
from src.server.database.service import AnimeSeriesService, EpisodeService
|
||||
|
||||
# Check if series already exists
|
||||
existing = await AnimeSeriesService.get_by_key(db, serie.key)
|
||||
|
||||
if existing:
|
||||
# Update episode_dict if changed
|
||||
if existing.episode_dict != serie.episodeDict:
|
||||
updated = await AnimeSeriesService.update(
|
||||
db,
|
||||
existing.id,
|
||||
episode_dict=serie.episodeDict,
|
||||
folder=serie.folder
|
||||
)
|
||||
# Build existing episode dict from episodes for comparison
|
||||
existing_episodes = await EpisodeService.get_by_series(
|
||||
db, existing.id
|
||||
)
|
||||
existing_dict: dict[int, list[int]] = {}
|
||||
for ep in existing_episodes:
|
||||
if ep.season not in existing_dict:
|
||||
existing_dict[ep.season] = []
|
||||
existing_dict[ep.season].append(ep.episode_number)
|
||||
for season in existing_dict:
|
||||
existing_dict[season].sort()
|
||||
|
||||
# Update episodes if changed
|
||||
if existing_dict != serie.episodeDict:
|
||||
# Add new episodes
|
||||
new_dict = serie.episodeDict or {}
|
||||
for season, episode_numbers in new_dict.items():
|
||||
existing_eps = set(existing_dict.get(season, []))
|
||||
for ep_num in episode_numbers:
|
||||
if ep_num not in existing_eps:
|
||||
await EpisodeService.create(
|
||||
db=db,
|
||||
series_id=existing.id,
|
||||
season=season,
|
||||
episode_number=ep_num,
|
||||
)
|
||||
|
||||
# Update folder if changed
|
||||
if existing.folder != serie.folder:
|
||||
await AnimeSeriesService.update(
|
||||
db,
|
||||
existing.id,
|
||||
folder=serie.folder
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Updated series in database: %s (key=%s)",
|
||||
serie.name,
|
||||
serie.key
|
||||
)
|
||||
return updated
|
||||
return existing
|
||||
else:
|
||||
logger.debug(
|
||||
"Series unchanged in database: %s (key=%s)",
|
||||
@ -584,8 +611,19 @@ class SerieScanner:
|
||||
name=serie.name,
|
||||
site=serie.site,
|
||||
folder=serie.folder,
|
||||
episode_dict=serie.episodeDict,
|
||||
)
|
||||
|
||||
# Create Episode records
|
||||
if serie.episodeDict:
|
||||
for season, episode_numbers in serie.episodeDict.items():
|
||||
for ep_num in episode_numbers:
|
||||
await EpisodeService.create(
|
||||
db=db,
|
||||
series_id=anime_series.id,
|
||||
season=season,
|
||||
episode_number=ep_num,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Created series in database: %s (key=%s)",
|
||||
serie.name,
|
||||
@ -608,7 +646,7 @@ class SerieScanner:
|
||||
Returns:
|
||||
Updated AnimeSeries instance, or None if not found
|
||||
"""
|
||||
from src.server.database.service import AnimeSeriesService
|
||||
from src.server.database.service import AnimeSeriesService, EpisodeService
|
||||
|
||||
existing = await AnimeSeriesService.get_by_key(db, serie.key)
|
||||
if not existing:
|
||||
@ -619,20 +657,43 @@ class SerieScanner:
|
||||
)
|
||||
return None
|
||||
|
||||
updated = await AnimeSeriesService.update(
|
||||
# Update basic fields
|
||||
await AnimeSeriesService.update(
|
||||
db,
|
||||
existing.id,
|
||||
name=serie.name,
|
||||
site=serie.site,
|
||||
folder=serie.folder,
|
||||
episode_dict=serie.episodeDict,
|
||||
)
|
||||
|
||||
# Update episodes - add any new ones
|
||||
if serie.episodeDict:
|
||||
existing_episodes = await EpisodeService.get_by_series(
|
||||
db, existing.id
|
||||
)
|
||||
existing_dict: dict[int, set[int]] = {}
|
||||
for ep in existing_episodes:
|
||||
if ep.season not in existing_dict:
|
||||
existing_dict[ep.season] = set()
|
||||
existing_dict[ep.season].add(ep.episode_number)
|
||||
|
||||
for season, episode_numbers in serie.episodeDict.items():
|
||||
existing_eps = existing_dict.get(season, set())
|
||||
for ep_num in episode_numbers:
|
||||
if ep_num not in existing_eps:
|
||||
await EpisodeService.create(
|
||||
db=db,
|
||||
series_id=existing.id,
|
||||
season=season,
|
||||
episode_number=ep_num,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Updated series in database: %s (key=%s)",
|
||||
serie.name,
|
||||
serie.key
|
||||
)
|
||||
return updated
|
||||
return existing
|
||||
|
||||
def __find_mp4_files(self) -> Iterator[tuple[str, list[str]]]:
|
||||
"""Find all .mp4 files in the directory structure."""
|
||||
|
||||
@ -147,7 +147,7 @@ class SerieList:
|
||||
if result:
|
||||
print(f"Added series: {result.name}")
|
||||
"""
|
||||
from src.server.database.service import AnimeSeriesService
|
||||
from src.server.database.service import AnimeSeriesService, EpisodeService
|
||||
|
||||
# Check if series already exists in DB
|
||||
existing = await AnimeSeriesService.get_by_key(db, serie.key)
|
||||
@ -166,9 +166,19 @@ class SerieList:
|
||||
name=serie.name,
|
||||
site=serie.site,
|
||||
folder=serie.folder,
|
||||
episode_dict=serie.episodeDict,
|
||||
)
|
||||
|
||||
# Create Episode records for each episode in episodeDict
|
||||
if serie.episodeDict:
|
||||
for season, episode_numbers in serie.episodeDict.items():
|
||||
for episode_number in episode_numbers:
|
||||
await EpisodeService.create(
|
||||
db=db,
|
||||
series_id=anime_series.id,
|
||||
season=season,
|
||||
episode_number=episode_number,
|
||||
)
|
||||
|
||||
# Also add to in-memory collection
|
||||
self.keyDict[serie.key] = serie
|
||||
|
||||
@ -267,8 +277,10 @@ class SerieList:
|
||||
# Clear existing in-memory data
|
||||
self.keyDict.clear()
|
||||
|
||||
# Load all series from database
|
||||
anime_series_list = await AnimeSeriesService.get_all(db)
|
||||
# Load all series from database (with episodes for episodeDict)
|
||||
anime_series_list = await AnimeSeriesService.get_all(
|
||||
db, with_episodes=True
|
||||
)
|
||||
|
||||
for anime_series in anime_series_list:
|
||||
serie = self._convert_from_db(anime_series)
|
||||
@ -288,23 +300,22 @@ class SerieList:
|
||||
|
||||
Args:
|
||||
anime_series: AnimeSeries model from database
|
||||
(must have episodes relationship loaded)
|
||||
|
||||
Returns:
|
||||
Serie entity instance
|
||||
"""
|
||||
# Convert episode_dict from JSON (string keys) to int keys
|
||||
# Build episode_dict from episodes relationship
|
||||
episode_dict: dict[int, list[int]] = {}
|
||||
if anime_series.episode_dict:
|
||||
for season_str, episodes in anime_series.episode_dict.items():
|
||||
try:
|
||||
season = int(season_str)
|
||||
episode_dict[season] = list(episodes)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
"Invalid season key '%s' in episode_dict for %s",
|
||||
season_str,
|
||||
anime_series.key
|
||||
)
|
||||
if anime_series.episodes:
|
||||
for episode in anime_series.episodes:
|
||||
season = episode.season
|
||||
if season not in episode_dict:
|
||||
episode_dict[season] = []
|
||||
episode_dict[season].append(episode.episode_number)
|
||||
# Sort episode numbers within each season
|
||||
for season in episode_dict:
|
||||
episode_dict[season].sort()
|
||||
|
||||
return Serie(
|
||||
key=anime_series.key,
|
||||
@ -325,19 +336,11 @@ class SerieList:
|
||||
Returns:
|
||||
Dictionary suitable for AnimeSeriesService.create()
|
||||
"""
|
||||
# Convert episode_dict keys to strings for JSON storage
|
||||
episode_dict = None
|
||||
if serie.episodeDict:
|
||||
episode_dict = {
|
||||
str(k): list(v) for k, v in serie.episodeDict.items()
|
||||
}
|
||||
|
||||
return {
|
||||
"key": serie.key,
|
||||
"name": serie.name,
|
||||
"site": serie.site,
|
||||
"folder": serie.folder,
|
||||
"episode_dict": episode_dict,
|
||||
}
|
||||
|
||||
async def contains_in_db(self, key: str, db: "AsyncSession") -> bool:
|
||||
|
||||
@ -229,37 +229,6 @@ class DatabaseIntegrityChecker:
|
||||
logger.warning(msg)
|
||||
issues_found += count
|
||||
|
||||
# Check for invalid progress percentages
|
||||
stmt = select(DownloadQueueItem).where(
|
||||
(DownloadQueueItem.progress < 0) |
|
||||
(DownloadQueueItem.progress > 100)
|
||||
)
|
||||
invalid_progress = self.session.execute(stmt).scalars().all()
|
||||
|
||||
if invalid_progress:
|
||||
count = len(invalid_progress)
|
||||
msg = (
|
||||
f"Found {count} queue items with invalid progress "
|
||||
f"percentages"
|
||||
)
|
||||
self.issues.append(msg)
|
||||
logger.warning(msg)
|
||||
issues_found += count
|
||||
|
||||
# Check for queue items with invalid status
|
||||
valid_statuses = {'pending', 'downloading', 'completed', 'failed'}
|
||||
stmt = select(DownloadQueueItem).where(
|
||||
~DownloadQueueItem.status.in_(valid_statuses)
|
||||
)
|
||||
invalid_status = self.session.execute(stmt).scalars().all()
|
||||
|
||||
if invalid_status:
|
||||
count = len(invalid_status)
|
||||
msg = f"Found {count} queue items with invalid status"
|
||||
self.issues.append(msg)
|
||||
logger.warning(msg)
|
||||
issues_found += count
|
||||
|
||||
if issues_found == 0:
|
||||
logger.info("No data consistency issues found")
|
||||
|
||||
|
||||
@ -669,7 +669,6 @@ async def add_series(
|
||||
name=request.name.strip(),
|
||||
site="aniworld.to",
|
||||
folder=folder,
|
||||
episode_dict={}, # Empty for new series
|
||||
)
|
||||
db_id = anime_series.id
|
||||
|
||||
|
||||
@ -1,479 +0,0 @@
|
||||
"""Example integration of database service with existing services.
|
||||
|
||||
This file demonstrates how to integrate the database service layer with
|
||||
existing application services like AnimeService and DownloadService.
|
||||
|
||||
These examples show patterns for:
|
||||
- Persisting scan results to database
|
||||
- Loading queue from database on startup
|
||||
- Syncing download progress to database
|
||||
- Maintaining consistency between in-memory state and database
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.core.entities.series import Serie
|
||||
from src.server.database.models import DownloadPriority, DownloadStatus
|
||||
from src.server.database.service import (
|
||||
AnimeSeriesService,
|
||||
DownloadQueueService,
|
||||
EpisodeService,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Example 1: Persist Scan Results
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def persist_scan_results(
|
||||
db: AsyncSession,
|
||||
series_list: List[Serie],
|
||||
) -> None:
|
||||
"""Persist scan results to database.
|
||||
|
||||
Updates or creates anime series and their episodes based on
|
||||
scan results from SerieScanner.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
series_list: List of Serie objects from scan
|
||||
"""
|
||||
logger.info(f"Persisting {len(series_list)} series to database")
|
||||
|
||||
for serie in series_list:
|
||||
# Check if series exists
|
||||
existing = await AnimeSeriesService.get_by_key(db, serie.key)
|
||||
|
||||
if existing:
|
||||
# Update existing series
|
||||
await AnimeSeriesService.update(
|
||||
db,
|
||||
existing.id,
|
||||
name=serie.name,
|
||||
site=serie.site,
|
||||
folder=serie.folder,
|
||||
episode_dict=serie.episode_dict,
|
||||
)
|
||||
series_id = existing.id
|
||||
else:
|
||||
# Create new series
|
||||
new_series = await AnimeSeriesService.create(
|
||||
db,
|
||||
key=serie.key,
|
||||
name=serie.name,
|
||||
site=serie.site,
|
||||
folder=serie.folder,
|
||||
episode_dict=serie.episode_dict,
|
||||
)
|
||||
series_id = new_series.id
|
||||
|
||||
# Update episodes for this series
|
||||
await _update_episodes(db, series_id, serie)
|
||||
|
||||
await db.commit()
|
||||
logger.info("Scan results persisted successfully")
|
||||
|
||||
|
||||
async def _update_episodes(
|
||||
db: AsyncSession,
|
||||
series_id: int,
|
||||
serie: Serie,
|
||||
) -> None:
|
||||
"""Update episodes for a series.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
series_id: Series ID in database
|
||||
serie: Serie object with episode information
|
||||
"""
|
||||
# Get existing episodes
|
||||
existing_episodes = await EpisodeService.get_by_series(db, series_id)
|
||||
existing_map = {
|
||||
(ep.season, ep.episode_number): ep
|
||||
for ep in existing_episodes
|
||||
}
|
||||
|
||||
# Iterate through episode_dict to create/update episodes
|
||||
for season, episodes in serie.episode_dict.items():
|
||||
for ep_num in episodes:
|
||||
key = (int(season), int(ep_num))
|
||||
|
||||
if key in existing_map:
|
||||
# Episode exists, check if downloaded
|
||||
episode = existing_map[key]
|
||||
# Update if needed (e.g., file path changed)
|
||||
if not episode.is_downloaded:
|
||||
# Check if file exists locally
|
||||
# This would be done by checking serie.local_episodes
|
||||
pass
|
||||
else:
|
||||
# Create new episode
|
||||
await EpisodeService.create(
|
||||
db,
|
||||
series_id=series_id,
|
||||
season=int(season),
|
||||
episode_number=int(ep_num),
|
||||
is_downloaded=False,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Example 2: Load Queue from Database
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def load_queue_from_database(
|
||||
db: AsyncSession,
|
||||
) -> List[dict]:
|
||||
"""Load download queue from database.
|
||||
|
||||
Retrieves pending and active download items from database and
|
||||
converts them to format suitable for DownloadService.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of download items as dictionaries
|
||||
"""
|
||||
logger.info("Loading download queue from database")
|
||||
|
||||
# Get pending and active items
|
||||
pending = await DownloadQueueService.get_pending(db)
|
||||
active = await DownloadQueueService.get_active(db)
|
||||
|
||||
all_items = pending + active
|
||||
|
||||
# Convert to dictionary format for DownloadService
|
||||
queue_items = []
|
||||
for item in all_items:
|
||||
queue_items.append({
|
||||
"id": item.id,
|
||||
"series_id": item.series_id,
|
||||
"season": item.season,
|
||||
"episode_number": item.episode_number,
|
||||
"status": item.status.value,
|
||||
"priority": item.priority.value,
|
||||
"progress_percent": item.progress_percent,
|
||||
"downloaded_bytes": item.downloaded_bytes,
|
||||
"total_bytes": item.total_bytes,
|
||||
"download_speed": item.download_speed,
|
||||
"error_message": item.error_message,
|
||||
"retry_count": item.retry_count,
|
||||
})
|
||||
|
||||
logger.info(f"Loaded {len(queue_items)} items from database")
|
||||
return queue_items
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Example 3: Sync Download Progress to Database
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def sync_download_progress(
|
||||
db: AsyncSession,
|
||||
item_id: int,
|
||||
progress_percent: float,
|
||||
downloaded_bytes: int,
|
||||
total_bytes: Optional[int] = None,
|
||||
download_speed: Optional[float] = None,
|
||||
) -> None:
|
||||
"""Sync download progress to database.
|
||||
|
||||
Updates download queue item progress in database. This would be called
|
||||
from the download progress callback.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
item_id: Download queue item ID
|
||||
progress_percent: Progress percentage (0-100)
|
||||
downloaded_bytes: Bytes downloaded
|
||||
total_bytes: Optional total file size
|
||||
download_speed: Optional current speed (bytes/sec)
|
||||
"""
|
||||
await DownloadQueueService.update_progress(
|
||||
db,
|
||||
item_id,
|
||||
progress_percent,
|
||||
downloaded_bytes,
|
||||
total_bytes,
|
||||
download_speed,
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def mark_download_complete(
|
||||
db: AsyncSession,
|
||||
item_id: int,
|
||||
file_path: str,
|
||||
file_size: int,
|
||||
) -> None:
|
||||
"""Mark download as complete in database.
|
||||
|
||||
Updates download queue item status and marks episode as downloaded.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
item_id: Download queue item ID
|
||||
file_path: Path to downloaded file
|
||||
file_size: File size in bytes
|
||||
"""
|
||||
# Get download item
|
||||
item = await DownloadQueueService.get_by_id(db, item_id)
|
||||
if not item:
|
||||
logger.error(f"Download item {item_id} not found")
|
||||
return
|
||||
|
||||
# Update download status
|
||||
await DownloadQueueService.update_status(
|
||||
db,
|
||||
item_id,
|
||||
DownloadStatus.COMPLETED,
|
||||
)
|
||||
|
||||
# Find or create episode and mark as downloaded
|
||||
episode = await EpisodeService.get_by_episode(
|
||||
db,
|
||||
item.series_id,
|
||||
item.season,
|
||||
item.episode_number,
|
||||
)
|
||||
|
||||
if episode:
|
||||
await EpisodeService.mark_downloaded(
|
||||
db,
|
||||
episode.id,
|
||||
file_path,
|
||||
file_size,
|
||||
)
|
||||
else:
|
||||
# Create episode
|
||||
episode = await EpisodeService.create(
|
||||
db,
|
||||
series_id=item.series_id,
|
||||
season=item.season,
|
||||
episode_number=item.episode_number,
|
||||
file_path=file_path,
|
||||
file_size=file_size,
|
||||
is_downloaded=True,
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
logger.info(
|
||||
f"Marked download complete: S{item.season:02d}E{item.episode_number:02d}"
|
||||
)
|
||||
|
||||
|
||||
async def mark_download_failed(
|
||||
db: AsyncSession,
|
||||
item_id: int,
|
||||
error_message: str,
|
||||
) -> None:
|
||||
"""Mark download as failed in database.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
item_id: Download queue item ID
|
||||
error_message: Error description
|
||||
"""
|
||||
await DownloadQueueService.update_status(
|
||||
db,
|
||||
item_id,
|
||||
DownloadStatus.FAILED,
|
||||
error_message=error_message,
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Example 4: Add Episodes to Download Queue
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def add_episodes_to_queue(
|
||||
db: AsyncSession,
|
||||
series_key: str,
|
||||
episodes: List[tuple[int, int]], # List of (season, episode) tuples
|
||||
priority: DownloadPriority = DownloadPriority.NORMAL,
|
||||
) -> int:
|
||||
"""Add multiple episodes to download queue.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
series_key: Series provider key
|
||||
episodes: List of (season, episode_number) tuples
|
||||
priority: Download priority
|
||||
|
||||
Returns:
|
||||
Number of episodes added to queue
|
||||
"""
|
||||
# Get series
|
||||
series = await AnimeSeriesService.get_by_key(db, series_key)
|
||||
if not series:
|
||||
logger.error(f"Series not found: {series_key}")
|
||||
return 0
|
||||
|
||||
added_count = 0
|
||||
for season, episode_number in episodes:
|
||||
# Check if already in queue
|
||||
existing_items = await DownloadQueueService.get_all(db)
|
||||
already_queued = any(
|
||||
item.series_id == series.id
|
||||
and item.season == season
|
||||
and item.episode_number == episode_number
|
||||
and item.status in (DownloadStatus.PENDING, DownloadStatus.DOWNLOADING)
|
||||
for item in existing_items
|
||||
)
|
||||
|
||||
if not already_queued:
|
||||
await DownloadQueueService.create(
|
||||
db,
|
||||
series_id=series.id,
|
||||
season=season,
|
||||
episode_number=episode_number,
|
||||
priority=priority,
|
||||
)
|
||||
added_count += 1
|
||||
|
||||
await db.commit()
|
||||
logger.info(f"Added {added_count} episodes to download queue")
|
||||
return added_count
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Example 5: Integration with AnimeService
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class EnhancedAnimeService:
|
||||
"""Enhanced AnimeService with database persistence.
|
||||
|
||||
This is an example of how to wrap the existing AnimeService with
|
||||
database persistence capabilities.
|
||||
"""
|
||||
|
||||
def __init__(self, db_session_factory):
|
||||
"""Initialize enhanced anime service.
|
||||
|
||||
Args:
|
||||
db_session_factory: Async session factory for database access
|
||||
"""
|
||||
self.db_session_factory = db_session_factory
|
||||
|
||||
async def rescan_with_persistence(self, directory: str) -> dict:
|
||||
"""Rescan directory and persist results.
|
||||
|
||||
Args:
|
||||
directory: Directory to scan
|
||||
|
||||
Returns:
|
||||
Scan results dictionary
|
||||
"""
|
||||
# Import here to avoid circular dependencies
|
||||
from src.core.SeriesApp import SeriesApp
|
||||
|
||||
# Perform scan
|
||||
app = SeriesApp(directory)
|
||||
series_list = app.ReScan()
|
||||
|
||||
# Persist to database
|
||||
async with self.db_session_factory() as db:
|
||||
await persist_scan_results(db, series_list)
|
||||
|
||||
return {
|
||||
"total_series": len(series_list),
|
||||
"message": "Scan completed and persisted to database",
|
||||
}
|
||||
|
||||
async def get_series_with_missing_episodes(self) -> List[dict]:
|
||||
"""Get series with missing episodes from database.
|
||||
|
||||
Returns:
|
||||
List of series with missing episodes
|
||||
"""
|
||||
async with self.db_session_factory() as db:
|
||||
# Get all series
|
||||
all_series = await AnimeSeriesService.get_all(
|
||||
db,
|
||||
with_episodes=True,
|
||||
)
|
||||
|
||||
# Filter series with missing episodes
|
||||
series_with_missing = []
|
||||
for series in all_series:
|
||||
if series.episode_dict:
|
||||
total_episodes = sum(
|
||||
len(eps) for eps in series.episode_dict.values()
|
||||
)
|
||||
downloaded_episodes = sum(
|
||||
1 for ep in series.episodes if ep.is_downloaded
|
||||
)
|
||||
|
||||
if downloaded_episodes < total_episodes:
|
||||
series_with_missing.append({
|
||||
"id": series.id,
|
||||
"key": series.key,
|
||||
"name": series.name,
|
||||
"total_episodes": total_episodes,
|
||||
"downloaded_episodes": downloaded_episodes,
|
||||
"missing_episodes": total_episodes - downloaded_episodes,
|
||||
})
|
||||
|
||||
return series_with_missing
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Usage Example
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def example_usage():
|
||||
"""Example usage of database service integration."""
|
||||
from src.server.database import get_db_session
|
||||
|
||||
# Get database session
|
||||
async with get_db_session() as db:
|
||||
# Example 1: Add episodes to queue
|
||||
added = await add_episodes_to_queue(
|
||||
db,
|
||||
series_key="attack-on-titan",
|
||||
episodes=[(1, 1), (1, 2), (1, 3)],
|
||||
priority=DownloadPriority.HIGH,
|
||||
)
|
||||
print(f"Added {added} episodes to queue")
|
||||
|
||||
# Example 2: Load queue
|
||||
queue_items = await load_queue_from_database(db)
|
||||
print(f"Queue has {len(queue_items)} items")
|
||||
|
||||
# Example 3: Update progress
|
||||
if queue_items:
|
||||
await sync_download_progress(
|
||||
db,
|
||||
item_id=queue_items[0]["id"],
|
||||
progress_percent=50.0,
|
||||
downloaded_bytes=500000,
|
||||
total_bytes=1000000,
|
||||
)
|
||||
|
||||
# Example 4: Mark complete
|
||||
if queue_items:
|
||||
await mark_download_complete(
|
||||
db,
|
||||
item_id=queue_items[0]["id"],
|
||||
file_path="/path/to/file.mp4",
|
||||
file_size=1000000,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(example_usage())
|
||||
@ -47,7 +47,7 @@ EXPECTED_INDEXES = {
|
||||
"episodes": ["ix_episodes_series_id"],
|
||||
"download_queue": [
|
||||
"ix_download_queue_series_id",
|
||||
"ix_download_queue_status",
|
||||
"ix_download_queue_episode_id",
|
||||
],
|
||||
"user_sessions": [
|
||||
"ix_user_sessions_session_id",
|
||||
|
||||
@ -15,18 +15,7 @@ from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
Boolean,
|
||||
DateTime,
|
||||
Float,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy import Enum as SQLEnum
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, String, Text, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship, validates
|
||||
|
||||
from src.server.database.base import Base, TimestampMixin
|
||||
@ -51,10 +40,6 @@ class AnimeSeries(Base, TimestampMixin):
|
||||
name: Display name of the series
|
||||
site: Provider site URL
|
||||
folder: Filesystem folder name (metadata only, not for lookups)
|
||||
description: Optional series description
|
||||
status: Current status (ongoing, completed, etc.)
|
||||
total_episodes: Total number of episodes
|
||||
cover_url: URL to series cover image
|
||||
episodes: Relationship to Episode models (via id foreign key)
|
||||
download_items: Relationship to DownloadQueueItem models (via id foreign key)
|
||||
created_at: Creation timestamp (from TimestampMixin)
|
||||
@ -89,30 +74,6 @@ class AnimeSeries(Base, TimestampMixin):
|
||||
doc="Filesystem folder name - METADATA ONLY, not for lookups"
|
||||
)
|
||||
|
||||
# Metadata
|
||||
description: Mapped[Optional[str]] = mapped_column(
|
||||
Text, nullable=True,
|
||||
doc="Series description"
|
||||
)
|
||||
status: Mapped[Optional[str]] = mapped_column(
|
||||
String(50), nullable=True,
|
||||
doc="Series status (ongoing, completed, etc.)"
|
||||
)
|
||||
total_episodes: Mapped[Optional[int]] = mapped_column(
|
||||
Integer, nullable=True,
|
||||
doc="Total number of episodes"
|
||||
)
|
||||
cover_url: Mapped[Optional[str]] = mapped_column(
|
||||
String(1000), nullable=True,
|
||||
doc="URL to cover image"
|
||||
)
|
||||
|
||||
# JSON field for episode dictionary (season -> [episodes])
|
||||
episode_dict: Mapped[Optional[dict]] = mapped_column(
|
||||
JSON, nullable=True,
|
||||
doc="Episode dictionary {season: [episodes]}"
|
||||
)
|
||||
|
||||
# Relationships
|
||||
episodes: Mapped[List["Episode"]] = relationship(
|
||||
"Episode",
|
||||
@ -161,22 +122,6 @@ class AnimeSeries(Base, TimestampMixin):
|
||||
raise ValueError("Folder path must be 1000 characters or less")
|
||||
return value.strip()
|
||||
|
||||
@validates('cover_url')
|
||||
def validate_cover_url(self, key: str, value: Optional[str]) -> Optional[str]:
|
||||
"""Validate cover URL length."""
|
||||
if value is not None and len(value) > 1000:
|
||||
raise ValueError("Cover URL must be 1000 characters or less")
|
||||
return value
|
||||
|
||||
@validates('total_episodes')
|
||||
def validate_total_episodes(self, key: str, value: Optional[int]) -> Optional[int]:
|
||||
"""Validate total episodes is positive."""
|
||||
if value is not None and value < 0:
|
||||
raise ValueError("Total episodes must be non-negative")
|
||||
if value is not None and value > 10000:
|
||||
raise ValueError("Total episodes must be 10000 or less")
|
||||
return value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AnimeSeries(id={self.id}, key='{self.key}', name='{self.name}')>"
|
||||
|
||||
@ -194,9 +139,7 @@ class Episode(Base, TimestampMixin):
|
||||
episode_number: Episode number within season
|
||||
title: Episode title
|
||||
file_path: Local file path if downloaded
|
||||
file_size: File size in bytes
|
||||
is_downloaded: Whether episode is downloaded
|
||||
download_date: When episode was downloaded
|
||||
series: Relationship to AnimeSeries
|
||||
created_at: Creation timestamp (from TimestampMixin)
|
||||
updated_at: Last update timestamp (from TimestampMixin)
|
||||
@ -234,18 +177,10 @@ class Episode(Base, TimestampMixin):
|
||||
String(1000), nullable=True,
|
||||
doc="Local file path"
|
||||
)
|
||||
file_size: Mapped[Optional[int]] = mapped_column(
|
||||
Integer, nullable=True,
|
||||
doc="File size in bytes"
|
||||
)
|
||||
is_downloaded: Mapped[bool] = mapped_column(
|
||||
Boolean, default=False, nullable=False,
|
||||
doc="Whether episode is downloaded"
|
||||
)
|
||||
download_date: Mapped[Optional[datetime]] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True,
|
||||
doc="When episode was downloaded"
|
||||
)
|
||||
|
||||
# Relationship
|
||||
series: Mapped["AnimeSeries"] = relationship(
|
||||
@ -287,13 +222,6 @@ class Episode(Base, TimestampMixin):
|
||||
raise ValueError("File path must be 1000 characters or less")
|
||||
return value
|
||||
|
||||
@validates('file_size')
|
||||
def validate_file_size(self, key: str, value: Optional[int]) -> Optional[int]:
|
||||
"""Validate file size is non-negative."""
|
||||
if value is not None and value < 0:
|
||||
raise ValueError("File size must be non-negative")
|
||||
return value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<Episode(id={self.id}, series_id={self.series_id}, "
|
||||
@ -321,27 +249,20 @@ class DownloadPriority(str, Enum):
|
||||
class DownloadQueueItem(Base, TimestampMixin):
|
||||
"""SQLAlchemy model for download queue items.
|
||||
|
||||
Tracks download queue with status, progress, and error information.
|
||||
Tracks download queue with error information.
|
||||
Provides persistence for the DownloadService queue state.
|
||||
|
||||
Attributes:
|
||||
id: Primary key
|
||||
series_id: Foreign key to AnimeSeries
|
||||
season: Season number
|
||||
episode_number: Episode number
|
||||
status: Current download status
|
||||
priority: Download priority
|
||||
progress_percent: Download progress (0-100)
|
||||
downloaded_bytes: Bytes downloaded
|
||||
total_bytes: Total file size
|
||||
download_speed: Current speed in bytes/sec
|
||||
episode_id: Foreign key to Episode
|
||||
error_message: Error description if failed
|
||||
retry_count: Number of retry attempts
|
||||
download_url: Provider download URL
|
||||
file_destination: Target file path
|
||||
started_at: When download started
|
||||
completed_at: When download completed
|
||||
series: Relationship to AnimeSeries
|
||||
episode: Relationship to Episode
|
||||
created_at: Creation timestamp (from TimestampMixin)
|
||||
updated_at: Last update timestamp (from TimestampMixin)
|
||||
"""
|
||||
@ -359,47 +280,11 @@ class DownloadQueueItem(Base, TimestampMixin):
|
||||
index=True
|
||||
)
|
||||
|
||||
# Episode identification
|
||||
season: Mapped[int] = mapped_column(
|
||||
Integer, nullable=False,
|
||||
doc="Season number"
|
||||
)
|
||||
episode_number: Mapped[int] = mapped_column(
|
||||
Integer, nullable=False,
|
||||
doc="Episode number"
|
||||
)
|
||||
|
||||
# Queue management
|
||||
status: Mapped[str] = mapped_column(
|
||||
SQLEnum(DownloadStatus),
|
||||
default=DownloadStatus.PENDING,
|
||||
# Foreign key to episode
|
||||
episode_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("episodes.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
doc="Current download status"
|
||||
)
|
||||
priority: Mapped[str] = mapped_column(
|
||||
SQLEnum(DownloadPriority),
|
||||
default=DownloadPriority.NORMAL,
|
||||
nullable=False,
|
||||
doc="Download priority"
|
||||
)
|
||||
|
||||
# Progress tracking
|
||||
progress_percent: Mapped[float] = mapped_column(
|
||||
Float, default=0.0, nullable=False,
|
||||
doc="Progress percentage (0-100)"
|
||||
)
|
||||
downloaded_bytes: Mapped[int] = mapped_column(
|
||||
Integer, default=0, nullable=False,
|
||||
doc="Bytes downloaded"
|
||||
)
|
||||
total_bytes: Mapped[Optional[int]] = mapped_column(
|
||||
Integer, nullable=True,
|
||||
doc="Total file size"
|
||||
)
|
||||
download_speed: Mapped[Optional[float]] = mapped_column(
|
||||
Float, nullable=True,
|
||||
doc="Current download speed (bytes/sec)"
|
||||
index=True
|
||||
)
|
||||
|
||||
# Error handling
|
||||
@ -407,10 +292,6 @@ class DownloadQueueItem(Base, TimestampMixin):
|
||||
Text, nullable=True,
|
||||
doc="Error description"
|
||||
)
|
||||
retry_count: Mapped[int] = mapped_column(
|
||||
Integer, default=0, nullable=False,
|
||||
doc="Number of retry attempts"
|
||||
)
|
||||
|
||||
# Download details
|
||||
download_url: Mapped[Optional[str]] = mapped_column(
|
||||
@ -437,67 +318,9 @@ class DownloadQueueItem(Base, TimestampMixin):
|
||||
"AnimeSeries",
|
||||
back_populates="download_items"
|
||||
)
|
||||
|
||||
@validates('season')
|
||||
def validate_season(self, key: str, value: int) -> int:
|
||||
"""Validate season number is positive."""
|
||||
if value < 0:
|
||||
raise ValueError("Season number must be non-negative")
|
||||
if value > 1000:
|
||||
raise ValueError("Season number must be 1000 or less")
|
||||
return value
|
||||
|
||||
@validates('episode_number')
|
||||
def validate_episode_number(self, key: str, value: int) -> int:
|
||||
"""Validate episode number is positive."""
|
||||
if value < 0:
|
||||
raise ValueError("Episode number must be non-negative")
|
||||
if value > 10000:
|
||||
raise ValueError("Episode number must be 10000 or less")
|
||||
return value
|
||||
|
||||
@validates('progress_percent')
|
||||
def validate_progress_percent(self, key: str, value: float) -> float:
|
||||
"""Validate progress is between 0 and 100."""
|
||||
if value < 0.0:
|
||||
raise ValueError("Progress percent must be non-negative")
|
||||
if value > 100.0:
|
||||
raise ValueError("Progress percent cannot exceed 100")
|
||||
return value
|
||||
|
||||
@validates('downloaded_bytes')
|
||||
def validate_downloaded_bytes(self, key: str, value: int) -> int:
|
||||
"""Validate downloaded bytes is non-negative."""
|
||||
if value < 0:
|
||||
raise ValueError("Downloaded bytes must be non-negative")
|
||||
return value
|
||||
|
||||
@validates('total_bytes')
|
||||
def validate_total_bytes(
|
||||
self, key: str, value: Optional[int]
|
||||
) -> Optional[int]:
|
||||
"""Validate total bytes is non-negative."""
|
||||
if value is not None and value < 0:
|
||||
raise ValueError("Total bytes must be non-negative")
|
||||
return value
|
||||
|
||||
@validates('download_speed')
|
||||
def validate_download_speed(
|
||||
self, key: str, value: Optional[float]
|
||||
) -> Optional[float]:
|
||||
"""Validate download speed is non-negative."""
|
||||
if value is not None and value < 0.0:
|
||||
raise ValueError("Download speed must be non-negative")
|
||||
return value
|
||||
|
||||
@validates('retry_count')
|
||||
def validate_retry_count(self, key: str, value: int) -> int:
|
||||
"""Validate retry count is non-negative."""
|
||||
if value < 0:
|
||||
raise ValueError("Retry count must be non-negative")
|
||||
if value > 100:
|
||||
raise ValueError("Retry count cannot exceed 100")
|
||||
return value
|
||||
episode: Mapped["Episode"] = relationship(
|
||||
"Episode"
|
||||
)
|
||||
|
||||
@validates('download_url')
|
||||
def validate_download_url(
|
||||
@ -523,8 +346,7 @@ class DownloadQueueItem(Base, TimestampMixin):
|
||||
return (
|
||||
f"<DownloadQueueItem(id={self.id}, "
|
||||
f"series_id={self.series_id}, "
|
||||
f"S{self.season:02d}E{self.episode_number:02d}, "
|
||||
f"status={self.status})>"
|
||||
f"episode_id={self.episode_id})>"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Dict, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@ -23,9 +23,7 @@ from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from src.server.database.models import (
|
||||
AnimeSeries,
|
||||
DownloadPriority,
|
||||
DownloadQueueItem,
|
||||
DownloadStatus,
|
||||
Episode,
|
||||
UserSession,
|
||||
)
|
||||
@ -57,11 +55,6 @@ class AnimeSeriesService:
|
||||
name: str,
|
||||
site: str,
|
||||
folder: str,
|
||||
description: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
total_episodes: Optional[int] = None,
|
||||
cover_url: Optional[str] = None,
|
||||
episode_dict: Optional[Dict] = None,
|
||||
) -> AnimeSeries:
|
||||
"""Create a new anime series.
|
||||
|
||||
@ -71,11 +64,6 @@ class AnimeSeriesService:
|
||||
name: Series name
|
||||
site: Provider site URL
|
||||
folder: Local filesystem path
|
||||
description: Optional series description
|
||||
status: Optional series status
|
||||
total_episodes: Optional total episode count
|
||||
cover_url: Optional cover image URL
|
||||
episode_dict: Optional episode dictionary
|
||||
|
||||
Returns:
|
||||
Created AnimeSeries instance
|
||||
@ -88,11 +76,6 @@ class AnimeSeriesService:
|
||||
name=name,
|
||||
site=site,
|
||||
folder=folder,
|
||||
description=description,
|
||||
status=status,
|
||||
total_episodes=total_episodes,
|
||||
cover_url=cover_url,
|
||||
episode_dict=episode_dict,
|
||||
)
|
||||
db.add(series)
|
||||
await db.flush()
|
||||
@ -262,7 +245,6 @@ class EpisodeService:
|
||||
episode_number: int,
|
||||
title: Optional[str] = None,
|
||||
file_path: Optional[str] = None,
|
||||
file_size: Optional[int] = None,
|
||||
is_downloaded: bool = False,
|
||||
) -> Episode:
|
||||
"""Create a new episode.
|
||||
@ -274,7 +256,6 @@ class EpisodeService:
|
||||
episode_number: Episode number within season
|
||||
title: Optional episode title
|
||||
file_path: Optional local file path
|
||||
file_size: Optional file size in bytes
|
||||
is_downloaded: Whether episode is downloaded
|
||||
|
||||
Returns:
|
||||
@ -286,9 +267,7 @@ class EpisodeService:
|
||||
episode_number=episode_number,
|
||||
title=title,
|
||||
file_path=file_path,
|
||||
file_size=file_size,
|
||||
is_downloaded=is_downloaded,
|
||||
download_date=datetime.now(timezone.utc) if is_downloaded else None,
|
||||
)
|
||||
db.add(episode)
|
||||
await db.flush()
|
||||
@ -372,7 +351,6 @@ class EpisodeService:
|
||||
db: AsyncSession,
|
||||
episode_id: int,
|
||||
file_path: str,
|
||||
file_size: int,
|
||||
) -> Optional[Episode]:
|
||||
"""Mark episode as downloaded.
|
||||
|
||||
@ -380,7 +358,6 @@ class EpisodeService:
|
||||
db: Database session
|
||||
episode_id: Episode primary key
|
||||
file_path: Local file path
|
||||
file_size: File size in bytes
|
||||
|
||||
Returns:
|
||||
Updated Episode instance or None if not found
|
||||
@ -391,8 +368,6 @@ class EpisodeService:
|
||||
|
||||
episode.is_downloaded = True
|
||||
episode.file_path = file_path
|
||||
episode.file_size = file_size
|
||||
episode.download_date = datetime.now(timezone.utc)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(episode)
|
||||
@ -427,17 +402,14 @@ class EpisodeService:
|
||||
class DownloadQueueService:
|
||||
"""Service for download queue CRUD operations.
|
||||
|
||||
Provides methods for managing the download queue with status tracking,
|
||||
priority management, and progress updates.
|
||||
Provides methods for managing the download queue.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def create(
|
||||
db: AsyncSession,
|
||||
series_id: int,
|
||||
season: int,
|
||||
episode_number: int,
|
||||
priority: DownloadPriority = DownloadPriority.NORMAL,
|
||||
episode_id: int,
|
||||
download_url: Optional[str] = None,
|
||||
file_destination: Optional[str] = None,
|
||||
) -> DownloadQueueItem:
|
||||
@ -446,9 +418,7 @@ class DownloadQueueService:
|
||||
Args:
|
||||
db: Database session
|
||||
series_id: Foreign key to AnimeSeries
|
||||
season: Season number
|
||||
episode_number: Episode number
|
||||
priority: Download priority
|
||||
episode_id: Foreign key to Episode
|
||||
download_url: Optional provider download URL
|
||||
file_destination: Optional target file path
|
||||
|
||||
@ -457,10 +427,7 @@ class DownloadQueueService:
|
||||
"""
|
||||
item = DownloadQueueItem(
|
||||
series_id=series_id,
|
||||
season=season,
|
||||
episode_number=episode_number,
|
||||
status=DownloadStatus.PENDING,
|
||||
priority=priority,
|
||||
episode_id=episode_id,
|
||||
download_url=download_url,
|
||||
file_destination=file_destination,
|
||||
)
|
||||
@ -468,8 +435,8 @@ class DownloadQueueService:
|
||||
await db.flush()
|
||||
await db.refresh(item)
|
||||
logger.info(
|
||||
f"Added to download queue: S{season:02d}E{episode_number:02d} "
|
||||
f"for series_id={series_id} with priority={priority}"
|
||||
f"Added to download queue: episode_id={episode_id} "
|
||||
f"for series_id={series_id}"
|
||||
)
|
||||
return item
|
||||
|
||||
@ -493,68 +460,25 @@ class DownloadQueueService:
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@staticmethod
|
||||
async def get_by_status(
|
||||
async def get_by_episode(
|
||||
db: AsyncSession,
|
||||
status: DownloadStatus,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[DownloadQueueItem]:
|
||||
"""Get download queue items by status.
|
||||
episode_id: int,
|
||||
) -> Optional[DownloadQueueItem]:
|
||||
"""Get download queue item by episode ID.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
status: Download status filter
|
||||
limit: Optional limit for results
|
||||
episode_id: Foreign key to Episode
|
||||
|
||||
Returns:
|
||||
List of DownloadQueueItem instances
|
||||
DownloadQueueItem instance or None if not found
|
||||
"""
|
||||
query = select(DownloadQueueItem).where(
|
||||
DownloadQueueItem.status == status
|
||||
)
|
||||
|
||||
# Order by priority (HIGH first) then creation time
|
||||
query = query.order_by(
|
||||
DownloadQueueItem.priority.desc(),
|
||||
DownloadQueueItem.created_at.asc(),
|
||||
)
|
||||
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@staticmethod
|
||||
async def get_pending(
|
||||
db: AsyncSession,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[DownloadQueueItem]:
|
||||
"""Get pending download queue items.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
limit: Optional limit for results
|
||||
|
||||
Returns:
|
||||
List of pending DownloadQueueItem instances ordered by priority
|
||||
"""
|
||||
return await DownloadQueueService.get_by_status(
|
||||
db, DownloadStatus.PENDING, limit
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_active(db: AsyncSession) -> List[DownloadQueueItem]:
|
||||
"""Get active download queue items.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of downloading DownloadQueueItem instances
|
||||
"""
|
||||
return await DownloadQueueService.get_by_status(
|
||||
db, DownloadStatus.DOWNLOADING
|
||||
result = await db.execute(
|
||||
select(DownloadQueueItem).where(
|
||||
DownloadQueueItem.episode_id == episode_id
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@staticmethod
|
||||
async def get_all(
|
||||
@ -576,7 +500,6 @@ class DownloadQueueService:
|
||||
query = query.options(selectinload(DownloadQueueItem.series))
|
||||
|
||||
query = query.order_by(
|
||||
DownloadQueueItem.priority.desc(),
|
||||
DownloadQueueItem.created_at.asc(),
|
||||
)
|
||||
|
||||
@ -584,19 +507,17 @@ class DownloadQueueService:
|
||||
return list(result.scalars().all())
|
||||
|
||||
@staticmethod
|
||||
async def update_status(
|
||||
async def set_error(
|
||||
db: AsyncSession,
|
||||
item_id: int,
|
||||
status: DownloadStatus,
|
||||
error_message: Optional[str] = None,
|
||||
error_message: str,
|
||||
) -> Optional[DownloadQueueItem]:
|
||||
"""Update download queue item status.
|
||||
"""Set error message on download queue item.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
item_id: Item primary key
|
||||
status: New download status
|
||||
error_message: Optional error message for failed status
|
||||
error_message: Error description
|
||||
|
||||
Returns:
|
||||
Updated DownloadQueueItem instance or None if not found
|
||||
@ -605,61 +526,11 @@ class DownloadQueueService:
|
||||
if not item:
|
||||
return None
|
||||
|
||||
item.status = status
|
||||
|
||||
# Update timestamps based on status
|
||||
if status == DownloadStatus.DOWNLOADING and not item.started_at:
|
||||
item.started_at = datetime.now(timezone.utc)
|
||||
elif status in (DownloadStatus.COMPLETED, DownloadStatus.FAILED):
|
||||
item.completed_at = datetime.now(timezone.utc)
|
||||
|
||||
# Set error message for failed downloads
|
||||
if status == DownloadStatus.FAILED and error_message:
|
||||
item.error_message = error_message
|
||||
item.retry_count += 1
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(item)
|
||||
logger.debug(f"Updated download queue item {item_id} status to {status}")
|
||||
return item
|
||||
|
||||
@staticmethod
|
||||
async def update_progress(
|
||||
db: AsyncSession,
|
||||
item_id: int,
|
||||
progress_percent: float,
|
||||
downloaded_bytes: int,
|
||||
total_bytes: Optional[int] = None,
|
||||
download_speed: Optional[float] = None,
|
||||
) -> Optional[DownloadQueueItem]:
|
||||
"""Update download progress.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
item_id: Item primary key
|
||||
progress_percent: Progress percentage (0-100)
|
||||
downloaded_bytes: Bytes downloaded
|
||||
total_bytes: Optional total file size
|
||||
download_speed: Optional current speed (bytes/sec)
|
||||
|
||||
Returns:
|
||||
Updated DownloadQueueItem instance or None if not found
|
||||
"""
|
||||
item = await DownloadQueueService.get_by_id(db, item_id)
|
||||
if not item:
|
||||
return None
|
||||
|
||||
item.progress_percent = progress_percent
|
||||
item.downloaded_bytes = downloaded_bytes
|
||||
|
||||
if total_bytes is not None:
|
||||
item.total_bytes = total_bytes
|
||||
|
||||
if download_speed is not None:
|
||||
item.download_speed = download_speed
|
||||
item.error_message = error_message
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(item)
|
||||
logger.debug(f"Set error on download queue item {item_id}")
|
||||
return item
|
||||
|
||||
@staticmethod
|
||||
@ -682,57 +553,30 @@ class DownloadQueueService:
|
||||
return deleted
|
||||
|
||||
@staticmethod
|
||||
async def clear_completed(db: AsyncSession) -> int:
|
||||
"""Clear completed downloads from queue.
|
||||
async def delete_by_episode(
|
||||
db: AsyncSession,
|
||||
episode_id: int,
|
||||
) -> bool:
|
||||
"""Delete download queue item by episode ID.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
episode_id: Foreign key to Episode
|
||||
|
||||
Returns:
|
||||
Number of items cleared
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
result = await db.execute(
|
||||
delete(DownloadQueueItem).where(
|
||||
DownloadQueueItem.status == DownloadStatus.COMPLETED
|
||||
DownloadQueueItem.episode_id == episode_id
|
||||
)
|
||||
)
|
||||
count = result.rowcount
|
||||
logger.info(f"Cleared {count} completed downloads from queue")
|
||||
return count
|
||||
|
||||
@staticmethod
|
||||
async def retry_failed(
|
||||
db: AsyncSession,
|
||||
max_retries: int = 3,
|
||||
) -> List[DownloadQueueItem]:
|
||||
"""Retry failed downloads that haven't exceeded max retries.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
max_retries: Maximum number of retry attempts
|
||||
|
||||
Returns:
|
||||
List of items marked for retry
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(DownloadQueueItem).where(
|
||||
DownloadQueueItem.status == DownloadStatus.FAILED,
|
||||
DownloadQueueItem.retry_count < max_retries,
|
||||
deleted = result.rowcount > 0
|
||||
if deleted:
|
||||
logger.info(
|
||||
f"Deleted download queue item with episode_id={episode_id}"
|
||||
)
|
||||
)
|
||||
items = list(result.scalars().all())
|
||||
|
||||
for item in items:
|
||||
item.status = DownloadStatus.PENDING
|
||||
item.error_message = None
|
||||
item.progress_percent = 0.0
|
||||
item.downloaded_bytes = 0
|
||||
item.started_at = None
|
||||
item.completed_at = None
|
||||
|
||||
await db.flush()
|
||||
logger.info(f"Marked {len(items)} failed downloads for retry")
|
||||
return items
|
||||
return deleted
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
||||
@ -70,8 +70,6 @@ class AnimeSeriesResponse(BaseModel):
|
||||
)
|
||||
)
|
||||
alt_titles: List[str] = Field(default_factory=list, description="Alternative titles")
|
||||
description: Optional[str] = Field(None, description="Short series description")
|
||||
total_episodes: Optional[int] = Field(None, ge=0, description="Declared total episode count if known")
|
||||
episodes: List[EpisodeInfo] = Field(default_factory=list, description="Known episodes information")
|
||||
missing_episodes: List[MissingEpisodeInfo] = Field(default_factory=list, description="Detected missing episode ranges")
|
||||
thumbnail: Optional[HttpUrl] = Field(None, description="Optional thumbnail image URL")
|
||||
|
||||
@ -22,7 +22,7 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.core.entities.series import Serie
|
||||
from src.server.database.service import AnimeSeriesService
|
||||
from src.server.database.service import AnimeSeriesService, EpisodeService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -206,7 +206,7 @@ class DataMigrationService:
|
||||
|
||||
Reads the data file, checks if the series already exists in the
|
||||
database, and creates a new record if it doesn't exist. If the
|
||||
series exists, optionally updates the episode_dict if changed.
|
||||
series exists, optionally updates the episodes if changed.
|
||||
|
||||
Args:
|
||||
data_path: Path to the data file
|
||||
@ -229,41 +229,44 @@ class DataMigrationService:
|
||||
existing = await AnimeSeriesService.get_by_key(db, serie.key)
|
||||
|
||||
if existing is not None:
|
||||
# Check if episode_dict has changed
|
||||
existing_dict = existing.episode_dict or {}
|
||||
# Build episode dict from existing episodes for comparison
|
||||
existing_dict: dict[int, list[int]] = {}
|
||||
episodes = await EpisodeService.get_by_series(db, existing.id)
|
||||
for ep in episodes:
|
||||
if ep.season not in existing_dict:
|
||||
existing_dict[ep.season] = []
|
||||
existing_dict[ep.season].append(ep.episode_number)
|
||||
for season in existing_dict:
|
||||
existing_dict[season].sort()
|
||||
|
||||
new_dict = serie.episodeDict or {}
|
||||
|
||||
# Convert keys to strings for comparison (JSON stores keys as strings)
|
||||
new_dict_str_keys = {
|
||||
str(k): v for k, v in new_dict.items()
|
||||
}
|
||||
|
||||
if existing_dict == new_dict_str_keys:
|
||||
if existing_dict == new_dict:
|
||||
logger.debug(
|
||||
"Series '%s' already exists with same data, skipping",
|
||||
serie.key
|
||||
)
|
||||
return False
|
||||
|
||||
# Update episode_dict if different
|
||||
await AnimeSeriesService.update(
|
||||
db,
|
||||
existing.id,
|
||||
episode_dict=new_dict_str_keys
|
||||
)
|
||||
# Update episodes if different - add new episodes
|
||||
for season, episode_numbers in new_dict.items():
|
||||
existing_eps = set(existing_dict.get(season, []))
|
||||
for ep_num in episode_numbers:
|
||||
if ep_num not in existing_eps:
|
||||
await EpisodeService.create(
|
||||
db=db,
|
||||
series_id=existing.id,
|
||||
season=season,
|
||||
episode_number=ep_num,
|
||||
)
|
||||
logger.info(
|
||||
"Updated episode_dict for existing series '%s'",
|
||||
"Updated episodes for existing series '%s'",
|
||||
serie.key
|
||||
)
|
||||
return True
|
||||
|
||||
# Create new series in database
|
||||
try:
|
||||
# Convert episode_dict keys to strings for JSON storage
|
||||
episode_dict_for_db = {
|
||||
str(k): v for k, v in (serie.episodeDict or {}).items()
|
||||
}
|
||||
|
||||
# Use folder as fallback name if name is empty
|
||||
series_name = serie.name
|
||||
if not series_name or not series_name.strip():
|
||||
@ -274,14 +277,25 @@ class DataMigrationService:
|
||||
serie.key
|
||||
)
|
||||
|
||||
await AnimeSeriesService.create(
|
||||
anime_series = await AnimeSeriesService.create(
|
||||
db,
|
||||
key=serie.key,
|
||||
name=series_name,
|
||||
site=serie.site,
|
||||
folder=serie.folder,
|
||||
episode_dict=episode_dict_for_db,
|
||||
)
|
||||
|
||||
# Create Episode records for each episode in episodeDict
|
||||
if serie.episodeDict:
|
||||
for season, episode_numbers in serie.episodeDict.items():
|
||||
for episode_number in episode_numbers:
|
||||
await EpisodeService.create(
|
||||
db=db,
|
||||
series_id=anime_series.id,
|
||||
season=season,
|
||||
episode_number=episode_number,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Migrated series '%s' to database",
|
||||
serie.key
|
||||
|
||||
@ -153,29 +153,40 @@ class TestMigrationIdempotency:
|
||||
}
|
||||
(series_dir / "data").write_text(json.dumps(data))
|
||||
|
||||
# Mock existing series in database
|
||||
# Mock existing series in database with same episodes
|
||||
existing = MagicMock()
|
||||
existing.id = 1
|
||||
existing.episode_dict = {"1": [1, 2]} # Same data
|
||||
|
||||
# Mock episodes matching data file
|
||||
mock_episodes = [
|
||||
MagicMock(season=1, episode_number=1),
|
||||
MagicMock(season=1, episode_number=2),
|
||||
]
|
||||
|
||||
service = DataMigrationService()
|
||||
|
||||
with patch(
|
||||
'src.server.services.data_migration_service.AnimeSeriesService'
|
||||
) as MockService:
|
||||
MockService.get_by_key = AsyncMock(return_value=existing)
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_db.commit = AsyncMock()
|
||||
|
||||
result = await service.migrate_all(tmp_dir, mock_db)
|
||||
|
||||
# Should skip since data is same
|
||||
assert result.total_found == 1
|
||||
assert result.skipped == 1
|
||||
assert result.migrated == 0
|
||||
# Should not call create
|
||||
MockService.create.assert_not_called()
|
||||
with patch(
|
||||
'src.server.services.data_migration_service.EpisodeService'
|
||||
) as MockEpisodeService:
|
||||
MockService.get_by_key = AsyncMock(return_value=existing)
|
||||
MockEpisodeService.get_by_series = AsyncMock(
|
||||
return_value=mock_episodes
|
||||
)
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_db.commit = AsyncMock()
|
||||
|
||||
result = await service.migrate_all(tmp_dir, mock_db)
|
||||
|
||||
# Should skip since data is same
|
||||
assert result.total_found == 1
|
||||
assert result.skipped == 1
|
||||
assert result.migrated == 0
|
||||
# Should not call create
|
||||
MockService.create.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migration_updates_changed_episodes(self):
|
||||
@ -196,25 +207,37 @@ class TestMigrationIdempotency:
|
||||
# Mock existing series with fewer episodes
|
||||
existing = MagicMock()
|
||||
existing.id = 1
|
||||
existing.episode_dict = {"1": [1, 2]} # Fewer episodes
|
||||
|
||||
# Mock existing episodes (fewer than data file)
|
||||
mock_episodes = [
|
||||
MagicMock(season=1, episode_number=1),
|
||||
MagicMock(season=1, episode_number=2),
|
||||
]
|
||||
|
||||
service = DataMigrationService()
|
||||
|
||||
with patch(
|
||||
'src.server.services.data_migration_service.AnimeSeriesService'
|
||||
) as MockService:
|
||||
MockService.get_by_key = AsyncMock(return_value=existing)
|
||||
MockService.update = AsyncMock()
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_db.commit = AsyncMock()
|
||||
|
||||
result = await service.migrate_all(tmp_dir, mock_db)
|
||||
|
||||
# Should update since data changed
|
||||
assert result.total_found == 1
|
||||
assert result.migrated == 1
|
||||
MockService.update.assert_called_once()
|
||||
with patch(
|
||||
'src.server.services.data_migration_service.EpisodeService'
|
||||
) as MockEpisodeService:
|
||||
MockService.get_by_key = AsyncMock(return_value=existing)
|
||||
MockEpisodeService.get_by_series = AsyncMock(
|
||||
return_value=mock_episodes
|
||||
)
|
||||
MockEpisodeService.create = AsyncMock()
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_db.commit = AsyncMock()
|
||||
|
||||
result = await service.migrate_all(tmp_dir, mock_db)
|
||||
|
||||
# Should update since data changed
|
||||
assert result.total_found == 1
|
||||
assert result.migrated == 1
|
||||
# Should create 3 new episodes (3, 4, 5)
|
||||
assert MockEpisodeService.create.call_count == 3
|
||||
|
||||
|
||||
class TestMigrationOnFreshStart:
|
||||
@ -348,13 +371,18 @@ class TestSerieListReadsFromDatabase:
|
||||
# Create mock series in database with spec to avoid mock attributes
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class MockEpisode:
|
||||
season: int
|
||||
episode_number: int
|
||||
|
||||
@dataclass
|
||||
class MockAnimeSeries:
|
||||
key: str
|
||||
name: str
|
||||
site: str
|
||||
folder: str
|
||||
episode_dict: dict
|
||||
episodes: list
|
||||
|
||||
mock_series = [
|
||||
MockAnimeSeries(
|
||||
@ -362,14 +390,18 @@ class TestSerieListReadsFromDatabase:
|
||||
name="Anime 1",
|
||||
site="aniworld.to",
|
||||
folder="Anime 1",
|
||||
episode_dict={"1": [1, 2, 3]}
|
||||
episodes=[
|
||||
MockEpisode(1, 1), MockEpisode(1, 2), MockEpisode(1, 3)
|
||||
]
|
||||
),
|
||||
MockAnimeSeries(
|
||||
key="anime-2",
|
||||
name="Anime 2",
|
||||
site="aniworld.to",
|
||||
folder="Anime 2",
|
||||
episode_dict={"1": [1, 2], "2": [1]}
|
||||
episodes=[
|
||||
MockEpisode(1, 1), MockEpisode(1, 2), MockEpisode(2, 1)
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
@ -389,8 +421,8 @@ class TestSerieListReadsFromDatabase:
|
||||
# Load from database
|
||||
await serie_list.load_series_from_db(mock_db)
|
||||
|
||||
# Verify service was called
|
||||
mock_get_all.assert_called_once_with(mock_db)
|
||||
# Verify service was called with with_episodes=True
|
||||
mock_get_all.assert_called_once_with(mock_db, with_episodes=True)
|
||||
|
||||
# Verify series were loaded
|
||||
all_series = serie_list.get_all()
|
||||
|
||||
@ -65,7 +65,6 @@ class TestAnimeSeriesResponse:
|
||||
title="Attack on Titan",
|
||||
folder="Attack on Titan (2013)",
|
||||
episodes=[ep],
|
||||
total_episodes=12,
|
||||
)
|
||||
|
||||
assert series.key == "attack-on-titan"
|
||||
|
||||
@ -304,10 +304,18 @@ class TestDataMigrationServiceMigrateSingle:
|
||||
"""Test migrating series that already exists with same data."""
|
||||
service = DataMigrationService()
|
||||
|
||||
# Create mock existing series with same episode_dict
|
||||
# Create mock existing series with same episodes
|
||||
existing = MagicMock()
|
||||
existing.id = 1
|
||||
existing.episode_dict = {"1": [1, 2, 3], "2": [1, 2]}
|
||||
|
||||
# Mock episodes matching sample_serie.episodeDict = {1: [1, 2, 3], 2: [1, 2]}
|
||||
mock_episodes = []
|
||||
for season, eps in {1: [1, 2, 3], 2: [1, 2]}.items():
|
||||
for ep_num in eps:
|
||||
mock_ep = MagicMock()
|
||||
mock_ep.season = season
|
||||
mock_ep.episode_number = ep_num
|
||||
mock_episodes.append(mock_ep)
|
||||
|
||||
with patch.object(
|
||||
service,
|
||||
@ -317,19 +325,25 @@ class TestDataMigrationServiceMigrateSingle:
|
||||
with patch(
|
||||
'src.server.services.data_migration_service.AnimeSeriesService'
|
||||
) as MockService:
|
||||
MockService.get_by_key = AsyncMock(return_value=existing)
|
||||
|
||||
result = await service.migrate_data_file(
|
||||
Path("/fake/data"),
|
||||
mock_db
|
||||
)
|
||||
|
||||
assert result is False
|
||||
MockService.create.assert_not_called()
|
||||
with patch(
|
||||
'src.server.services.data_migration_service.EpisodeService'
|
||||
) as MockEpisodeService:
|
||||
MockService.get_by_key = AsyncMock(return_value=existing)
|
||||
MockEpisodeService.get_by_series = AsyncMock(
|
||||
return_value=mock_episodes
|
||||
)
|
||||
|
||||
result = await service.migrate_data_file(
|
||||
Path("/fake/data"),
|
||||
mock_db
|
||||
)
|
||||
|
||||
assert result is False
|
||||
MockService.create.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_existing_series_different_data(self, mock_db):
|
||||
"""Test migrating series that exists with different episode_dict."""
|
||||
"""Test migrating series that exists with different episodes."""
|
||||
service = DataMigrationService()
|
||||
|
||||
# Serie with new episodes
|
||||
@ -344,7 +358,14 @@ class TestDataMigrationServiceMigrateSingle:
|
||||
# Existing series has fewer episodes
|
||||
existing = MagicMock()
|
||||
existing.id = 1
|
||||
existing.episode_dict = {"1": [1, 2, 3]}
|
||||
|
||||
# Mock episodes for existing (only 3 episodes)
|
||||
mock_episodes = []
|
||||
for ep_num in [1, 2, 3]:
|
||||
mock_ep = MagicMock()
|
||||
mock_ep.season = 1
|
||||
mock_ep.episode_number = ep_num
|
||||
mock_episodes.append(mock_ep)
|
||||
|
||||
with patch.object(
|
||||
service,
|
||||
@ -354,16 +375,23 @@ class TestDataMigrationServiceMigrateSingle:
|
||||
with patch(
|
||||
'src.server.services.data_migration_service.AnimeSeriesService'
|
||||
) as MockService:
|
||||
MockService.get_by_key = AsyncMock(return_value=existing)
|
||||
MockService.update = AsyncMock()
|
||||
|
||||
result = await service.migrate_data_file(
|
||||
Path("/fake/data"),
|
||||
mock_db
|
||||
)
|
||||
|
||||
assert result is True
|
||||
MockService.update.assert_called_once()
|
||||
with patch(
|
||||
'src.server.services.data_migration_service.EpisodeService'
|
||||
) as MockEpisodeService:
|
||||
MockService.get_by_key = AsyncMock(return_value=existing)
|
||||
MockEpisodeService.get_by_series = AsyncMock(
|
||||
return_value=mock_episodes
|
||||
)
|
||||
MockEpisodeService.create = AsyncMock()
|
||||
|
||||
result = await service.migrate_data_file(
|
||||
Path("/fake/data"),
|
||||
mock_db
|
||||
)
|
||||
|
||||
assert result is True
|
||||
# Should create 2 new episodes (4 and 5)
|
||||
assert MockEpisodeService.create.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_read_error(self, mock_db):
|
||||
@ -493,21 +521,26 @@ class TestDataMigrationServiceMigrateAll:
|
||||
# Mock: first series doesn't exist, second already exists
|
||||
existing = MagicMock()
|
||||
existing.id = 2
|
||||
existing.episode_dict = {}
|
||||
|
||||
with patch(
|
||||
'src.server.services.data_migration_service.AnimeSeriesService'
|
||||
) as MockService:
|
||||
MockService.get_by_key = AsyncMock(
|
||||
side_effect=[None, existing]
|
||||
)
|
||||
MockService.create = AsyncMock()
|
||||
|
||||
result = await service.migrate_all(tmp_dir, mock_db)
|
||||
|
||||
assert result.total_found == 2
|
||||
assert result.migrated == 1
|
||||
assert result.skipped == 1
|
||||
with patch(
|
||||
'src.server.services.data_migration_service.EpisodeService'
|
||||
) as MockEpisodeService:
|
||||
MockService.get_by_key = AsyncMock(
|
||||
side_effect=[None, existing]
|
||||
)
|
||||
MockService.create = AsyncMock(
|
||||
return_value=MagicMock(id=1)
|
||||
)
|
||||
MockEpisodeService.get_by_series = AsyncMock(return_value=[])
|
||||
|
||||
result = await service.migrate_all(tmp_dir, mock_db)
|
||||
|
||||
assert result.total_found == 2
|
||||
assert result.migrated == 1
|
||||
assert result.skipped == 1
|
||||
|
||||
|
||||
class TestDataMigrationServiceIsMigrationNeeded:
|
||||
|
||||
@ -14,9 +14,7 @@ from sqlalchemy.orm import Session, sessionmaker
|
||||
from src.server.database.base import Base, SoftDeleteMixin, TimestampMixin
|
||||
from src.server.database.models import (
|
||||
AnimeSeries,
|
||||
DownloadPriority,
|
||||
DownloadQueueItem,
|
||||
DownloadStatus,
|
||||
Episode,
|
||||
UserSession,
|
||||
)
|
||||
@ -49,11 +47,6 @@ class TestAnimeSeries:
|
||||
name="Attack on Titan",
|
||||
site="https://aniworld.to",
|
||||
folder="/anime/attack-on-titan",
|
||||
description="Epic anime about titans",
|
||||
status="completed",
|
||||
total_episodes=75,
|
||||
cover_url="https://example.com/cover.jpg",
|
||||
episode_dict={1: [1, 2, 3], 2: [1, 2, 3, 4]},
|
||||
)
|
||||
|
||||
db_session.add(series)
|
||||
@ -172,9 +165,7 @@ class TestEpisode:
|
||||
episode_number=5,
|
||||
title="The Fifth Episode",
|
||||
file_path="/anime/test/S01E05.mp4",
|
||||
file_size=524288000, # 500 MB
|
||||
is_downloaded=True,
|
||||
download_date=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
db_session.add(episode)
|
||||
@ -225,17 +216,17 @@ class TestDownloadQueueItem:
|
||||
db_session.add(series)
|
||||
db_session.commit()
|
||||
|
||||
item = DownloadQueueItem(
|
||||
episode = Episode(
|
||||
series_id=series.id,
|
||||
season=1,
|
||||
episode_number=3,
|
||||
status=DownloadStatus.DOWNLOADING,
|
||||
priority=DownloadPriority.HIGH,
|
||||
progress_percent=45.5,
|
||||
downloaded_bytes=250000000,
|
||||
total_bytes=550000000,
|
||||
download_speed=2500000.0,
|
||||
retry_count=0,
|
||||
)
|
||||
db_session.add(episode)
|
||||
db_session.commit()
|
||||
|
||||
item = DownloadQueueItem(
|
||||
series_id=series.id,
|
||||
episode_id=episode.id,
|
||||
download_url="https://example.com/download/ep3",
|
||||
file_destination="/anime/download/S01E03.mp4",
|
||||
)
|
||||
@ -245,37 +236,38 @@ class TestDownloadQueueItem:
|
||||
|
||||
# Verify saved
|
||||
assert item.id is not None
|
||||
assert item.status == DownloadStatus.DOWNLOADING
|
||||
assert item.priority == DownloadPriority.HIGH
|
||||
assert item.progress_percent == 45.5
|
||||
assert item.retry_count == 0
|
||||
assert item.episode_id == episode.id
|
||||
assert item.series_id == series.id
|
||||
|
||||
def test_download_item_status_enum(self, db_session: Session):
|
||||
"""Test download status enum values."""
|
||||
def test_download_item_episode_relationship(self, db_session: Session):
|
||||
"""Test download item episode relationship."""
|
||||
series = AnimeSeries(
|
||||
key="status-test",
|
||||
name="Status Test",
|
||||
key="relationship-test",
|
||||
name="Relationship Test",
|
||||
site="https://example.com",
|
||||
folder="/anime/status",
|
||||
folder="/anime/relationship",
|
||||
)
|
||||
db_session.add(series)
|
||||
db_session.commit()
|
||||
|
||||
item = DownloadQueueItem(
|
||||
episode = Episode(
|
||||
series_id=series.id,
|
||||
season=1,
|
||||
episode_number=1,
|
||||
status=DownloadStatus.PENDING,
|
||||
)
|
||||
db_session.add(episode)
|
||||
db_session.commit()
|
||||
|
||||
item = DownloadQueueItem(
|
||||
series_id=series.id,
|
||||
episode_id=episode.id,
|
||||
)
|
||||
db_session.add(item)
|
||||
db_session.commit()
|
||||
|
||||
# Update status
|
||||
item.status = DownloadStatus.COMPLETED
|
||||
db_session.commit()
|
||||
|
||||
# Verify status change
|
||||
assert item.status == DownloadStatus.COMPLETED
|
||||
# Verify relationship
|
||||
assert item.episode.id == episode.id
|
||||
assert item.series.id == series.id
|
||||
|
||||
def test_download_item_error_handling(self, db_session: Session):
|
||||
"""Test download item with error information."""
|
||||
@ -288,21 +280,24 @@ class TestDownloadQueueItem:
|
||||
db_session.add(series)
|
||||
db_session.commit()
|
||||
|
||||
item = DownloadQueueItem(
|
||||
episode = Episode(
|
||||
series_id=series.id,
|
||||
season=1,
|
||||
episode_number=1,
|
||||
status=DownloadStatus.FAILED,
|
||||
)
|
||||
db_session.add(episode)
|
||||
db_session.commit()
|
||||
|
||||
item = DownloadQueueItem(
|
||||
series_id=series.id,
|
||||
episode_id=episode.id,
|
||||
error_message="Network timeout after 30 seconds",
|
||||
retry_count=2,
|
||||
)
|
||||
db_session.add(item)
|
||||
db_session.commit()
|
||||
|
||||
# Verify error info
|
||||
assert item.status == DownloadStatus.FAILED
|
||||
assert item.error_message == "Network timeout after 30 seconds"
|
||||
assert item.retry_count == 2
|
||||
|
||||
|
||||
class TestUserSession:
|
||||
@ -502,32 +497,31 @@ class TestDatabaseQueries:
|
||||
db_session.add(series)
|
||||
db_session.commit()
|
||||
|
||||
# Create items with different statuses
|
||||
for i, status in enumerate([
|
||||
DownloadStatus.PENDING,
|
||||
DownloadStatus.DOWNLOADING,
|
||||
DownloadStatus.COMPLETED,
|
||||
]):
|
||||
item = DownloadQueueItem(
|
||||
# Create episodes and items
|
||||
for i in range(3):
|
||||
episode = Episode(
|
||||
series_id=series.id,
|
||||
season=1,
|
||||
episode_number=i + 1,
|
||||
status=status,
|
||||
)
|
||||
db_session.add(episode)
|
||||
db_session.commit()
|
||||
|
||||
item = DownloadQueueItem(
|
||||
series_id=series.id,
|
||||
episode_id=episode.id,
|
||||
)
|
||||
db_session.add(item)
|
||||
db_session.commit()
|
||||
|
||||
# Query pending items
|
||||
# Query all items
|
||||
result = db_session.execute(
|
||||
select(DownloadQueueItem).where(
|
||||
DownloadQueueItem.status == DownloadStatus.PENDING
|
||||
)
|
||||
select(DownloadQueueItem)
|
||||
)
|
||||
pending = result.scalars().all()
|
||||
items = result.scalars().all()
|
||||
|
||||
# Verify query
|
||||
assert len(pending) == 1
|
||||
assert pending[0].episode_number == 1
|
||||
assert len(items) == 3
|
||||
|
||||
def test_query_active_sessions(self, db_session: Session):
|
||||
"""Test querying active user sessions."""
|
||||
|
||||
@ -10,7 +10,6 @@ from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from src.server.database.base import Base
|
||||
from src.server.database.models import DownloadPriority, DownloadStatus
|
||||
from src.server.database.service import (
|
||||
AnimeSeriesService,
|
||||
DownloadQueueService,
|
||||
@ -65,17 +64,11 @@ async def test_create_anime_series(db_session):
|
||||
name="Test Anime",
|
||||
site="https://example.com",
|
||||
folder="/path/to/anime",
|
||||
description="A test anime",
|
||||
status="ongoing",
|
||||
total_episodes=12,
|
||||
cover_url="https://example.com/cover.jpg",
|
||||
)
|
||||
|
||||
assert series.id is not None
|
||||
assert series.key == "test-anime-1"
|
||||
assert series.name == "Test Anime"
|
||||
assert series.description == "A test anime"
|
||||
assert series.total_episodes == 12
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -160,13 +153,11 @@ async def test_update_anime_series(db_session):
|
||||
db_session,
|
||||
series.id,
|
||||
name="Updated Name",
|
||||
total_episodes=24,
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
assert updated is not None
|
||||
assert updated.name == "Updated Name"
|
||||
assert updated.total_episodes == 24
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -308,14 +299,12 @@ async def test_mark_episode_downloaded(db_session):
|
||||
db_session,
|
||||
episode.id,
|
||||
file_path="/path/to/file.mp4",
|
||||
file_size=1024000,
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
assert updated is not None
|
||||
assert updated.is_downloaded is True
|
||||
assert updated.file_path == "/path/to/file.mp4"
|
||||
assert updated.download_date is not None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@ -336,23 +325,30 @@ async def test_create_download_queue_item(db_session):
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
# Add to queue
|
||||
item = await DownloadQueueService.create(
|
||||
# Create episode
|
||||
episode = await EpisodeService.create(
|
||||
db_session,
|
||||
series_id=series.id,
|
||||
season=1,
|
||||
episode_number=1,
|
||||
priority=DownloadPriority.HIGH,
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
# Add to queue
|
||||
item = await DownloadQueueService.create(
|
||||
db_session,
|
||||
series_id=series.id,
|
||||
episode_id=episode.id,
|
||||
)
|
||||
|
||||
assert item.id is not None
|
||||
assert item.status == DownloadStatus.PENDING
|
||||
assert item.priority == DownloadPriority.HIGH
|
||||
assert item.episode_id == episode.id
|
||||
assert item.series_id == series.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pending_downloads(db_session):
|
||||
"""Test retrieving pending downloads."""
|
||||
async def test_get_download_queue_item_by_episode(db_session):
|
||||
"""Test retrieving download queue item by episode."""
|
||||
# Create series
|
||||
series = await AnimeSeriesService.create(
|
||||
db_session,
|
||||
@ -362,29 +358,32 @@ async def test_get_pending_downloads(db_session):
|
||||
folder="/path/test5",
|
||||
)
|
||||
|
||||
# Add pending items
|
||||
await DownloadQueueService.create(
|
||||
# Create episode
|
||||
episode = await EpisodeService.create(
|
||||
db_session,
|
||||
series_id=series.id,
|
||||
season=1,
|
||||
episode_number=1,
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
# Add to queue
|
||||
await DownloadQueueService.create(
|
||||
db_session,
|
||||
series_id=series.id,
|
||||
season=1,
|
||||
episode_number=2,
|
||||
episode_id=episode.id,
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
# Retrieve pending
|
||||
pending = await DownloadQueueService.get_pending(db_session)
|
||||
assert len(pending) == 2
|
||||
# Retrieve by episode
|
||||
item = await DownloadQueueService.get_by_episode(db_session, episode.id)
|
||||
assert item is not None
|
||||
assert item.episode_id == episode.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_download_status(db_session):
|
||||
"""Test updating download status."""
|
||||
async def test_set_download_error(db_session):
|
||||
"""Test setting error on download queue item."""
|
||||
# Create series and queue item
|
||||
series = await AnimeSeriesService.create(
|
||||
db_session,
|
||||
@ -393,30 +392,34 @@ async def test_update_download_status(db_session):
|
||||
site="https://example.com",
|
||||
folder="/path/test6",
|
||||
)
|
||||
item = await DownloadQueueService.create(
|
||||
episode = await EpisodeService.create(
|
||||
db_session,
|
||||
series_id=series.id,
|
||||
season=1,
|
||||
episode_number=1,
|
||||
)
|
||||
item = await DownloadQueueService.create(
|
||||
db_session,
|
||||
series_id=series.id,
|
||||
episode_id=episode.id,
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
# Update status
|
||||
updated = await DownloadQueueService.update_status(
|
||||
# Set error
|
||||
updated = await DownloadQueueService.set_error(
|
||||
db_session,
|
||||
item.id,
|
||||
DownloadStatus.DOWNLOADING,
|
||||
"Network error",
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
assert updated is not None
|
||||
assert updated.status == DownloadStatus.DOWNLOADING
|
||||
assert updated.started_at is not None
|
||||
assert updated.error_message == "Network error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_download_progress(db_session):
|
||||
"""Test updating download progress."""
|
||||
async def test_delete_download_queue_item_by_episode(db_session):
|
||||
"""Test deleting download queue item by episode."""
|
||||
# Create series and queue item
|
||||
series = await AnimeSeriesService.create(
|
||||
db_session,
|
||||
@ -425,109 +428,31 @@ async def test_update_download_progress(db_session):
|
||||
site="https://example.com",
|
||||
folder="/path/test7",
|
||||
)
|
||||
item = await DownloadQueueService.create(
|
||||
episode = await EpisodeService.create(
|
||||
db_session,
|
||||
series_id=series.id,
|
||||
season=1,
|
||||
episode_number=1,
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
# Update progress
|
||||
updated = await DownloadQueueService.update_progress(
|
||||
db_session,
|
||||
item.id,
|
||||
progress_percent=50.0,
|
||||
downloaded_bytes=500000,
|
||||
total_bytes=1000000,
|
||||
download_speed=50000.0,
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
assert updated is not None
|
||||
assert updated.progress_percent == 50.0
|
||||
assert updated.downloaded_bytes == 500000
|
||||
assert updated.total_bytes == 1000000
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_completed_downloads(db_session):
|
||||
"""Test clearing completed downloads."""
|
||||
# Create series and completed items
|
||||
series = await AnimeSeriesService.create(
|
||||
db_session,
|
||||
key="test-series-8",
|
||||
name="Test Series 8",
|
||||
site="https://example.com",
|
||||
folder="/path/test8",
|
||||
)
|
||||
item1 = await DownloadQueueService.create(
|
||||
await DownloadQueueService.create(
|
||||
db_session,
|
||||
series_id=series.id,
|
||||
season=1,
|
||||
episode_number=1,
|
||||
)
|
||||
item2 = await DownloadQueueService.create(
|
||||
db_session,
|
||||
series_id=series.id,
|
||||
season=1,
|
||||
episode_number=2,
|
||||
)
|
||||
|
||||
# Mark items as completed
|
||||
await DownloadQueueService.update_status(
|
||||
db_session,
|
||||
item1.id,
|
||||
DownloadStatus.COMPLETED,
|
||||
)
|
||||
await DownloadQueueService.update_status(
|
||||
db_session,
|
||||
item2.id,
|
||||
DownloadStatus.COMPLETED,
|
||||
episode_id=episode.id,
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
# Clear completed
|
||||
count = await DownloadQueueService.clear_completed(db_session)
|
||||
await db_session.commit()
|
||||
|
||||
assert count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_failed_downloads(db_session):
|
||||
"""Test retrying failed downloads."""
|
||||
# Create series and failed item
|
||||
series = await AnimeSeriesService.create(
|
||||
# Delete by episode
|
||||
deleted = await DownloadQueueService.delete_by_episode(
|
||||
db_session,
|
||||
key="test-series-9",
|
||||
name="Test Series 9",
|
||||
site="https://example.com",
|
||||
folder="/path/test9",
|
||||
)
|
||||
item = await DownloadQueueService.create(
|
||||
db_session,
|
||||
series_id=series.id,
|
||||
season=1,
|
||||
episode_number=1,
|
||||
)
|
||||
|
||||
# Mark as failed
|
||||
await DownloadQueueService.update_status(
|
||||
db_session,
|
||||
item.id,
|
||||
DownloadStatus.FAILED,
|
||||
error_message="Network error",
|
||||
episode.id,
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
# Retry
|
||||
retried = await DownloadQueueService.retry_failed(db_session)
|
||||
await db_session.commit()
|
||||
assert deleted is True
|
||||
|
||||
assert len(retried) == 1
|
||||
assert retried[0].status == DownloadStatus.PENDING
|
||||
assert retried[0].error_message is None
|
||||
# Verify deleted
|
||||
item = await DownloadQueueService.get_by_episode(db_session, episode.id)
|
||||
assert item is None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
||||
@ -45,7 +45,23 @@ def mock_anime_series():
|
||||
anime_series.name = "Test Series"
|
||||
anime_series.site = "https://aniworld.to/anime/stream/test-series"
|
||||
anime_series.folder = "Test Series (2020)"
|
||||
anime_series.episode_dict = {"1": [1, 2, 3], "2": [1, 2]}
|
||||
# Mock episodes relationship
|
||||
mock_ep1 = MagicMock()
|
||||
mock_ep1.season = 1
|
||||
mock_ep1.episode_number = 1
|
||||
mock_ep2 = MagicMock()
|
||||
mock_ep2.season = 1
|
||||
mock_ep2.episode_number = 2
|
||||
mock_ep3 = MagicMock()
|
||||
mock_ep3.season = 1
|
||||
mock_ep3.episode_number = 3
|
||||
mock_ep4 = MagicMock()
|
||||
mock_ep4.season = 2
|
||||
mock_ep4.episode_number = 1
|
||||
mock_ep5 = MagicMock()
|
||||
mock_ep5.season = 2
|
||||
mock_ep5.episode_number = 2
|
||||
anime_series.episodes = [mock_ep1, mock_ep2, mock_ep3, mock_ep4, mock_ep5]
|
||||
return anime_series
|
||||
|
||||
|
||||
@ -288,37 +304,27 @@ class TestSerieListDatabaseMode:
|
||||
assert serie.name == mock_anime_series.name
|
||||
assert serie.site == mock_anime_series.site
|
||||
assert serie.folder == mock_anime_series.folder
|
||||
# Season keys should be converted from string to int
|
||||
# Season keys should be built from episodes relationship
|
||||
assert 1 in serie.episodeDict
|
||||
assert 2 in serie.episodeDict
|
||||
assert serie.episodeDict[1] == [1, 2, 3]
|
||||
assert serie.episodeDict[2] == [1, 2]
|
||||
|
||||
def test_convert_from_db_empty_episode_dict(self, mock_anime_series):
|
||||
"""Test _convert_from_db handles empty episode_dict."""
|
||||
mock_anime_series.episode_dict = None
|
||||
def test_convert_from_db_empty_episodes(self, mock_anime_series):
|
||||
"""Test _convert_from_db handles empty episodes."""
|
||||
mock_anime_series.episodes = []
|
||||
|
||||
serie = SerieList._convert_from_db(mock_anime_series)
|
||||
|
||||
assert serie.episodeDict == {}
|
||||
|
||||
def test_convert_from_db_handles_invalid_season_keys(
|
||||
self, mock_anime_series
|
||||
):
|
||||
"""Test _convert_from_db handles invalid season keys gracefully."""
|
||||
mock_anime_series.episode_dict = {
|
||||
"1": [1, 2],
|
||||
"invalid": [3, 4], # Invalid key - not an integer
|
||||
"2": [5, 6]
|
||||
}
|
||||
def test_convert_from_db_none_episodes(self, mock_anime_series):
|
||||
"""Test _convert_from_db handles None episodes."""
|
||||
mock_anime_series.episodes = None
|
||||
|
||||
serie = SerieList._convert_from_db(mock_anime_series)
|
||||
|
||||
# Valid keys should be converted
|
||||
assert 1 in serie.episodeDict
|
||||
assert 2 in serie.episodeDict
|
||||
# Invalid key should be skipped
|
||||
assert "invalid" not in serie.episodeDict
|
||||
assert serie.episodeDict == {}
|
||||
|
||||
def test_convert_to_db_dict(self, sample_serie):
|
||||
"""Test _convert_to_db_dict creates correct dictionary."""
|
||||
@ -328,9 +334,8 @@ class TestSerieListDatabaseMode:
|
||||
assert result["name"] == sample_serie.name
|
||||
assert result["site"] == sample_serie.site
|
||||
assert result["folder"] == sample_serie.folder
|
||||
# Keys should be converted to strings for JSON
|
||||
assert "1" in result["episode_dict"]
|
||||
assert result["episode_dict"]["1"] == [1, 2, 3]
|
||||
# episode_dict should not be in result anymore
|
||||
assert "episode_dict" not in result
|
||||
|
||||
def test_convert_to_db_dict_empty_episode_dict(self):
|
||||
"""Test _convert_to_db_dict handles empty episode_dict."""
|
||||
@ -344,7 +349,8 @@ class TestSerieListDatabaseMode:
|
||||
|
||||
result = SerieList._convert_to_db_dict(serie)
|
||||
|
||||
assert result["episode_dict"] is None
|
||||
# episode_dict should not be in result anymore
|
||||
assert "episode_dict" not in result
|
||||
|
||||
|
||||
class TestSerieListDatabaseAsync:
|
||||
|
||||
@ -174,10 +174,16 @@ class TestSerieScannerAsyncScan:
|
||||
"""Test scan_async updates existing series in database."""
|
||||
scanner = SerieScanner(temp_directory, mock_loader)
|
||||
|
||||
# Mock existing series in database
|
||||
# Mock existing series in database with different episodes
|
||||
existing = MagicMock()
|
||||
existing.id = 1
|
||||
existing.episode_dict = {1: [5, 6]} # Different from sample_serie
|
||||
existing.folder = sample_serie.folder
|
||||
|
||||
# Mock episodes (different from sample_serie)
|
||||
mock_existing_episodes = [
|
||||
MagicMock(season=1, episode_number=5),
|
||||
MagicMock(season=1, episode_number=6),
|
||||
]
|
||||
|
||||
with patch.object(scanner, 'get_total_to_scan', return_value=1):
|
||||
with patch.object(
|
||||
@ -200,17 +206,24 @@ class TestSerieScannerAsyncScan:
|
||||
with patch(
|
||||
'src.server.database.service.AnimeSeriesService'
|
||||
) as mock_service:
|
||||
mock_service.get_by_key = AsyncMock(
|
||||
return_value=existing
|
||||
)
|
||||
mock_service.update = AsyncMock(
|
||||
return_value=existing
|
||||
)
|
||||
|
||||
await scanner.scan_async(mock_db_session)
|
||||
|
||||
# Verify database update was called
|
||||
mock_service.update.assert_called_once()
|
||||
with patch(
|
||||
'src.server.database.service.EpisodeService'
|
||||
) as mock_ep_service:
|
||||
mock_service.get_by_key = AsyncMock(
|
||||
return_value=existing
|
||||
)
|
||||
mock_service.update = AsyncMock(
|
||||
return_value=existing
|
||||
)
|
||||
mock_ep_service.get_by_series = AsyncMock(
|
||||
return_value=mock_existing_episodes
|
||||
)
|
||||
mock_ep_service.create = AsyncMock()
|
||||
|
||||
await scanner.scan_async(mock_db_session)
|
||||
|
||||
# Verify episodes were created
|
||||
assert mock_ep_service.create.called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_async_handles_errors_gracefully(
|
||||
@ -249,17 +262,21 @@ class TestSerieScannerDatabaseHelpers:
|
||||
with patch(
|
||||
'src.server.database.service.AnimeSeriesService'
|
||||
) as mock_service:
|
||||
mock_service.get_by_key = AsyncMock(return_value=None)
|
||||
mock_created = MagicMock()
|
||||
mock_created.id = 1
|
||||
mock_service.create = AsyncMock(return_value=mock_created)
|
||||
|
||||
result = await scanner._save_serie_to_db(
|
||||
sample_serie, mock_db_session
|
||||
)
|
||||
|
||||
assert result is mock_created
|
||||
mock_service.create.assert_called_once()
|
||||
with patch(
|
||||
'src.server.database.service.EpisodeService'
|
||||
) as mock_ep_service:
|
||||
mock_service.get_by_key = AsyncMock(return_value=None)
|
||||
mock_created = MagicMock()
|
||||
mock_created.id = 1
|
||||
mock_service.create = AsyncMock(return_value=mock_created)
|
||||
mock_ep_service.create = AsyncMock()
|
||||
|
||||
result = await scanner._save_serie_to_db(
|
||||
sample_serie, mock_db_session
|
||||
)
|
||||
|
||||
assert result is mock_created
|
||||
mock_service.create.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_serie_to_db_updates_existing(
|
||||
@ -270,20 +287,34 @@ class TestSerieScannerDatabaseHelpers:
|
||||
|
||||
existing = MagicMock()
|
||||
existing.id = 1
|
||||
existing.episode_dict = {1: [5, 6]} # Different episodes
|
||||
existing.folder = sample_serie.folder
|
||||
|
||||
# Mock existing episodes (different from sample_serie)
|
||||
mock_existing_episodes = [
|
||||
MagicMock(season=1, episode_number=5),
|
||||
MagicMock(season=1, episode_number=6),
|
||||
]
|
||||
|
||||
with patch(
|
||||
'src.server.database.service.AnimeSeriesService'
|
||||
) as mock_service:
|
||||
mock_service.get_by_key = AsyncMock(return_value=existing)
|
||||
mock_service.update = AsyncMock(return_value=existing)
|
||||
|
||||
result = await scanner._save_serie_to_db(
|
||||
sample_serie, mock_db_session
|
||||
)
|
||||
|
||||
assert result is existing
|
||||
mock_service.update.assert_called_once()
|
||||
with patch(
|
||||
'src.server.database.service.EpisodeService'
|
||||
) as mock_ep_service:
|
||||
mock_service.get_by_key = AsyncMock(return_value=existing)
|
||||
mock_service.update = AsyncMock(return_value=existing)
|
||||
mock_ep_service.get_by_series = AsyncMock(
|
||||
return_value=mock_existing_episodes
|
||||
)
|
||||
mock_ep_service.create = AsyncMock()
|
||||
|
||||
result = await scanner._save_serie_to_db(
|
||||
sample_serie, mock_db_session
|
||||
)
|
||||
|
||||
assert result is existing
|
||||
# Should have created new episodes
|
||||
assert mock_ep_service.create.called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_serie_to_db_skips_unchanged(
|
||||
@ -294,19 +325,33 @@ class TestSerieScannerDatabaseHelpers:
|
||||
|
||||
existing = MagicMock()
|
||||
existing.id = 1
|
||||
existing.episode_dict = sample_serie.episodeDict # Same episodes
|
||||
existing.folder = sample_serie.folder
|
||||
|
||||
# Mock episodes matching sample_serie.episodeDict
|
||||
mock_existing_episodes = []
|
||||
for season, ep_nums in sample_serie.episodeDict.items():
|
||||
for ep_num in ep_nums:
|
||||
mock_existing_episodes.append(
|
||||
MagicMock(season=season, episode_number=ep_num)
|
||||
)
|
||||
|
||||
with patch(
|
||||
'src.server.database.service.AnimeSeriesService'
|
||||
) as mock_service:
|
||||
mock_service.get_by_key = AsyncMock(return_value=existing)
|
||||
|
||||
result = await scanner._save_serie_to_db(
|
||||
sample_serie, mock_db_session
|
||||
)
|
||||
|
||||
assert result is None
|
||||
mock_service.update.assert_not_called()
|
||||
with patch(
|
||||
'src.server.database.service.EpisodeService'
|
||||
) as mock_ep_service:
|
||||
mock_service.get_by_key = AsyncMock(return_value=existing)
|
||||
mock_ep_service.get_by_series = AsyncMock(
|
||||
return_value=mock_existing_episodes
|
||||
)
|
||||
|
||||
result = await scanner._save_serie_to_db(
|
||||
sample_serie, mock_db_session
|
||||
)
|
||||
|
||||
assert result is None
|
||||
mock_service.update.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_serie_in_db_updates_existing(
|
||||
@ -321,15 +366,20 @@ class TestSerieScannerDatabaseHelpers:
|
||||
with patch(
|
||||
'src.server.database.service.AnimeSeriesService'
|
||||
) as mock_service:
|
||||
mock_service.get_by_key = AsyncMock(return_value=existing)
|
||||
mock_service.update = AsyncMock(return_value=existing)
|
||||
|
||||
result = await scanner._update_serie_in_db(
|
||||
sample_serie, mock_db_session
|
||||
)
|
||||
|
||||
assert result is existing
|
||||
mock_service.update.assert_called_once()
|
||||
with patch(
|
||||
'src.server.database.service.EpisodeService'
|
||||
) as mock_ep_service:
|
||||
mock_service.get_by_key = AsyncMock(return_value=existing)
|
||||
mock_service.update = AsyncMock(return_value=existing)
|
||||
mock_ep_service.get_by_series = AsyncMock(return_value=[])
|
||||
mock_ep_service.create = AsyncMock()
|
||||
|
||||
result = await scanner._update_serie_in_db(
|
||||
sample_serie, mock_db_session
|
||||
)
|
||||
|
||||
assert result is existing
|
||||
mock_service.update.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_serie_in_db_returns_none_if_not_found(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user