diff --git a/instructions.md b/instructions.md index 5b86d70..87f52d8 100644 --- a/instructions.md +++ b/instructions.md @@ -77,13 +77,6 @@ This comprehensive guide ensures a robust, maintainable, and scalable anime down ### 10. Testing -#### [] Create unit tests for services - -- []Create `tests/unit/test_auth_service.py` -- []Create `tests/unit/test_anime_service.py` -- []Create `tests/unit/test_download_service.py` -- []Create `tests/unit/test_config_service.py` - #### [] Create API endpoint tests - []Create `tests/api/test_auth_endpoints.py` diff --git a/tests/unit/test_anime_service.py b/tests/unit/test_anime_service.py index b840aa2..94c031b 100644 --- a/tests/unit/test_anime_service.py +++ b/tests/unit/test_anime_service.py @@ -1,27 +1,332 @@ +"""Unit tests for AnimeService. + +Tests cover service initialization, async operations, caching, +error handling, and progress reporting integration. +""" +from __future__ import annotations + import asyncio +from unittest.mock import AsyncMock, MagicMock, patch import pytest from src.server.services.anime_service import AnimeService, AnimeServiceError +from src.server.services.progress_service import ProgressService -@pytest.mark.asyncio -async def test_list_missing_empty(tmp_path): - svc = AnimeService(directory=str(tmp_path)) - # SeriesApp may return empty list depending on filesystem; ensure it returns a list - result = await svc.list_missing() - assert isinstance(result, list) +@pytest.fixture +def mock_series_app(): + """Create a mock SeriesApp instance.""" + with patch("src.server.services.anime_service.SeriesApp") as mock_class: + mock_instance = MagicMock() + mock_instance.series_list = [] + mock_instance.search = MagicMock(return_value=[]) + mock_instance.ReScan = MagicMock() + mock_instance.download = MagicMock(return_value=True) + mock_class.return_value = mock_instance + yield mock_instance -@pytest.mark.asyncio -async def test_search_empty_query(tmp_path): - svc = AnimeService(directory=str(tmp_path)) - res = await svc.search("") - assert res == [] +@pytest.fixture +def mock_progress_service(): + """Create a mock ProgressService instance.""" + service = MagicMock(spec=ProgressService) + service.start_progress = AsyncMock() + service.update_progress = AsyncMock() + service.complete_progress = AsyncMock() + service.fail_progress = AsyncMock() + return service -@pytest.mark.asyncio -async def test_rescan_and_cache_clear(tmp_path): - svc = AnimeService(directory=str(tmp_path)) - # calling rescan should not raise - await svc.rescan() +@pytest.fixture +def anime_service(tmp_path, mock_series_app, mock_progress_service): + """Create an AnimeService instance for testing.""" + return AnimeService( + directory=str(tmp_path), + max_workers=2, + progress_service=mock_progress_service, + ) + + +class TestAnimeServiceInitialization: + """Test AnimeService initialization.""" + + def test_initialization_success(self, tmp_path, mock_progress_service): + """Test successful service initialization.""" + with patch("src.server.services.anime_service.SeriesApp"): + service = AnimeService( + directory=str(tmp_path), + max_workers=2, + progress_service=mock_progress_service, + ) + + assert service._directory == str(tmp_path) + assert service._executor is not None + assert service._progress_service is mock_progress_service + + def test_initialization_failure_raises_error( + self, tmp_path, mock_progress_service + ): + """Test SeriesApp initialization failure raises error.""" + with patch( + "src.server.services.anime_service.SeriesApp" + ) as mock_class: + mock_class.side_effect = Exception("Initialization failed") + + with pytest.raises( + AnimeServiceError, match="Initialization failed" + ): + AnimeService( + directory=str(tmp_path), + progress_service=mock_progress_service, + ) + + +class TestListMissing: + """Test list_missing operation.""" + + @pytest.mark.asyncio + async def test_list_missing_empty(self, anime_service, mock_series_app): + """Test listing missing episodes when list is empty.""" + mock_series_app.series_list = [] + + result = await anime_service.list_missing() + + assert isinstance(result, list) + assert len(result) == 0 + + @pytest.mark.asyncio + async def test_list_missing_with_series( + self, anime_service, mock_series_app + ): + """Test listing missing episodes with series data.""" + mock_series_app.series_list = [ + {"name": "Test Series 1", "missing": [1, 2]}, + {"name": "Test Series 2", "missing": [3]}, + ] + + result = await anime_service.list_missing() + + assert len(result) == 2 + assert result[0]["name"] == "Test Series 1" + assert result[1]["name"] == "Test Series 2" + + @pytest.mark.asyncio + async def test_list_missing_caching(self, anime_service, mock_series_app): + """Test that list_missing uses caching.""" + mock_series_app.series_list = [{"name": "Test Series"}] + + # First call + result1 = await anime_service.list_missing() + + # Second call (should use cache) + result2 = await anime_service.list_missing() + + assert result1 == result2 + + @pytest.mark.asyncio + async def test_list_missing_error_handling( + self, anime_service, mock_series_app + ): + """Test error handling in list_missing.""" + mock_series_app.series_list = None # Cause an error + + # Error message will be about NoneType not being iterable + with pytest.raises(AnimeServiceError): + await anime_service.list_missing() + + +class TestSearch: + """Test search operation.""" + + @pytest.mark.asyncio + async def test_search_empty_query(self, anime_service): + """Test search with empty query returns empty list.""" + result = await anime_service.search("") + + assert result == [] + + @pytest.mark.asyncio + async def test_search_success(self, anime_service, mock_series_app): + """Test successful search operation.""" + mock_series_app.search.return_value = [ + {"name": "Test Anime", "url": "http://example.com"} + ] + + result = await anime_service.search("test") + + assert len(result) == 1 + assert result[0]["name"] == "Test Anime" + mock_series_app.search.assert_called_once_with("test") + + @pytest.mark.asyncio + async def test_search_error_handling( + self, anime_service, mock_series_app + ): + """Test error handling during search.""" + mock_series_app.search.side_effect = Exception("Search failed") + + with pytest.raises(AnimeServiceError, match="Search failed"): + await anime_service.search("test query") + + +class TestRescan: + """Test rescan operation.""" + + @pytest.mark.asyncio + async def test_rescan_success( + self, anime_service, mock_series_app, mock_progress_service + ): + """Test successful rescan operation.""" + await anime_service.rescan() + + # Verify SeriesApp.ReScan was called + mock_series_app.ReScan.assert_called_once() + + # Verify progress tracking + mock_progress_service.start_progress.assert_called_once() + mock_progress_service.complete_progress.assert_called_once() + + @pytest.mark.asyncio + async def test_rescan_with_callback(self, anime_service, mock_series_app): + """Test rescan with progress callback.""" + callback_called = False + callback_data = None + + def callback(data): + nonlocal callback_called, callback_data + callback_called = True + callback_data = data + + # Mock ReScan to call the callback + def mock_rescan(cb): + if cb: + cb({"current": 5, "total": 10, "message": "Scanning..."}) + + mock_series_app.ReScan.side_effect = mock_rescan + + await anime_service.rescan(callback=callback) + + assert callback_called + assert callback_data is not None + + @pytest.mark.asyncio + async def test_rescan_clears_cache(self, anime_service, mock_series_app): + """Test that rescan clears the list cache.""" + # Populate cache + mock_series_app.series_list = [{"name": "Test"}] + await anime_service.list_missing() + + # Update series list + mock_series_app.series_list = [{"name": "Test"}, {"name": "New"}] + + # Rescan should clear cache + await anime_service.rescan() + + # Next list_missing should return updated data + result = await anime_service.list_missing() + assert len(result) == 2 + + @pytest.mark.asyncio + async def test_rescan_error_handling( + self, anime_service, mock_series_app, mock_progress_service + ): + """Test error handling during rescan.""" + mock_series_app.ReScan.side_effect = Exception("Rescan failed") + + with pytest.raises(AnimeServiceError, match="Rescan failed"): + await anime_service.rescan() + + # Verify progress failure was recorded + mock_progress_service.fail_progress.assert_called_once() + + +class TestDownload: + """Test download operation.""" + + @pytest.mark.asyncio + async def test_download_success(self, anime_service, mock_series_app): + """Test successful download operation.""" + mock_series_app.download.return_value = True + + result = await anime_service.download( + serie_folder="test_series", + season=1, + episode=1, + key="test_key", + ) + + assert result is True + mock_series_app.download.assert_called_once_with( + "test_series", 1, 1, "test_key", None + ) + + @pytest.mark.asyncio + async def test_download_with_callback(self, anime_service, mock_series_app): + """Test download with progress callback.""" + callback = MagicMock() + mock_series_app.download.return_value = True + + result = await anime_service.download( + serie_folder="test_series", + season=1, + episode=1, + key="test_key", + callback=callback, + ) + + assert result is True + # Verify callback was passed to SeriesApp + mock_series_app.download.assert_called_once_with( + "test_series", 1, 1, "test_key", callback + ) + + @pytest.mark.asyncio + async def test_download_error_handling(self, anime_service, mock_series_app): + """Test error handling during download.""" + mock_series_app.download.side_effect = Exception("Download failed") + + with pytest.raises(AnimeServiceError, match="Download failed"): + await anime_service.download( + serie_folder="test_series", + season=1, + episode=1, + key="test_key", + ) + + +class TestConcurrency: + """Test concurrent operations.""" + + @pytest.mark.asyncio + async def test_multiple_concurrent_operations( + self, anime_service, mock_series_app + ): + """Test that multiple operations can run concurrently.""" + mock_series_app.search.return_value = [{"name": "Test"}] + + # Run multiple searches concurrently + tasks = [ + anime_service.search("query1"), + anime_service.search("query2"), + anime_service.search("query3"), + ] + + results = await asyncio.gather(*tasks) + + assert len(results) == 3 + assert all(len(r) == 1 for r in results) + + +class TestFactoryFunction: + """Test factory function.""" + + def test_get_anime_service(self, tmp_path): + """Test get_anime_service factory function.""" + from src.server.services.anime_service import get_anime_service + + with patch("src.server.services.anime_service.SeriesApp"): + service = get_anime_service(directory=str(tmp_path)) + + assert isinstance(service, AnimeService) + assert service._directory == str(tmp_path) diff --git a/tests/unit/test_auth_service.py b/tests/unit/test_auth_service.py index 34c193e..746fd8f 100644 --- a/tests/unit/test_auth_service.py +++ b/tests/unit/test_auth_service.py @@ -1,59 +1,303 @@ +"""Unit tests for AuthService. + +Tests cover password setup and validation, JWT token operations, +session management, lockout mechanism, and error handling. +""" +from datetime import datetime, timedelta + import pytest from src.server.services.auth_service import AuthError, AuthService, LockedOutError -def test_setup_and_validate_success(): - svc = AuthService() - password = "Str0ng!Pass" - svc.setup_master_password(password) - assert svc.is_configured() +class TestPasswordSetup: + """Test password setup and validation.""" - assert svc.validate_master_password(password) is True + def test_setup_and_validate_success(self): + """Test successful password setup and validation.""" + svc = AuthService() + password = "Str0ng!Pass" + svc.setup_master_password(password) + + assert svc.is_configured() + assert svc.validate_master_password(password) is True - resp = svc.create_access_token(subject="tester", remember=False) - assert resp.token_type == "bearer" - assert resp.access_token + @pytest.mark.parametrize( + "bad", + [ + "short", + "lowercaseonly", + "UPPERCASEONLY", + "NoSpecial1", + ], + ) + def test_setup_weak_passwords(self, bad): + """Test that weak passwords are rejected.""" + svc = AuthService() + with pytest.raises(ValueError): + svc.setup_master_password(bad) - sess = svc.create_session_model(resp.access_token) - assert sess.expires_at is not None + def test_password_length_validation(self): + """Test minimum password length validation.""" + svc = AuthService() + with pytest.raises(ValueError, match="at least 8 characters"): + svc.setup_master_password("Short1!") + + def test_password_case_validation(self): + """Test mixed case requirement.""" + svc = AuthService() + with pytest.raises(ValueError, match="mixed case"): + svc.setup_master_password("alllowercase1!") + + with pytest.raises(ValueError, match="mixed case"): + svc.setup_master_password("ALLUPPERCASE1!") + + def test_password_special_char_validation(self): + """Test special character requirement.""" + svc = AuthService() + with pytest.raises( + ValueError, match="symbol or punctuation" + ): + svc.setup_master_password("NoSpecial123") + + def test_validate_without_setup_raises_error(self): + """Test validation without password setup raises error.""" + svc = AuthService() + # Clear any hash that might come from settings + svc._hash = None + + with pytest.raises(AuthError, match="not configured"): + svc.validate_master_password("anypassword") + + def test_validate_wrong_password(self): + """Test validation with wrong password.""" + svc = AuthService() + svc.setup_master_password("Correct!Pass123") + + assert svc.validate_master_password("Wrong!Pass123") is False -@pytest.mark.parametrize( - "bad", - [ - "short", - "lowercaseonly", - "UPPERCASEONLY", - "NoSpecial1", - ], -) -def test_setup_weak_passwords(bad): - svc = AuthService() - with pytest.raises(ValueError): - svc.setup_master_password(bad) +class TestFailedAttemptsAndLockout: + """Test failed login attempts and lockout mechanism.""" + def test_failed_attempts_and_lockout(self): + """Test lockout after max failed attempts.""" + svc = AuthService() + password = "An0ther$Good1" + svc.setup_master_password(password) -def test_failed_attempts_and_lockout(): - svc = AuthService() - password = "An0ther$Good1" - svc.setup_master_password(password) + identifier = "test-ip" + # fail max_attempts times + for _ in range(svc.max_attempts): + assert ( + svc.validate_master_password( + "wrongpassword", identifier=identifier + ) + is False + ) - identifier = "test-ip" - # fail max_attempts times - for _ in range(svc.max_attempts): + # Next attempt must raise LockedOutError + with pytest.raises(LockedOutError): + svc.validate_master_password(password, identifier=identifier) + + def test_lockout_different_identifiers(self): + """Test that lockout is per identifier.""" + svc = AuthService() + password = "Valid!Pass123" + svc.setup_master_password(password) + + # Fail attempts for identifier1 + for _ in range(svc.max_attempts): + svc.validate_master_password("wrong", identifier="id1") + + # identifier1 should be locked + with pytest.raises(LockedOutError): + svc.validate_master_password(password, identifier="id1") + + # identifier2 should still work assert ( - svc.validate_master_password("wrongpassword", identifier=identifier) - is False + svc.validate_master_password(password, identifier="id2") + is True ) - # Next attempt must raise LockedOutError - with pytest.raises(LockedOutError): - svc.validate_master_password(password, identifier=identifier) + def test_successful_login_clears_failures(self): + """Test that successful login clears failure count.""" + svc = AuthService() + password = "Valid!Pass123" + svc.setup_master_password(password) + + identifier = "test-ip" + # Fail a few times (but not enough to lock) + for _ in range(svc.max_attempts - 1): + svc.validate_master_password("wrong", identifier=identifier) + + # Successful login should clear failures + assert ( + svc.validate_master_password(password, identifier=identifier) + is True + ) + + # Should be able to fail again without lockout + for _ in range(svc.max_attempts - 1): + svc.validate_master_password("wrong", identifier=identifier) + + # Should still not be locked + assert ( + svc.validate_master_password(password, identifier=identifier) + is True + ) -def test_token_decode_invalid(): - svc = AuthService() - # invalid token should raise AuthError - with pytest.raises(AuthError): - svc.decode_token("not-a-jwt") +class TestJWTTokens: + """Test JWT token creation and validation.""" + + def test_create_access_token(self): + """Test JWT token creation.""" + svc = AuthService() + password = "Str0ng!Pass" + svc.setup_master_password(password) + + resp = svc.create_access_token(subject="tester", remember=False) + + assert resp.token_type == "bearer" + assert resp.access_token + assert resp.expires_at is not None + + def test_create_token_with_remember(self): + """Test JWT token with remember=True has longer expiry.""" + svc = AuthService() + password = "Str0ng!Pass" + svc.setup_master_password(password) + + resp_normal = svc.create_access_token( + subject="tester", remember=False + ) + resp_remember = svc.create_access_token( + subject="tester", remember=True + ) + + # Remember token should expire later + assert resp_remember.expires_at > resp_normal.expires_at + + def test_decode_valid_token(self): + """Test decoding valid JWT token.""" + svc = AuthService() + password = "Str0ng!Pass" + svc.setup_master_password(password) + + resp = svc.create_access_token(subject="tester", remember=False) + decoded = svc.decode_token(resp.access_token) + + assert decoded["sub"] == "tester" + assert "exp" in decoded + assert "iat" in decoded + + def test_token_decode_invalid(self): + """Test that invalid token raises AuthError.""" + svc = AuthService() + + with pytest.raises(AuthError): + svc.decode_token("not-a-jwt") + + def test_decode_malformed_token(self): + """Test decoding malformed JWT token.""" + svc = AuthService() + + with pytest.raises(AuthError): + svc.decode_token("header.payload.signature") + + def test_decode_expired_token(self): + """Test decoding expired token.""" + svc = AuthService() + password = "Str0ng!Pass" + svc.setup_master_password(password) + + # Create a token with past expiry + from jose import jwt + + expired_payload = { + "sub": "tester", + "exp": int((datetime.utcnow() - timedelta(hours=1)).timestamp()), + "iat": int(datetime.utcnow().timestamp()), + } + expired_token = jwt.encode( + expired_payload, svc.secret, algorithm="HS256" + ) + + with pytest.raises(AuthError): + svc.decode_token(expired_token) + + +class TestSessionManagement: + """Test session model creation and management.""" + + def test_create_session_model(self): + """Test session model creation from token.""" + svc = AuthService() + password = "Str0ng!Pass" + svc.setup_master_password(password) + + resp = svc.create_access_token(subject="tester", remember=False) + sess = svc.create_session_model(resp.access_token) + + assert sess.session_id + assert sess.user == "tester" + assert sess.expires_at is not None + + def test_session_id_deterministic(self): + """Test that same token produces same session ID.""" + svc = AuthService() + password = "Str0ng!Pass" + svc.setup_master_password(password) + + resp = svc.create_access_token(subject="tester", remember=False) + sess1 = svc.create_session_model(resp.access_token) + sess2 = svc.create_session_model(resp.access_token) + + assert sess1.session_id == sess2.session_id + + def test_revoke_token(self): + """Test token revocation (placeholder).""" + svc = AuthService() + password = "Str0ng!Pass" + svc.setup_master_password(password) + + resp = svc.create_access_token(subject="tester", remember=False) + + # Currently a no-op, should not raise + result = svc.revoke_token(resp.access_token) + assert result is None + + +class TestServiceConfiguration: + """Test service configuration and initialization.""" + + def test_is_configured_initial_state(self): + """Test initial unconfigured state.""" + svc = AuthService() + # Clear any hash that might come from settings + svc._hash = None + + assert svc.is_configured() is False + + def test_is_configured_after_setup(self): + """Test configured state after setup.""" + svc = AuthService() + svc.setup_master_password("Valid!Pass123") + assert svc.is_configured() is True + + def test_custom_lockout_settings(self): + """Test custom lockout configuration.""" + svc = AuthService() + + # Verify default values + assert svc.max_attempts == 5 + assert svc.lockout_seconds == 300 + assert svc.token_expiry_hours == 24 + + # Custom settings should be modifiable + svc.max_attempts = 3 + svc.lockout_seconds = 600 + + assert svc.max_attempts == 3 + assert svc.lockout_seconds == 600