204 lines
6.6 KiB
Python
204 lines
6.6 KiB
Python
"""Integration tests for concurrent operations.
|
|
|
|
Tests concurrent downloads, parallel NFO generation, race conditions,
|
|
and cache consistency under concurrent access.
|
|
"""
|
|
|
|
import asyncio
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
|
|
class TestConcurrentDownloads:
|
|
"""Concurrent download queue operations."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_concurrent_queue_additions(self):
|
|
"""Multiple concurrent add operations don't corrupt the queue."""
|
|
from src.server.database.models import DownloadQueueItem, DownloadStatus
|
|
|
|
items = []
|
|
for i in range(10):
|
|
item = DownloadQueueItem(
|
|
series_id=1,
|
|
episode_id=i,
|
|
download_url=f"https://example.com/{i}",
|
|
file_destination=f"/tmp/ep{i}.mp4",
|
|
)
|
|
items.append(item)
|
|
|
|
# All items created without collision
|
|
urls = {item.download_url for item in items}
|
|
assert len(urls) == 10
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_download_status_transitions_are_safe(self):
|
|
"""Status can only transition to valid states."""
|
|
from src.server.database.models import DownloadStatus
|
|
|
|
valid_transitions = {
|
|
DownloadStatus.PENDING: {
|
|
DownloadStatus.DOWNLOADING,
|
|
DownloadStatus.CANCELLED,
|
|
},
|
|
DownloadStatus.DOWNLOADING: {
|
|
DownloadStatus.COMPLETED,
|
|
DownloadStatus.FAILED,
|
|
DownloadStatus.PAUSED,
|
|
},
|
|
}
|
|
# Verify the enum has all expected members
|
|
assert DownloadStatus.PENDING is not None
|
|
assert DownloadStatus.DOWNLOADING is not None
|
|
assert DownloadStatus.COMPLETED is not None
|
|
assert DownloadStatus.FAILED is not None
|
|
|
|
|
|
class TestParallelNfoGeneration:
|
|
"""Parallel NFO creation for multiple series."""
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("src.core.services.series_manager_service.SerieList")
|
|
async def test_multiple_series_process_sequentially(self, mock_sl):
|
|
"""process_nfo_for_series called for each serie in order."""
|
|
from src.core.services.series_manager_service import SeriesManagerService
|
|
|
|
manager = SeriesManagerService(
|
|
anime_directory="/anime",
|
|
tmdb_api_key=None,
|
|
)
|
|
# Without nfo_service, should be no-op
|
|
await manager.process_nfo_for_series(
|
|
serie_folder="test-folder",
|
|
serie_name="Test Anime",
|
|
serie_key="test-key",
|
|
)
|
|
# No exception raised
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_concurrent_factory_calls_return_same_singleton(self):
|
|
"""get_nfo_factory returns the same instance across concurrent calls."""
|
|
from src.core.services.nfo_factory import get_nfo_factory
|
|
|
|
results = []
|
|
|
|
async def get_factory():
|
|
results.append(get_nfo_factory())
|
|
|
|
tasks = [get_factory() for _ in range(5)]
|
|
await asyncio.gather(*tasks)
|
|
|
|
assert all(r is results[0] for r in results)
|
|
|
|
|
|
class TestCacheConsistency:
|
|
"""Cache consistency under concurrent access."""
|
|
|
|
def test_provider_cache_key_uniqueness(self):
|
|
"""Different inputs produce different cache keys."""
|
|
from src.core.providers.aniworld_provider import AniworldLoader
|
|
|
|
loader = AniworldLoader.__new__(AniworldLoader)
|
|
loader.cache = {}
|
|
loader.base_url = "https://aniworld.to"
|
|
|
|
# Cache is a plain dict - keys are URLs
|
|
key_a = f"{loader.base_url}/anime/stream/naruto"
|
|
key_b = f"{loader.base_url}/anime/stream/bleach"
|
|
assert key_a != key_b
|
|
|
|
def test_concurrent_dict_writes_no_data_loss(self):
|
|
"""Concurrent writes to a dict lose no keys (GIL protection)."""
|
|
import threading
|
|
|
|
shared = {}
|
|
barrier = threading.Barrier(10)
|
|
|
|
def writer(idx):
|
|
barrier.wait()
|
|
shared[f"key_{idx}"] = idx
|
|
|
|
threads = [threading.Thread(target=writer, args=(i,)) for i in range(10)]
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join()
|
|
|
|
assert len(shared) == 10
|
|
|
|
|
|
class TestDatabaseConcurrency:
|
|
"""Database access under concurrent conditions."""
|
|
|
|
def test_model_creation_does_not_share_state(self):
|
|
"""Two AnimeSeries instances are independent."""
|
|
from src.server.database.models import AnimeSeries
|
|
|
|
a = AnimeSeries(key="anime-a", name="A", site="https://a.com", folder="A")
|
|
b = AnimeSeries(key="anime-b", name="B", site="https://b.com", folder="B")
|
|
assert a.key != b.key
|
|
assert a is not b
|
|
|
|
def test_download_queue_item_defaults(self):
|
|
"""Default fields are set correctly."""
|
|
from src.server.database.models import DownloadQueueItem
|
|
|
|
item = DownloadQueueItem(
|
|
series_id=1,
|
|
episode_id=1,
|
|
download_url="https://example.com/ep1",
|
|
file_destination="/tmp/ep1.mp4",
|
|
)
|
|
assert item.error_message is None
|
|
assert item.started_at is None
|
|
assert item.completed_at is None
|
|
|
|
def test_episode_model_boundary_values(self):
|
|
"""Episode model accepts boundary season/episode values."""
|
|
from src.server.database.models import Episode
|
|
|
|
# Min boundary
|
|
ep_min = Episode(series_id=1, season=0, episode_number=0, title="Ep0")
|
|
assert ep_min.season == 0
|
|
|
|
# Max boundary
|
|
ep_max = Episode(series_id=1, season=1000, episode_number=10000, title="EpMax")
|
|
assert ep_max.season == 1000
|
|
|
|
|
|
class TestWebSocketConcurrency:
|
|
"""WebSocket broadcast during concurrent operations."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_to_empty_connections(self):
|
|
"""Broadcasting to zero connections is a no-op."""
|
|
# Simulate a broadcast manager with empty connections
|
|
connections: list = []
|
|
|
|
async def broadcast(msg: str):
|
|
for ws in connections:
|
|
await ws.send_text(msg)
|
|
|
|
# Should not raise
|
|
await broadcast("test")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_skips_closed_connections(self):
|
|
"""Closed WebSocket connections are handled gracefully."""
|
|
closed_ws = AsyncMock()
|
|
closed_ws.send_text.side_effect = RuntimeError("connection closed")
|
|
|
|
connections = [closed_ws]
|
|
errors = []
|
|
|
|
async def broadcast(msg: str):
|
|
for ws in connections:
|
|
try:
|
|
await ws.send_text(msg)
|
|
except RuntimeError:
|
|
errors.append(ws)
|
|
|
|
await broadcast("test")
|
|
assert len(errors) == 1
|