608 lines
16 KiB
Python
608 lines
16 KiB
Python
"""Unit tests for database service layer.
|
|
|
|
Tests CRUD operations for all database services using in-memory SQLite.
|
|
"""
|
|
import asyncio
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
import pytest
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
from src.server.database.base import Base
|
|
from src.server.database.service import (
|
|
AnimeSeriesService,
|
|
DownloadQueueService,
|
|
EpisodeService,
|
|
UserSessionService,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
async def db_engine():
|
|
"""Create in-memory database engine for testing."""
|
|
engine = create_async_engine(
|
|
"sqlite+aiosqlite:///:memory:",
|
|
echo=False,
|
|
)
|
|
|
|
# Create all tables
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
yield engine
|
|
|
|
# Cleanup
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest.fixture
|
|
async def db_session(db_engine):
|
|
"""Create database session for testing."""
|
|
async_session = sessionmaker(
|
|
db_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
)
|
|
|
|
async with async_session() as session:
|
|
yield session
|
|
await session.rollback()
|
|
|
|
|
|
# ============================================================================
|
|
# AnimeSeriesService Tests
|
|
# ============================================================================
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_anime_series(db_session):
|
|
"""Test creating an anime series."""
|
|
series = await AnimeSeriesService.create(
|
|
db_session,
|
|
key="test-anime-1",
|
|
name="Test Anime",
|
|
site="https://example.com",
|
|
folder="/path/to/anime",
|
|
)
|
|
|
|
assert series.id is not None
|
|
assert series.key == "test-anime-1"
|
|
assert series.name == "Test Anime"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_anime_series_by_id(db_session):
|
|
"""Test retrieving anime series by ID."""
|
|
# Create series
|
|
series = await AnimeSeriesService.create(
|
|
db_session,
|
|
key="test-anime-2",
|
|
name="Test Anime 2",
|
|
site="https://example.com",
|
|
folder="/path/to/anime2",
|
|
)
|
|
await db_session.commit()
|
|
|
|
# Retrieve series
|
|
retrieved = await AnimeSeriesService.get_by_id(db_session, series.id)
|
|
assert retrieved is not None
|
|
assert retrieved.id == series.id
|
|
assert retrieved.key == "test-anime-2"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_anime_series_by_key(db_session):
|
|
"""Test retrieving anime series by provider key."""
|
|
# Create series
|
|
await AnimeSeriesService.create(
|
|
db_session,
|
|
key="unique-key",
|
|
name="Test Anime",
|
|
site="https://example.com",
|
|
folder="/path/to/anime",
|
|
)
|
|
await db_session.commit()
|
|
|
|
# Retrieve by key
|
|
retrieved = await AnimeSeriesService.get_by_key(db_session, "unique-key")
|
|
assert retrieved is not None
|
|
assert retrieved.key == "unique-key"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_all_anime_series(db_session):
|
|
"""Test retrieving all anime series."""
|
|
# Create multiple series
|
|
await AnimeSeriesService.create(
|
|
db_session,
|
|
key="anime-1",
|
|
name="Anime 1",
|
|
site="https://example.com",
|
|
folder="/path/1",
|
|
)
|
|
await AnimeSeriesService.create(
|
|
db_session,
|
|
key="anime-2",
|
|
name="Anime 2",
|
|
site="https://example.com",
|
|
folder="/path/2",
|
|
)
|
|
await db_session.commit()
|
|
|
|
# Retrieve all
|
|
all_series = await AnimeSeriesService.get_all(db_session)
|
|
assert len(all_series) == 2
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_anime_series(db_session):
|
|
"""Test updating anime series."""
|
|
# Create series
|
|
series = await AnimeSeriesService.create(
|
|
db_session,
|
|
key="anime-update",
|
|
name="Original Name",
|
|
site="https://example.com",
|
|
folder="/path/original",
|
|
)
|
|
await db_session.commit()
|
|
|
|
# Update series
|
|
updated = await AnimeSeriesService.update(
|
|
db_session,
|
|
series.id,
|
|
name="Updated Name",
|
|
)
|
|
await db_session.commit()
|
|
|
|
assert updated is not None
|
|
assert updated.name == "Updated Name"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_anime_series(db_session):
|
|
"""Test deleting anime series."""
|
|
# Create series
|
|
series = await AnimeSeriesService.create(
|
|
db_session,
|
|
key="anime-delete",
|
|
name="To Delete",
|
|
site="https://example.com",
|
|
folder="/path/delete",
|
|
)
|
|
await db_session.commit()
|
|
|
|
# Delete series
|
|
deleted = await AnimeSeriesService.delete(db_session, series.id)
|
|
await db_session.commit()
|
|
|
|
assert deleted is True
|
|
|
|
# Verify deletion
|
|
retrieved = await AnimeSeriesService.get_by_id(db_session, series.id)
|
|
assert retrieved is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_anime_series(db_session):
|
|
"""Test searching anime series by name."""
|
|
# Create series
|
|
await AnimeSeriesService.create(
|
|
db_session,
|
|
key="naruto",
|
|
name="Naruto Shippuden",
|
|
site="https://example.com",
|
|
folder="/path/naruto",
|
|
)
|
|
await AnimeSeriesService.create(
|
|
db_session,
|
|
key="bleach",
|
|
name="Bleach",
|
|
site="https://example.com",
|
|
folder="/path/bleach",
|
|
)
|
|
await db_session.commit()
|
|
|
|
# Search
|
|
results = await AnimeSeriesService.search(db_session, "naruto")
|
|
assert len(results) == 1
|
|
assert results[0].name == "Naruto Shippuden"
|
|
|
|
|
|
# ============================================================================
|
|
# EpisodeService Tests
|
|
# ============================================================================
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_episode(db_session):
|
|
"""Test creating an episode."""
|
|
# Create series first
|
|
series = await AnimeSeriesService.create(
|
|
db_session,
|
|
key="test-series",
|
|
name="Test Series",
|
|
site="https://example.com",
|
|
folder="/path/test",
|
|
)
|
|
await db_session.commit()
|
|
|
|
# Create episode
|
|
episode = await EpisodeService.create(
|
|
db_session,
|
|
series_id=series.id,
|
|
season=1,
|
|
episode_number=1,
|
|
title="Episode 1",
|
|
)
|
|
|
|
assert episode.id is not None
|
|
assert episode.series_id == series.id
|
|
assert episode.season == 1
|
|
assert episode.episode_number == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_episodes_by_series(db_session):
|
|
"""Test retrieving episodes for a series."""
|
|
# Create series
|
|
series = await AnimeSeriesService.create(
|
|
db_session,
|
|
key="test-series-2",
|
|
name="Test Series 2",
|
|
site="https://example.com",
|
|
folder="/path/test2",
|
|
)
|
|
|
|
# Create episodes
|
|
await EpisodeService.create(
|
|
db_session,
|
|
series_id=series.id,
|
|
season=1,
|
|
episode_number=1,
|
|
)
|
|
await EpisodeService.create(
|
|
db_session,
|
|
series_id=series.id,
|
|
season=1,
|
|
episode_number=2,
|
|
)
|
|
await db_session.commit()
|
|
|
|
# Retrieve episodes
|
|
episodes = await EpisodeService.get_by_series(db_session, series.id)
|
|
assert len(episodes) == 2
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mark_episode_downloaded(db_session):
|
|
"""Test marking episode as downloaded."""
|
|
# Create series and episode
|
|
series = await AnimeSeriesService.create(
|
|
db_session,
|
|
key="test-series-3",
|
|
name="Test Series 3",
|
|
site="https://example.com",
|
|
folder="/path/test3",
|
|
)
|
|
episode = await EpisodeService.create(
|
|
db_session,
|
|
series_id=series.id,
|
|
season=1,
|
|
episode_number=1,
|
|
)
|
|
await db_session.commit()
|
|
|
|
# Mark as downloaded
|
|
updated = await EpisodeService.mark_downloaded(
|
|
db_session,
|
|
episode.id,
|
|
file_path="/path/to/file.mp4",
|
|
)
|
|
await db_session.commit()
|
|
|
|
assert updated is not None
|
|
assert updated.is_downloaded is True
|
|
assert updated.file_path == "/path/to/file.mp4"
|
|
|
|
|
|
# ============================================================================
|
|
# DownloadQueueService Tests
|
|
# ============================================================================
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_download_queue_item(db_session):
|
|
"""Test adding item to download queue."""
|
|
# Create series
|
|
series = await AnimeSeriesService.create(
|
|
db_session,
|
|
key="test-series-4",
|
|
name="Test Series 4",
|
|
site="https://example.com",
|
|
folder="/path/test4",
|
|
)
|
|
await db_session.commit()
|
|
|
|
# Create episode
|
|
episode = await EpisodeService.create(
|
|
db_session,
|
|
series_id=series.id,
|
|
season=1,
|
|
episode_number=1,
|
|
)
|
|
await db_session.commit()
|
|
|
|
# Add to queue
|
|
item = await DownloadQueueService.create(
|
|
db_session,
|
|
series_id=series.id,
|
|
episode_id=episode.id,
|
|
)
|
|
|
|
assert item.id is not None
|
|
assert item.episode_id == episode.id
|
|
assert item.series_id == series.id
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_download_queue_item_by_episode(db_session):
|
|
"""Test retrieving download queue item by episode."""
|
|
# Create series
|
|
series = await AnimeSeriesService.create(
|
|
db_session,
|
|
key="test-series-5",
|
|
name="Test Series 5",
|
|
site="https://example.com",
|
|
folder="/path/test5",
|
|
)
|
|
|
|
# Create episode
|
|
episode = await EpisodeService.create(
|
|
db_session,
|
|
series_id=series.id,
|
|
season=1,
|
|
episode_number=1,
|
|
)
|
|
await db_session.commit()
|
|
|
|
# Add to queue
|
|
await DownloadQueueService.create(
|
|
db_session,
|
|
series_id=series.id,
|
|
episode_id=episode.id,
|
|
)
|
|
await db_session.commit()
|
|
|
|
# Retrieve by episode
|
|
item = await DownloadQueueService.get_by_episode(db_session, episode.id)
|
|
assert item is not None
|
|
assert item.episode_id == episode.id
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_set_download_error(db_session):
|
|
"""Test setting error on download queue item."""
|
|
# Create series and queue item
|
|
series = await AnimeSeriesService.create(
|
|
db_session,
|
|
key="test-series-6",
|
|
name="Test Series 6",
|
|
site="https://example.com",
|
|
folder="/path/test6",
|
|
)
|
|
episode = await EpisodeService.create(
|
|
db_session,
|
|
series_id=series.id,
|
|
season=1,
|
|
episode_number=1,
|
|
)
|
|
item = await DownloadQueueService.create(
|
|
db_session,
|
|
series_id=series.id,
|
|
episode_id=episode.id,
|
|
)
|
|
await db_session.commit()
|
|
|
|
# Set error
|
|
updated = await DownloadQueueService.set_error(
|
|
db_session,
|
|
item.id,
|
|
"Network error",
|
|
)
|
|
await db_session.commit()
|
|
|
|
assert updated is not None
|
|
assert updated.error_message == "Network error"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_download_queue_item_by_episode(db_session):
|
|
"""Test deleting download queue item by episode."""
|
|
# Create series and queue item
|
|
series = await AnimeSeriesService.create(
|
|
db_session,
|
|
key="test-series-7",
|
|
name="Test Series 7",
|
|
site="https://example.com",
|
|
folder="/path/test7",
|
|
)
|
|
episode = await EpisodeService.create(
|
|
db_session,
|
|
series_id=series.id,
|
|
season=1,
|
|
episode_number=1,
|
|
)
|
|
await DownloadQueueService.create(
|
|
db_session,
|
|
series_id=series.id,
|
|
episode_id=episode.id,
|
|
)
|
|
await db_session.commit()
|
|
|
|
# Delete by episode
|
|
deleted = await DownloadQueueService.delete_by_episode(
|
|
db_session,
|
|
episode.id,
|
|
)
|
|
await db_session.commit()
|
|
|
|
assert deleted is True
|
|
|
|
# Verify deleted
|
|
item = await DownloadQueueService.get_by_episode(db_session, episode.id)
|
|
assert item is None
|
|
|
|
|
|
# ============================================================================
|
|
# UserSessionService Tests
|
|
# ============================================================================
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_user_session(db_session):
|
|
"""Test creating a user session."""
|
|
expires_at = datetime.now(timezone.utc) + timedelta(hours=24)
|
|
session = await UserSessionService.create(
|
|
db_session,
|
|
session_id="test-session-1",
|
|
token_hash="hashed-token",
|
|
expires_at=expires_at,
|
|
user_id="user123",
|
|
ip_address="127.0.0.1",
|
|
)
|
|
|
|
assert session.id is not None
|
|
assert session.session_id == "test-session-1"
|
|
assert session.is_active is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_session_by_id(db_session):
|
|
"""Test retrieving session by ID."""
|
|
expires_at = datetime.now(timezone.utc) + timedelta(hours=24)
|
|
session = await UserSessionService.create(
|
|
db_session,
|
|
session_id="test-session-2",
|
|
token_hash="hashed-token",
|
|
expires_at=expires_at,
|
|
)
|
|
await db_session.commit()
|
|
|
|
# Retrieve
|
|
retrieved = await UserSessionService.get_by_session_id(
|
|
db_session,
|
|
"test-session-2",
|
|
)
|
|
|
|
assert retrieved is not None
|
|
assert retrieved.session_id == "test-session-2"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_active_sessions(db_session):
|
|
"""Test retrieving active sessions."""
|
|
expires_at = datetime.now(timezone.utc) + timedelta(hours=24)
|
|
|
|
# Create active session
|
|
await UserSessionService.create(
|
|
db_session,
|
|
session_id="active-session",
|
|
token_hash="hashed-token",
|
|
expires_at=expires_at,
|
|
)
|
|
|
|
# Create expired session
|
|
await UserSessionService.create(
|
|
db_session,
|
|
session_id="expired-session",
|
|
token_hash="hashed-token",
|
|
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
|
|
)
|
|
await db_session.commit()
|
|
|
|
# Retrieve active sessions
|
|
active = await UserSessionService.get_active_sessions(db_session)
|
|
assert len(active) == 1
|
|
assert active[0].session_id == "active-session"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_revoke_session(db_session):
|
|
"""Test revoking a session."""
|
|
expires_at = datetime.now(timezone.utc) + timedelta(hours=24)
|
|
session = await UserSessionService.create(
|
|
db_session,
|
|
session_id="test-session-3",
|
|
token_hash="hashed-token",
|
|
expires_at=expires_at,
|
|
)
|
|
await db_session.commit()
|
|
|
|
# Revoke
|
|
revoked = await UserSessionService.revoke(db_session, "test-session-3")
|
|
await db_session.commit()
|
|
|
|
assert revoked is True
|
|
|
|
# Verify
|
|
retrieved = await UserSessionService.get_by_session_id(
|
|
db_session,
|
|
"test-session-3",
|
|
)
|
|
assert retrieved.is_active is False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cleanup_expired_sessions(db_session):
|
|
"""Test cleaning up expired sessions."""
|
|
# Create expired sessions
|
|
await UserSessionService.create(
|
|
db_session,
|
|
session_id="expired-1",
|
|
token_hash="hashed-token",
|
|
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
|
|
)
|
|
await UserSessionService.create(
|
|
db_session,
|
|
session_id="expired-2",
|
|
token_hash="hashed-token",
|
|
expires_at=datetime.now(timezone.utc) - timedelta(hours=2),
|
|
)
|
|
await db_session.commit()
|
|
|
|
# Cleanup
|
|
count = await UserSessionService.cleanup_expired(db_session)
|
|
await db_session.commit()
|
|
|
|
assert count == 2
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_session_activity(db_session):
|
|
"""Test updating session last activity."""
|
|
expires_at = datetime.now(timezone.utc) + timedelta(hours=24)
|
|
session = await UserSessionService.create(
|
|
db_session,
|
|
session_id="test-session-4",
|
|
token_hash="hashed-token",
|
|
expires_at=expires_at,
|
|
)
|
|
await db_session.commit()
|
|
|
|
original_activity = session.last_activity
|
|
|
|
# Wait a bit
|
|
await asyncio.sleep(0.1)
|
|
|
|
# Update activity
|
|
updated = await UserSessionService.update_activity(
|
|
db_session,
|
|
"test-session-4",
|
|
)
|
|
await db_session.commit()
|
|
|
|
assert updated is not None
|
|
assert updated.last_activity > original_activity
|