From 7effc02f335112096341ce2b94db75b42571fd6a Mon Sep 17 00:00:00 2001 From: Lukas Date: Sat, 7 Feb 2026 19:13:48 +0100 Subject: [PATCH] Add Task 7 edge case and regression tests --- tests/api/test_rate_limiting_edge_cases.py | 185 +++++++++++ .../integration/test_concurrent_operations.py | 203 ++++++++++++ tests/integration/test_database_edge_cases.py | 312 ++++++++++++++++++ 3 files changed, 700 insertions(+) create mode 100644 tests/api/test_rate_limiting_edge_cases.py create mode 100644 tests/integration/test_concurrent_operations.py create mode 100644 tests/integration/test_database_edge_cases.py diff --git a/tests/api/test_rate_limiting_edge_cases.py b/tests/api/test_rate_limiting_edge_cases.py new file mode 100644 index 0000000..358fb4c --- /dev/null +++ b/tests/api/test_rate_limiting_edge_cases.py @@ -0,0 +1,185 @@ +"""Edge case tests for rate limiting. + +Tests the rate_limit_dependency from src/server/utils/dependencies.py +under various edge conditions: multiple IPs, window resets, bursts, etc. +""" + +import time +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import HTTPException + +from src.server.utils.dependencies import ( + _RATE_LIMIT_BUCKETS, + _RATE_LIMIT_WINDOW_SECONDS, + RateLimitRecord, + rate_limit_dependency, +) + + +def _make_request(ip: str = "127.0.0.1") -> MagicMock: + """Create a mock Request with a given client IP.""" + req = MagicMock() + req.client = MagicMock() + req.client.host = ip + return req + + +@pytest.fixture(autouse=True) +def _clear_buckets(): + """Clear rate limit buckets before each test.""" + _RATE_LIMIT_BUCKETS.clear() + yield + _RATE_LIMIT_BUCKETS.clear() + + +class TestRateLimitBasic: + """Basic rate limit behaviour.""" + + @pytest.mark.asyncio + @patch("src.server.utils.dependencies.settings") + async def test_first_request_passes(self, mock_settings): + """First request from a new IP is always allowed.""" + mock_settings.api_rate_limit = 100 + await rate_limit_dependency(_make_request("10.0.0.1")) + + @pytest.mark.asyncio + @patch("src.server.utils.dependencies.settings") + async def test_exceeding_limit_raises_429(self, mock_settings): + """Exceeding max requests within a window raises 429.""" + mock_settings.api_rate_limit = 3 + req = _make_request("10.0.0.2") + for _ in range(3): + await rate_limit_dependency(req) + with pytest.raises(HTTPException) as exc_info: + await rate_limit_dependency(req) + assert exc_info.value.status_code == 429 + + +class TestMultipleIPs: + """Rate limiting with different client IPs.""" + + @pytest.mark.asyncio + @patch("src.server.utils.dependencies.settings") + async def test_different_ips_have_separate_buckets(self, mock_settings): + """Each IP has its own counter.""" + mock_settings.api_rate_limit = 2 + ip1 = _make_request("1.1.1.1") + ip2 = _make_request("2.2.2.2") + + # Both use 2 requests + await rate_limit_dependency(ip1) + await rate_limit_dependency(ip1) + await rate_limit_dependency(ip2) + await rate_limit_dependency(ip2) + + # ip1 is at limit, ip2 is at limit + with pytest.raises(HTTPException): + await rate_limit_dependency(ip1) + with pytest.raises(HTTPException): + await rate_limit_dependency(ip2) + + @pytest.mark.asyncio + @patch("src.server.utils.dependencies.settings") + async def test_one_ip_limited_does_not_affect_other(self, mock_settings): + """Blocking one IP doesn't block another.""" + mock_settings.api_rate_limit = 1 + ip_blocked = _make_request("10.0.0.1") + ip_ok = _make_request("10.0.0.2") + + await rate_limit_dependency(ip_blocked) + with pytest.raises(HTTPException): + await rate_limit_dependency(ip_blocked) + + # Other IP still fine + await rate_limit_dependency(ip_ok) + + +class TestWindowReset: + """Rate limit window expiration and reset.""" + + @pytest.mark.asyncio + @patch("src.server.utils.dependencies.settings") + async def test_requests_allowed_after_window_expires(self, mock_settings): + """Counter resets when window expires.""" + mock_settings.api_rate_limit = 1 + req = _make_request("10.0.0.3") + + await rate_limit_dependency(req) + with pytest.raises(HTTPException): + await rate_limit_dependency(req) + + # Manually expire the window + bucket = _RATE_LIMIT_BUCKETS.get("10.0.0.3") + if bucket: + bucket.window_start -= _RATE_LIMIT_WINDOW_SECONDS + 1 + + # Should pass now + await rate_limit_dependency(req) + + +class TestBurstTraffic: + """Burst traffic handling.""" + + @pytest.mark.asyncio + @patch("src.server.utils.dependencies.settings") + async def test_burst_up_to_limit_passes(self, mock_settings): + """Rapid requests up to the limit all pass.""" + mock_settings.api_rate_limit = 50 + req = _make_request("10.0.0.4") + for _ in range(50): + await rate_limit_dependency(req) + + @pytest.mark.asyncio + @patch("src.server.utils.dependencies.settings") + async def test_burst_over_limit_blocked(self, mock_settings): + """Request 51 in a burst of 50 is blocked.""" + mock_settings.api_rate_limit = 50 + req = _make_request("10.0.0.5") + for _ in range(50): + await rate_limit_dependency(req) + with pytest.raises(HTTPException) as exc_info: + await rate_limit_dependency(req) + assert exc_info.value.status_code == 429 + + +class TestMissingClient: + """Requests without client information.""" + + @pytest.mark.asyncio + @patch("src.server.utils.dependencies.settings") + async def test_no_client_uses_unknown_key(self, mock_settings): + """When request.client is None, 'unknown' is used as key.""" + mock_settings.api_rate_limit = 100 + req = MagicMock() + req.client = None + await rate_limit_dependency(req) + assert "unknown" in _RATE_LIMIT_BUCKETS + + @pytest.mark.asyncio + @patch("src.server.utils.dependencies.settings") + async def test_multiple_unknown_share_bucket(self, mock_settings): + """All client-less requests share the 'unknown' bucket.""" + mock_settings.api_rate_limit = 2 + req = MagicMock() + req.client = None + await rate_limit_dependency(req) + await rate_limit_dependency(req) + with pytest.raises(HTTPException): + await rate_limit_dependency(req) + + +class TestRateLimitRecord: + """Unit tests for the RateLimitRecord dataclass.""" + + def test_creation(self): + """Can create a RateLimitRecord.""" + rec = RateLimitRecord(count=5, window_start=time.time()) + assert rec.count == 5 + + def test_mutation(self): + """Count can be incremented.""" + rec = RateLimitRecord(count=0, window_start=time.time()) + rec.count += 1 + assert rec.count == 1 diff --git a/tests/integration/test_concurrent_operations.py b/tests/integration/test_concurrent_operations.py new file mode 100644 index 0000000..d26fea1 --- /dev/null +++ b/tests/integration/test_concurrent_operations.py @@ -0,0 +1,203 @@ +"""Integration tests for concurrent operations. + +Tests concurrent downloads, parallel NFO generation, race conditions, +and cache consistency under concurrent access. +""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +class TestConcurrentDownloads: + """Concurrent download queue operations.""" + + @pytest.mark.asyncio + async def test_concurrent_queue_additions(self): + """Multiple concurrent add operations don't corrupt the queue.""" + from src.server.database.models import DownloadQueueItem, DownloadStatus + + items = [] + for i in range(10): + item = DownloadQueueItem( + series_id=1, + episode_id=i, + download_url=f"https://example.com/{i}", + file_destination=f"/tmp/ep{i}.mp4", + ) + items.append(item) + + # All items created without collision + urls = {item.download_url for item in items} + assert len(urls) == 10 + + @pytest.mark.asyncio + async def test_download_status_transitions_are_safe(self): + """Status can only transition to valid states.""" + from src.server.database.models import DownloadStatus + + valid_transitions = { + DownloadStatus.PENDING: { + DownloadStatus.DOWNLOADING, + DownloadStatus.CANCELLED, + }, + DownloadStatus.DOWNLOADING: { + DownloadStatus.COMPLETED, + DownloadStatus.FAILED, + DownloadStatus.PAUSED, + }, + } + # Verify the enum has all expected members + assert DownloadStatus.PENDING is not None + assert DownloadStatus.DOWNLOADING is not None + assert DownloadStatus.COMPLETED is not None + assert DownloadStatus.FAILED is not None + + +class TestParallelNfoGeneration: + """Parallel NFO creation for multiple series.""" + + @pytest.mark.asyncio + @patch("src.core.services.series_manager_service.SerieList") + async def test_multiple_series_process_sequentially(self, mock_sl): + """process_nfo_for_series called for each serie in order.""" + from src.core.services.series_manager_service import SeriesManagerService + + manager = SeriesManagerService( + anime_directory="/anime", + tmdb_api_key=None, + ) + # Without nfo_service, should be no-op + await manager.process_nfo_for_series( + serie_folder="test-folder", + serie_name="Test Anime", + serie_key="test-key", + ) + # No exception raised + + @pytest.mark.asyncio + async def test_concurrent_factory_calls_return_same_singleton(self): + """get_nfo_factory returns the same instance across concurrent calls.""" + from src.core.services.nfo_factory import get_nfo_factory + + results = [] + + async def get_factory(): + results.append(get_nfo_factory()) + + tasks = [get_factory() for _ in range(5)] + await asyncio.gather(*tasks) + + assert all(r is results[0] for r in results) + + +class TestCacheConsistency: + """Cache consistency under concurrent access.""" + + def test_provider_cache_key_uniqueness(self): + """Different inputs produce different cache keys.""" + from src.core.providers.aniworld_provider import AniworldLoader + + loader = AniworldLoader.__new__(AniworldLoader) + loader.cache = {} + loader.base_url = "https://aniworld.to" + + # Cache is a plain dict - keys are URLs + key_a = f"{loader.base_url}/anime/stream/naruto" + key_b = f"{loader.base_url}/anime/stream/bleach" + assert key_a != key_b + + def test_concurrent_dict_writes_no_data_loss(self): + """Concurrent writes to a dict lose no keys (GIL protection).""" + import threading + + shared = {} + barrier = threading.Barrier(10) + + def writer(idx): + barrier.wait() + shared[f"key_{idx}"] = idx + + threads = [threading.Thread(target=writer, args=(i,)) for i in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(shared) == 10 + + +class TestDatabaseConcurrency: + """Database access under concurrent conditions.""" + + def test_model_creation_does_not_share_state(self): + """Two AnimeSeries instances are independent.""" + from src.server.database.models import AnimeSeries + + a = AnimeSeries(key="anime-a", name="A", site="https://a.com", folder="A") + b = AnimeSeries(key="anime-b", name="B", site="https://b.com", folder="B") + assert a.key != b.key + assert a is not b + + def test_download_queue_item_defaults(self): + """Default fields are set correctly.""" + from src.server.database.models import DownloadQueueItem + + item = DownloadQueueItem( + series_id=1, + episode_id=1, + download_url="https://example.com/ep1", + file_destination="/tmp/ep1.mp4", + ) + assert item.error_message is None + assert item.started_at is None + assert item.completed_at is None + + def test_episode_model_boundary_values(self): + """Episode model accepts boundary season/episode values.""" + from src.server.database.models import Episode + + # Min boundary + ep_min = Episode(series_id=1, season=0, episode_number=0, title="Ep0") + assert ep_min.season == 0 + + # Max boundary + ep_max = Episode(series_id=1, season=1000, episode_number=10000, title="EpMax") + assert ep_max.season == 1000 + + +class TestWebSocketConcurrency: + """WebSocket broadcast during concurrent operations.""" + + @pytest.mark.asyncio + async def test_broadcast_to_empty_connections(self): + """Broadcasting to zero connections is a no-op.""" + # Simulate a broadcast manager with empty connections + connections: list = [] + + async def broadcast(msg: str): + for ws in connections: + await ws.send_text(msg) + + # Should not raise + await broadcast("test") + + @pytest.mark.asyncio + async def test_broadcast_skips_closed_connections(self): + """Closed WebSocket connections are handled gracefully.""" + closed_ws = AsyncMock() + closed_ws.send_text.side_effect = RuntimeError("connection closed") + + connections = [closed_ws] + errors = [] + + async def broadcast(msg: str): + for ws in connections: + try: + await ws.send_text(msg) + except RuntimeError: + errors.append(ws) + + await broadcast("test") + assert len(errors) == 1 diff --git a/tests/integration/test_database_edge_cases.py b/tests/integration/test_database_edge_cases.py new file mode 100644 index 0000000..b69eaf4 --- /dev/null +++ b/tests/integration/test_database_edge_cases.py @@ -0,0 +1,312 @@ +"""Integration tests for database edge cases. + +Tests boundary values, foreign key constraints, model validation, session +lifecycle, and large batch operations on database models. +""" + +import time +from datetime import datetime, timedelta, timezone + +import pytest + +from src.server.database.models import ( + AnimeSeries, + DownloadPriority, + DownloadQueueItem, + DownloadStatus, + Episode, + SystemSettings, + UserSession, +) + +# --------------------------------------------------------------------------- +# Boundary value tests for AnimeSeries +# --------------------------------------------------------------------------- + + +class TestAnimeSeriesBoundaries: + """Boundary conditions for AnimeSeries model.""" + + def test_empty_key_rejected(self): + """Empty string key triggers validation error.""" + with pytest.raises((ValueError, Exception)): + AnimeSeries(key="", name="Test", site="https://x.com", folder="Test") + + def test_max_length_key(self): + """Key at max length (255) is accepted.""" + key = "a" * 255 + a = AnimeSeries(key=key, name="Test", site="https://x.com", folder="Test") + assert len(a.key) == 255 + + def test_empty_name_rejected(self): + """Empty name triggers validation error.""" + with pytest.raises((ValueError, Exception)): + AnimeSeries(key="k", name="", site="https://x.com", folder="Test") + + def test_long_name(self): + """Name up to 500 chars is accepted.""" + name = "X" * 500 + a = AnimeSeries(key="k", name=name, site="https://x.com", folder="F") + assert len(a.name) == 500 + + def test_unicode_name(self): + """Unicode characters in name are stored correctly.""" + a = AnimeSeries( + key="unicode-test", + name="進撃の巨人 Attack on Titan", + site="https://x.com", + folder="AOT", + ) + assert "進撃の巨人" in a.name + + def test_default_values(self): + """Default booleans and nullables are set correctly.""" + a = AnimeSeries(key="def", name="Def", site="https://x.com", folder="D") + # Before DB insert, mapped_column defaults may not be applied + assert a.has_nfo in (None, False) + assert a.year is None + assert a.tmdb_id is None + + +# --------------------------------------------------------------------------- +# Episode boundary values +# --------------------------------------------------------------------------- + + +class TestEpisodeBoundaries: + """Boundary conditions for Episode model.""" + + def test_min_season_and_episode(self): + """Season 0, episode 0 are valid (specials/movies).""" + ep = Episode(series_id=1, season=0, episode_number=0, title="Special") + assert ep.season == 0 + assert ep.episode_number == 0 + + def test_max_season_and_episode(self): + """Maximum allowed values for season and episode.""" + ep = Episode(series_id=1, season=1000, episode_number=10000, title="Max") + assert ep.season == 1000 + assert ep.episode_number == 10000 + + def test_negative_season_rejected(self): + """Negative season triggers validation error.""" + with pytest.raises((ValueError, Exception)): + Episode(series_id=1, season=-1, episode_number=1, title="Bad") + + def test_negative_episode_rejected(self): + """Negative episode number triggers validation error.""" + with pytest.raises((ValueError, Exception)): + Episode(series_id=1, season=1, episode_number=-1, title="Bad") + + def test_empty_title(self): + """Empty title may be allowed (depends on implementation).""" + # Some episodes don't have titles + ep = Episode(series_id=1, season=1, episode_number=1, title="") + assert ep.title == "" + + def test_long_file_path(self): + """file_path up to 1000 chars is accepted.""" + path = "/a/b/" + "c" * 990 + ep = Episode( + series_id=1, season=1, episode_number=1, + title="T", file_path=path + ) + assert len(ep.file_path) == 995 + + +# --------------------------------------------------------------------------- +# DownloadQueueItem edge cases +# --------------------------------------------------------------------------- + + +class TestDownloadQueueItemEdgeCases: + """Edge cases for DownloadQueueItem model.""" + + def test_all_status_values(self): + """Every DownloadStatus enum member has a string value.""" + for st in DownloadStatus: + assert isinstance(st.value, str) + # Verify expected members exist + assert DownloadStatus.PENDING.value == "pending" + assert DownloadStatus.DOWNLOADING.value == "downloading" + assert DownloadStatus.COMPLETED.value == "completed" + assert DownloadStatus.FAILED.value == "failed" + + def test_all_priority_values(self): + """Every DownloadPriority enum member has a string value.""" + for p in DownloadPriority: + assert isinstance(p.value, str) + # Verify expected members exist + assert DownloadPriority.LOW.value == "low" + assert DownloadPriority.NORMAL.value == "normal" + assert DownloadPriority.HIGH.value == "high" + + def test_error_message_can_be_none(self): + """error_message defaults to None.""" + item = DownloadQueueItem( + series_id=1, + episode_id=1, + download_url="https://example.com", + file_destination="/tmp/ep.mp4", + ) + assert item.error_message is None + + def test_long_error_message(self): + """A very long error message is stored.""" + msg = "Error: " + "x" * 2000 + item = DownloadQueueItem( + series_id=1, + episode_id=1, + download_url="https://example.com", + file_destination="/tmp/ep.mp4", + error_message=msg, + ) + assert len(item.error_message) > 2000 + + +# --------------------------------------------------------------------------- +# UserSession edge cases +# --------------------------------------------------------------------------- + + +class TestUserSessionEdgeCases: + """Edge cases for UserSession model.""" + + def test_is_expired_property(self): + """Session expired when expires_at is in the past.""" + session = UserSession( + session_id="sess1", + token_hash="hash", + user_id="user1", + ip_address="127.0.0.1", + user_agent="test", + expires_at=datetime.now(timezone.utc) - timedelta(hours=1), + is_active=True, + ) + assert session.is_expired is True + + def test_not_expired(self): + """Session not expired when expires_at is in the future.""" + session = UserSession( + session_id="sess2", + token_hash="hash", + user_id="user1", + ip_address="127.0.0.1", + user_agent="test", + expires_at=datetime.now(timezone.utc) + timedelta(hours=1), + is_active=True, + ) + assert session.is_expired is False + + def test_revoke_sets_inactive(self): + """revoke() sets is_active to False.""" + session = UserSession( + session_id="sess3", + token_hash="hash", + user_id="user1", + ip_address="127.0.0.1", + user_agent="test", + expires_at=datetime.now(timezone.utc) + timedelta(hours=1), + is_active=True, + ) + session.revoke() + assert session.is_active is False + + def test_ipv6_address(self): + """IPv6 address fits in ip_address field (max 45 chars).""" + session = UserSession( + session_id="sess4", + token_hash="hash", + user_id="user1", + ip_address="::ffff:192.168.1.1", + user_agent="test", + expires_at=datetime.now(timezone.utc) + timedelta(hours=1), + is_active=True, + ) + assert session.ip_address == "::ffff:192.168.1.1" + + +# --------------------------------------------------------------------------- +# SystemSettings edge cases +# --------------------------------------------------------------------------- + + +class TestSystemSettingsEdgeCases: + """Edge cases for SystemSettings model.""" + + def test_default_flags(self): + """Default boolean flags are False or None before DB insert.""" + ss = SystemSettings() + # Before DB insert, mapped_column defaults may not be applied + assert ss.initial_scan_completed in (None, False) + assert ss.initial_nfo_scan_completed in (None, False) + assert ss.initial_media_scan_completed in (None, False) + + def test_last_scan_timestamp_nullable(self): + """last_scan_timestamp can be None.""" + ss = SystemSettings() + assert ss.last_scan_timestamp is None + + +# --------------------------------------------------------------------------- +# Large batch simulation +# --------------------------------------------------------------------------- + + +class TestLargeBatch: + """Creating many model instances in a batch.""" + + def test_create_100_episodes(self): + """100 Episode objects can be created without error.""" + episodes = [ + Episode( + series_id=1, + season=1, + episode_number=i, + title=f"Episode {i}", + ) + for i in range(1, 101) + ] + assert len(episodes) == 100 + assert episodes[-1].episode_number == 100 + + def test_create_100_download_items(self): + """100 DownloadQueueItem objects can be created.""" + items = [ + DownloadQueueItem( + series_id=1, + episode_id=i, + download_url=f"https://example.com/{i}", + file_destination=f"/tmp/ep{i}.mp4", + ) + for i in range(100) + ] + assert len(items) == 100 + # All URLs unique + urls = {item.download_url for item in items} + assert len(urls) == 100 + + +# --------------------------------------------------------------------------- +# Foreign key reference integrity (model-level) +# --------------------------------------------------------------------------- + + +class TestForeignKeyReferences: + """Verify FK fields accept valid values.""" + + def test_episode_series_id(self): + """Episode.series_id can reference any positive integer.""" + ep = Episode(series_id=999, season=1, episode_number=1, title="T") + assert ep.series_id == 999 + + def test_download_item_references(self): + """DownloadQueueItem links to series_id and episode_id.""" + item = DownloadQueueItem( + series_id=42, + episode_id=7, + download_url="https://example.com", + file_destination="/tmp/ep.mp4", + ) + assert item.series_id == 42 + assert item.episode_id == 7