Aniworld/tests/unit/test_database_service.py
2025-12-04 19:22:42 +01:00

608 lines
16 KiB
Python

"""Unit tests for database service layer.
Tests CRUD operations for all database services using in-memory SQLite.
"""
import asyncio
from datetime import datetime, timedelta, timezone
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from src.server.database.base import Base
from src.server.database.service import (
AnimeSeriesService,
DownloadQueueService,
EpisodeService,
UserSessionService,
)
@pytest.fixture
async def db_engine():
"""Create in-memory database engine for testing."""
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
echo=False,
)
# Create all tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
# Cleanup
await engine.dispose()
@pytest.fixture
async def db_session(db_engine):
"""Create database session for testing."""
async_session = sessionmaker(
db_engine,
class_=AsyncSession,
expire_on_commit=False,
)
async with async_session() as session:
yield session
await session.rollback()
# ============================================================================
# AnimeSeriesService Tests
# ============================================================================
@pytest.mark.asyncio
async def test_create_anime_series(db_session):
"""Test creating an anime series."""
series = await AnimeSeriesService.create(
db_session,
key="test-anime-1",
name="Test Anime",
site="https://example.com",
folder="/path/to/anime",
)
assert series.id is not None
assert series.key == "test-anime-1"
assert series.name == "Test Anime"
@pytest.mark.asyncio
async def test_get_anime_series_by_id(db_session):
"""Test retrieving anime series by ID."""
# Create series
series = await AnimeSeriesService.create(
db_session,
key="test-anime-2",
name="Test Anime 2",
site="https://example.com",
folder="/path/to/anime2",
)
await db_session.commit()
# Retrieve series
retrieved = await AnimeSeriesService.get_by_id(db_session, series.id)
assert retrieved is not None
assert retrieved.id == series.id
assert retrieved.key == "test-anime-2"
@pytest.mark.asyncio
async def test_get_anime_series_by_key(db_session):
"""Test retrieving anime series by provider key."""
# Create series
await AnimeSeriesService.create(
db_session,
key="unique-key",
name="Test Anime",
site="https://example.com",
folder="/path/to/anime",
)
await db_session.commit()
# Retrieve by key
retrieved = await AnimeSeriesService.get_by_key(db_session, "unique-key")
assert retrieved is not None
assert retrieved.key == "unique-key"
@pytest.mark.asyncio
async def test_get_all_anime_series(db_session):
"""Test retrieving all anime series."""
# Create multiple series
await AnimeSeriesService.create(
db_session,
key="anime-1",
name="Anime 1",
site="https://example.com",
folder="/path/1",
)
await AnimeSeriesService.create(
db_session,
key="anime-2",
name="Anime 2",
site="https://example.com",
folder="/path/2",
)
await db_session.commit()
# Retrieve all
all_series = await AnimeSeriesService.get_all(db_session)
assert len(all_series) == 2
@pytest.mark.asyncio
async def test_update_anime_series(db_session):
"""Test updating anime series."""
# Create series
series = await AnimeSeriesService.create(
db_session,
key="anime-update",
name="Original Name",
site="https://example.com",
folder="/path/original",
)
await db_session.commit()
# Update series
updated = await AnimeSeriesService.update(
db_session,
series.id,
name="Updated Name",
)
await db_session.commit()
assert updated is not None
assert updated.name == "Updated Name"
@pytest.mark.asyncio
async def test_delete_anime_series(db_session):
"""Test deleting anime series."""
# Create series
series = await AnimeSeriesService.create(
db_session,
key="anime-delete",
name="To Delete",
site="https://example.com",
folder="/path/delete",
)
await db_session.commit()
# Delete series
deleted = await AnimeSeriesService.delete(db_session, series.id)
await db_session.commit()
assert deleted is True
# Verify deletion
retrieved = await AnimeSeriesService.get_by_id(db_session, series.id)
assert retrieved is None
@pytest.mark.asyncio
async def test_search_anime_series(db_session):
"""Test searching anime series by name."""
# Create series
await AnimeSeriesService.create(
db_session,
key="naruto",
name="Naruto Shippuden",
site="https://example.com",
folder="/path/naruto",
)
await AnimeSeriesService.create(
db_session,
key="bleach",
name="Bleach",
site="https://example.com",
folder="/path/bleach",
)
await db_session.commit()
# Search
results = await AnimeSeriesService.search(db_session, "naruto")
assert len(results) == 1
assert results[0].name == "Naruto Shippuden"
# ============================================================================
# EpisodeService Tests
# ============================================================================
@pytest.mark.asyncio
async def test_create_episode(db_session):
"""Test creating an episode."""
# Create series first
series = await AnimeSeriesService.create(
db_session,
key="test-series",
name="Test Series",
site="https://example.com",
folder="/path/test",
)
await db_session.commit()
# Create episode
episode = await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
title="Episode 1",
)
assert episode.id is not None
assert episode.series_id == series.id
assert episode.season == 1
assert episode.episode_number == 1
@pytest.mark.asyncio
async def test_get_episodes_by_series(db_session):
"""Test retrieving episodes for a series."""
# Create series
series = await AnimeSeriesService.create(
db_session,
key="test-series-2",
name="Test Series 2",
site="https://example.com",
folder="/path/test2",
)
# Create episodes
await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
)
await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=2,
)
await db_session.commit()
# Retrieve episodes
episodes = await EpisodeService.get_by_series(db_session, series.id)
assert len(episodes) == 2
@pytest.mark.asyncio
async def test_mark_episode_downloaded(db_session):
"""Test marking episode as downloaded."""
# Create series and episode
series = await AnimeSeriesService.create(
db_session,
key="test-series-3",
name="Test Series 3",
site="https://example.com",
folder="/path/test3",
)
episode = await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
)
await db_session.commit()
# Mark as downloaded
updated = await EpisodeService.mark_downloaded(
db_session,
episode.id,
file_path="/path/to/file.mp4",
)
await db_session.commit()
assert updated is not None
assert updated.is_downloaded is True
assert updated.file_path == "/path/to/file.mp4"
# ============================================================================
# DownloadQueueService Tests
# ============================================================================
@pytest.mark.asyncio
async def test_create_download_queue_item(db_session):
"""Test adding item to download queue."""
# Create series
series = await AnimeSeriesService.create(
db_session,
key="test-series-4",
name="Test Series 4",
site="https://example.com",
folder="/path/test4",
)
await db_session.commit()
# Create episode
episode = await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
)
await db_session.commit()
# Add to queue
item = await DownloadQueueService.create(
db_session,
series_id=series.id,
episode_id=episode.id,
)
assert item.id is not None
assert item.episode_id == episode.id
assert item.series_id == series.id
@pytest.mark.asyncio
async def test_get_download_queue_item_by_episode(db_session):
"""Test retrieving download queue item by episode."""
# Create series
series = await AnimeSeriesService.create(
db_session,
key="test-series-5",
name="Test Series 5",
site="https://example.com",
folder="/path/test5",
)
# Create episode
episode = await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
)
await db_session.commit()
# Add to queue
await DownloadQueueService.create(
db_session,
series_id=series.id,
episode_id=episode.id,
)
await db_session.commit()
# Retrieve by episode
item = await DownloadQueueService.get_by_episode(db_session, episode.id)
assert item is not None
assert item.episode_id == episode.id
@pytest.mark.asyncio
async def test_set_download_error(db_session):
"""Test setting error on download queue item."""
# Create series and queue item
series = await AnimeSeriesService.create(
db_session,
key="test-series-6",
name="Test Series 6",
site="https://example.com",
folder="/path/test6",
)
episode = await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
)
item = await DownloadQueueService.create(
db_session,
series_id=series.id,
episode_id=episode.id,
)
await db_session.commit()
# Set error
updated = await DownloadQueueService.set_error(
db_session,
item.id,
"Network error",
)
await db_session.commit()
assert updated is not None
assert updated.error_message == "Network error"
@pytest.mark.asyncio
async def test_delete_download_queue_item_by_episode(db_session):
"""Test deleting download queue item by episode."""
# Create series and queue item
series = await AnimeSeriesService.create(
db_session,
key="test-series-7",
name="Test Series 7",
site="https://example.com",
folder="/path/test7",
)
episode = await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
)
await DownloadQueueService.create(
db_session,
series_id=series.id,
episode_id=episode.id,
)
await db_session.commit()
# Delete by episode
deleted = await DownloadQueueService.delete_by_episode(
db_session,
episode.id,
)
await db_session.commit()
assert deleted is True
# Verify deleted
item = await DownloadQueueService.get_by_episode(db_session, episode.id)
assert item is None
# ============================================================================
# UserSessionService Tests
# ============================================================================
@pytest.mark.asyncio
async def test_create_user_session(db_session):
"""Test creating a user session."""
expires_at = datetime.now(timezone.utc) + timedelta(hours=24)
session = await UserSessionService.create(
db_session,
session_id="test-session-1",
token_hash="hashed-token",
expires_at=expires_at,
user_id="user123",
ip_address="127.0.0.1",
)
assert session.id is not None
assert session.session_id == "test-session-1"
assert session.is_active is True
@pytest.mark.asyncio
async def test_get_session_by_id(db_session):
"""Test retrieving session by ID."""
expires_at = datetime.now(timezone.utc) + timedelta(hours=24)
session = await UserSessionService.create(
db_session,
session_id="test-session-2",
token_hash="hashed-token",
expires_at=expires_at,
)
await db_session.commit()
# Retrieve
retrieved = await UserSessionService.get_by_session_id(
db_session,
"test-session-2",
)
assert retrieved is not None
assert retrieved.session_id == "test-session-2"
@pytest.mark.asyncio
async def test_get_active_sessions(db_session):
"""Test retrieving active sessions."""
expires_at = datetime.now(timezone.utc) + timedelta(hours=24)
# Create active session
await UserSessionService.create(
db_session,
session_id="active-session",
token_hash="hashed-token",
expires_at=expires_at,
)
# Create expired session
await UserSessionService.create(
db_session,
session_id="expired-session",
token_hash="hashed-token",
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
)
await db_session.commit()
# Retrieve active sessions
active = await UserSessionService.get_active_sessions(db_session)
assert len(active) == 1
assert active[0].session_id == "active-session"
@pytest.mark.asyncio
async def test_revoke_session(db_session):
"""Test revoking a session."""
expires_at = datetime.now(timezone.utc) + timedelta(hours=24)
session = await UserSessionService.create(
db_session,
session_id="test-session-3",
token_hash="hashed-token",
expires_at=expires_at,
)
await db_session.commit()
# Revoke
revoked = await UserSessionService.revoke(db_session, "test-session-3")
await db_session.commit()
assert revoked is True
# Verify
retrieved = await UserSessionService.get_by_session_id(
db_session,
"test-session-3",
)
assert retrieved.is_active is False
@pytest.mark.asyncio
async def test_cleanup_expired_sessions(db_session):
"""Test cleaning up expired sessions."""
# Create expired sessions
await UserSessionService.create(
db_session,
session_id="expired-1",
token_hash="hashed-token",
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
)
await UserSessionService.create(
db_session,
session_id="expired-2",
token_hash="hashed-token",
expires_at=datetime.now(timezone.utc) - timedelta(hours=2),
)
await db_session.commit()
# Cleanup
count = await UserSessionService.cleanup_expired(db_session)
await db_session.commit()
assert count == 2
@pytest.mark.asyncio
async def test_update_session_activity(db_session):
"""Test updating session last activity."""
expires_at = datetime.now(timezone.utc) + timedelta(hours=24)
session = await UserSessionService.create(
db_session,
session_id="test-session-4",
token_hash="hashed-token",
expires_at=expires_at,
)
await db_session.commit()
original_activity = session.last_activity
# Wait a bit
await asyncio.sleep(0.1)
# Update activity
updated = await UserSessionService.update_activity(
db_session,
"test-session-4",
)
await db_session.commit()
assert updated is not None
assert updated.last_activity > original_activity