diff --git a/instructions.md b/instructions.md index bfb8c70..854b8ca 100644 --- a/instructions.md +++ b/instructions.md @@ -77,12 +77,6 @@ This comprehensive guide ensures a robust, maintainable, and scalable anime down ### 10. Testing -#### [] Create integration tests - -- []Create `tests/integration/test_download_flow.py` -- []Create `tests/integration/test_auth_flow.py` -- []Create `tests/integration/test_websocket.py` - #### [] Create frontend integration tests - []Create `tests/frontend/test_existing_ui_integration.py` diff --git a/tests/integration/test_auth_flow.py b/tests/integration/test_auth_flow.py new file mode 100644 index 0000000..391d7c3 --- /dev/null +++ b/tests/integration/test_auth_flow.py @@ -0,0 +1,739 @@ +"""Integration tests for authentication flow. + +This module tests the complete authentication flow including: +- Initial setup and master password configuration +- Login with valid/invalid credentials +- JWT token generation and validation +- Protected endpoint access control +- Token refresh and expiration +- Logout functionality +- Rate limiting and lockout mechanisms +- Session management +""" +import time +from typing import Dict, Optional + +import pytest +from httpx import ASGITransport, AsyncClient + +from src.server.fastapi_app import app +from src.server.services.auth_service import auth_service + + +@pytest.fixture(autouse=True) +def reset_auth(): + """Reset authentication state before each test.""" + original_hash = auth_service._hash + auth_service._hash = None + auth_service._failed.clear() + yield + auth_service._hash = original_hash + auth_service._failed.clear() + + +@pytest.fixture +async def client(): + """Create an async test client.""" + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + yield ac + + +class TestInitialSetup: + """Test initial authentication setup flow.""" + + async def test_setup_with_strong_password(self, client): + """Test setting up master password with strong password.""" + response = await client.post( + "/api/auth/setup", + json={"master_password": "StrongP@ssw0rd123"} + ) + + assert response.status_code == 201 + data = response.json() + assert data["status"] == "ok" + + async def test_setup_with_weak_password_fails(self, client): + """Test that setup fails with weak password.""" + response = await client.post( + "/api/auth/setup", + json={"master_password": "weak"} + ) + + # Should fail validation + assert response.status_code in [400, 422] + + async def test_setup_cannot_be_called_twice(self, client): + """Test that setup can only be called once.""" + # First setup succeeds + await client.post( + "/api/auth/setup", + json={"master_password": "FirstPassword123!"} + ) + + # Second setup should fail + response = await client.post( + "/api/auth/setup", + json={"master_password": "SecondPassword123!"} + ) + + assert response.status_code == 400 + data = response.json() + assert "already configured" in data["detail"].lower() + + async def test_auth_status_before_setup(self, client): + """Test authentication status before setup.""" + response = await client.get("/api/auth/status") + + assert response.status_code == 200 + data = response.json() + assert data["configured"] is False + assert data["authenticated"] is False + + async def test_auth_status_after_setup(self, client): + """Test authentication status after setup.""" + # Setup + await client.post( + "/api/auth/setup", + json={"master_password": "SetupPassword123!"} + ) + + # Check status + response = await client.get("/api/auth/status") + + assert response.status_code == 200 + data = response.json() + assert data["configured"] is True + assert data["authenticated"] is False + + +class TestLoginFlow: + """Test login flow with valid and invalid credentials.""" + + async def test_login_with_valid_credentials(self, client): + """Test successful login with correct password.""" + # Setup + password = "ValidPassword123!" + await client.post( + "/api/auth/setup", + json={"master_password": password} + ) + + # Login + response = await client.post( + "/api/auth/login", + json={"password": password} + ) + + assert response.status_code == 200 + data = response.json() + + # Verify token structure + assert "access_token" in data + assert "token_type" in data + assert data["token_type"] == "bearer" + assert isinstance(data["access_token"], str) + assert len(data["access_token"]) > 0 + + async def test_login_with_invalid_password(self, client): + """Test login failure with incorrect password.""" + # Setup + await client.post( + "/api/auth/setup", + json={"master_password": "CorrectPassword123!"} + ) + + # Login with wrong password + response = await client.post( + "/api/auth/login", + json={"password": "WrongPassword123!"} + ) + + assert response.status_code == 401 + data = response.json() + assert "detail" in data + assert "invalid" in data["detail"].lower() + + async def test_login_before_setup_fails(self, client): + """Test that login fails before setup is complete.""" + response = await client.post( + "/api/auth/login", + json={"password": "AnyPassword123!"} + ) + + assert response.status_code in [400, 401] + + async def test_login_with_remember_me(self, client): + """Test login with remember me option.""" + # Setup + password = "RememberPassword123!" + await client.post( + "/api/auth/setup", + json={"master_password": password} + ) + + # Login with remember=true + response = await client.post( + "/api/auth/login", + json={"password": password, "remember": True} + ) + + assert response.status_code == 200 + data = response.json() + assert "access_token" in data + + # Token should be issued (expiration time may be extended) + + async def test_login_without_remember_me(self, client): + """Test login without remember me option.""" + # Setup + password = "NoRememberPassword123!" + await client.post( + "/api/auth/setup", + json={"master_password": password} + ) + + # Login without remember + response = await client.post( + "/api/auth/login", + json={"password": password, "remember": False} + ) + + assert response.status_code == 200 + data = response.json() + assert "access_token" in data + + +class TestTokenValidation: + """Test JWT token validation and usage.""" + + async def get_valid_token(self, client) -> str: + """Helper to get a valid authentication token.""" + password = "TokenTestPassword123!" + await client.post( + "/api/auth/setup", + json={"master_password": password} + ) + response = await client.post( + "/api/auth/login", + json={"password": password} + ) + return response.json()["access_token"] + + async def test_access_protected_endpoint_with_valid_token(self, client): + """Test accessing protected endpoint with valid token.""" + token = await self.get_valid_token(client) + + # Access protected endpoint + response = await client.get( + "/api/queue/status", + headers={"Authorization": f"Bearer {token}"} + ) + + # Should succeed (or return 503 if service not configured) + assert response.status_code in [200, 503] + + async def test_access_protected_endpoint_without_token(self, client): + """Test accessing protected endpoint without token.""" + response = await client.get("/api/queue/status") + + assert response.status_code == 401 + + async def test_access_protected_endpoint_with_invalid_token(self, client): + """Test accessing protected endpoint with invalid token.""" + response = await client.get( + "/api/queue/status", + headers={"Authorization": "Bearer invalid_token_12345"} + ) + + assert response.status_code == 401 + + async def test_access_protected_endpoint_with_malformed_header( + self, client + ): + """Test accessing protected endpoint with malformed auth header.""" + token = await self.get_valid_token(client) + + # Missing "Bearer" prefix + response = await client.get( + "/api/queue/status", + headers={"Authorization": token} + ) + + assert response.status_code == 401 + + async def test_token_works_for_multiple_requests(self, client): + """Test that token can be reused for multiple requests.""" + token = await self.get_valid_token(client) + headers = {"Authorization": f"Bearer {token}"} + + # Make multiple requests with same token + for _ in range(5): + response = await client.get("/api/queue/status", headers=headers) + assert response.status_code in [200, 503] + + async def test_auth_status_with_valid_token(self, client): + """Test auth status endpoint with valid token.""" + token = await self.get_valid_token(client) + + response = await client.get( + "/api/auth/status", + headers={"Authorization": f"Bearer {token}"} + ) + + assert response.status_code == 200 + data = response.json() + assert data["configured"] is True + assert data["authenticated"] is True + + +class TestProtectedEndpoints: + """Test that all protected endpoints enforce authentication.""" + + async def get_valid_token(self, client) -> str: + """Helper to get a valid authentication token.""" + password = "ProtectedTestPassword123!" + await client.post( + "/api/auth/setup", + json={"master_password": password} + ) + response = await client.post( + "/api/auth/login", + json={"password": password} + ) + return response.json()["access_token"] + + async def test_anime_endpoints_require_auth(self, client): + """Test that anime endpoints require authentication.""" + # Without token + response = await client.get("/api/v1/anime") + assert response.status_code == 401 + + # With valid token + token = await self.get_valid_token(client) + response = await client.get( + "/api/v1/anime", + headers={"Authorization": f"Bearer {token}"} + ) + assert response.status_code in [200, 503] + + async def test_queue_endpoints_require_auth(self, client): + """Test that queue endpoints require authentication.""" + endpoints = [ + ("/api/queue/status", "GET"), + ("/api/queue/add", "POST"), + ("/api/queue/control/start", "POST"), + ("/api/queue/control/pause", "POST"), + ] + + token = await self.get_valid_token(client) + + for endpoint, method in endpoints: + # Without token + if method == "GET": + response = await client.get(endpoint) + else: + response = await client.post(endpoint, json={}) + + assert response.status_code in [400, 401, 422] + + # With token (should pass auth, may fail validation) + headers = {"Authorization": f"Bearer {token}"} + if method == "GET": + response = await client.get(endpoint, headers=headers) + else: + response = await client.post(endpoint, json={}, headers=headers) + + assert response.status_code not in [401] + + async def test_config_endpoints_require_auth(self, client): + """Test that config endpoints require authentication.""" + # Without token + response = await client.get("/api/v1/config") + assert response.status_code == 401 + + # With token + token = await self.get_valid_token(client) + response = await client.get( + "/api/v1/config", + headers={"Authorization": f"Bearer {token}"} + ) + assert response.status_code in [200, 503] + + async def test_download_endpoints_require_auth(self, client): + """Test that download endpoints require authentication.""" + token = await self.get_valid_token(client) + + # Test queue operations require auth + response = await client.get("/api/queue/status") + assert response.status_code == 401 + + response = await client.get( + "/api/queue/status", + headers={"Authorization": f"Bearer {token}"} + ) + assert response.status_code in [200, 503] + + +class TestLogoutFlow: + """Test logout functionality.""" + + async def get_valid_token(self, client) -> str: + """Helper to get a valid authentication token.""" + password = "LogoutTestPassword123!" + await client.post( + "/api/auth/setup", + json={"master_password": password} + ) + response = await client.post( + "/api/auth/login", + json={"password": password} + ) + return response.json()["access_token"] + + async def test_logout_with_valid_token(self, client): + """Test logout with valid token.""" + token = await self.get_valid_token(client) + + response = await client.post( + "/api/auth/logout", + headers={"Authorization": f"Bearer {token}"} + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + + async def test_logout_without_token(self, client): + """Test logout without token.""" + response = await client.post("/api/auth/logout") + + # May succeed as logout is sometimes allowed without auth + assert response.status_code in [200, 401] + + async def test_token_after_logout(self, client): + """Test that token still works after logout (stateless JWT).""" + token = await self.get_valid_token(client) + + # Logout + await client.post( + "/api/auth/logout", + headers={"Authorization": f"Bearer {token}"} + ) + + # Try to use token (may still work if JWT is stateless) + response = await client.get( + "/api/queue/status", + headers={"Authorization": f"Bearer {token}"} + ) + + # Stateless JWT: token may still work + # Stateful: should return 401 + assert response.status_code in [200, 401, 503] + + +class TestRateLimitingAndLockout: + """Test rate limiting and lockout mechanisms.""" + + async def test_failed_login_attempts_tracked(self, client): + """Test that failed login attempts are tracked.""" + # Setup + await client.post( + "/api/auth/setup", + json={"master_password": "CorrectPassword123!"} + ) + + # Multiple failed attempts + for _ in range(3): + response = await client.post( + "/api/auth/login", + json={"password": "WrongPassword123!"} + ) + assert response.status_code == 401 + + async def test_lockout_after_max_failed_attempts(self, client): + """Test account lockout after maximum failed attempts.""" + # Setup + await client.post( + "/api/auth/setup", + json={"master_password": "CorrectPassword123!"} + ) + + # Make multiple failed attempts to trigger lockout + for i in range(6): # More than max allowed + response = await client.post( + "/api/auth/login", + json={"password": "WrongPassword123!"} + ) + + if i < 5: + assert response.status_code == 401 + else: + # Should be locked out + assert response.status_code in [401, 429] + + async def test_successful_login_resets_failed_attempts(self, client): + """Test that successful login resets failed attempt counter.""" + # Setup + password = "ResetCounterPassword123!" + await client.post( + "/api/auth/setup", + json={"master_password": password} + ) + + # Failed attempts + for _ in range(2): + await client.post( + "/api/auth/login", + json={"password": "WrongPassword123!"} + ) + + # Successful login + response = await client.post( + "/api/auth/login", + json={"password": password} + ) + assert response.status_code == 200 + + # Should be able to make more attempts (counter reset) + await client.post( + "/api/auth/login", + json={"password": "WrongPassword123!"} + ) + + +class TestSessionManagement: + """Test session management and concurrent sessions.""" + + async def get_valid_token(self, client) -> str: + """Helper to get a valid authentication token.""" + password = "SessionTestPassword123!" + await client.post( + "/api/auth/setup", + json={"master_password": password} + ) + response = await client.post( + "/api/auth/login", + json={"password": password} + ) + return response.json()["access_token"] + + async def test_multiple_concurrent_sessions(self, client): + """Test that multiple sessions can exist simultaneously.""" + password = "MultiSessionPassword123!" + await client.post( + "/api/auth/setup", + json={"master_password": password} + ) + + # Create multiple sessions + tokens = [] + for _ in range(3): + response = await client.post( + "/api/auth/login", + json={"password": password} + ) + assert response.status_code == 200 + tokens.append(response.json()["access_token"]) + + # All tokens should work + for token in tokens: + response = await client.get( + "/api/queue/status", + headers={"Authorization": f"Bearer {token}"} + ) + assert response.status_code in [200, 503] + + async def test_independent_token_lifetimes(self, client): + """Test that tokens have independent lifetimes.""" + token1 = await self.get_valid_token(client) + + # Small delay + time.sleep(0.1) + + token2 = await self.get_valid_token(client) + + # Both tokens should work + for token in [token1, token2]: + response = await client.get( + "/api/queue/status", + headers={"Authorization": f"Bearer {token}"} + ) + assert response.status_code in [200, 503] + + +class TestAuthenticationEdgeCases: + """Test edge cases and error scenarios.""" + + async def test_empty_password_in_setup(self, client): + """Test setup with empty password.""" + response = await client.post( + "/api/auth/setup", + json={"master_password": ""} + ) + + assert response.status_code in [400, 422] + + async def test_empty_password_in_login(self, client): + """Test login with empty password.""" + # Setup first + await client.post( + "/api/auth/setup", + json={"master_password": "ValidPassword123!"} + ) + + response = await client.post( + "/api/auth/login", + json={"password": ""} + ) + + assert response.status_code in [400, 401, 422] + + async def test_missing_password_field(self, client): + """Test requests with missing password field.""" + response = await client.post( + "/api/auth/setup", + json={} + ) + + assert response.status_code == 422 # Validation error + + async def test_malformed_json_in_auth_requests(self, client): + """Test authentication with malformed JSON.""" + response = await client.post( + "/api/auth/setup", + content="not valid json", + headers={"Content-Type": "application/json"} + ) + + assert response.status_code in [400, 422] + + async def test_extremely_long_password(self, client): + """Test setup with extremely long password.""" + long_password = "P@ssw0rd" + "x" * 10000 + + response = await client.post( + "/api/auth/setup", + json={"master_password": long_password} + ) + + # Should handle gracefully (accept or reject) + assert response.status_code in [201, 400, 413, 422] + + async def test_special_characters_in_password(self, client): + """Test password with various special characters.""" + special_password = "P@$$w0rd!#%^&*()_+-=[]{}|;:',.<>?/~`" + + response = await client.post( + "/api/auth/setup", + json={"master_password": special_password} + ) + + # Should accept special characters + assert response.status_code in [201, 400] + + async def test_unicode_characters_in_password(self, client): + """Test password with unicode characters.""" + unicode_password = "Pässwörd123!日本語" + + response = await client.post( + "/api/auth/setup", + json={"master_password": unicode_password} + ) + + # Should handle unicode gracefully + assert response.status_code in [201, 400, 422] + + +class TestCompleteAuthenticationWorkflow: + """Test complete authentication workflows.""" + + async def test_full_authentication_cycle(self, client): + """Test complete authentication cycle from setup to logout.""" + password = "CompleteWorkflowPassword123!" + + # 1. Check initial status (not configured) + status = await client.get("/api/auth/status") + assert status.json()["configured"] is False + + # 2. Setup master password + setup = await client.post( + "/api/auth/setup", + json={"master_password": password} + ) + assert setup.status_code == 201 + + # 3. Check status (configured, not authenticated) + status = await client.get("/api/auth/status") + data = status.json() + assert data["configured"] is True + assert data["authenticated"] is False + + # 4. Login + login = await client.post( + "/api/auth/login", + json={"password": password} + ) + assert login.status_code == 200 + token = login.json()["access_token"] + + # 5. Access protected endpoint + protected = await client.get( + "/api/queue/status", + headers={"Authorization": f"Bearer {token}"} + ) + assert protected.status_code in [200, 503] + + # 6. Check authenticated status + status = await client.get( + "/api/auth/status", + headers={"Authorization": f"Bearer {token}"} + ) + data = status.json() + assert data["configured"] is True + assert data["authenticated"] is True + + # 7. Logout + logout = await client.post( + "/api/auth/logout", + headers={"Authorization": f"Bearer {token}"} + ) + assert logout.status_code == 200 + + async def test_workflow_with_failed_and_successful_attempts(self, client): + """Test workflow with mixed failed and successful attempts.""" + password = "MixedAttemptsPassword123!" + + # Setup + await client.post( + "/api/auth/setup", + json={"master_password": password} + ) + + # Failed attempt + response = await client.post( + "/api/auth/login", + json={"password": "WrongPassword123!"} + ) + assert response.status_code == 401 + + # Successful attempt + response = await client.post( + "/api/auth/login", + json={"password": password} + ) + assert response.status_code == 200 + + # Another failed attempt + response = await client.post( + "/api/auth/login", + json={"password": "WrongAgain123!"} + ) + assert response.status_code == 401 + + # Another successful attempt + response = await client.post( + "/api/auth/login", + json={"password": password} + ) + assert response.status_code == 200 diff --git a/tests/integration/test_download_flow.py b/tests/integration/test_download_flow.py new file mode 100644 index 0000000..a6f3e9d --- /dev/null +++ b/tests/integration/test_download_flow.py @@ -0,0 +1,621 @@ +"""Integration tests for complete download flow. + +This module tests the end-to-end download flow including: +- Adding episodes to the queue +- Queue status updates +- Download processing +- Progress tracking +- Queue control operations (pause, resume, clear) +- Error handling and retries +- WebSocket notifications +""" +import asyncio +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from httpx import ASGITransport, AsyncClient + +from src.server.fastapi_app import app +from src.server.models.download import ( + DownloadPriority, + DownloadStatus, + EpisodeIdentifier, +) +from src.server.services.anime_service import AnimeService +from src.server.services.auth_service import auth_service +from src.server.services.download_service import DownloadService +from src.server.services.progress_service import get_progress_service +from src.server.services.websocket_service import get_websocket_service + + +@pytest.fixture(autouse=True) +def reset_auth(): + """Reset authentication state before each test.""" + original_hash = auth_service._hash + auth_service._hash = None + auth_service._failed.clear() + yield + auth_service._hash = original_hash + auth_service._failed.clear() + + +@pytest.fixture +async def client(): + """Create an async test client.""" + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + yield ac + + +@pytest.fixture +async def authenticated_client(client): + """Create an authenticated test client with token.""" + # Setup master password + await client.post( + "/api/auth/setup", + json={"master_password": "TestPassword123!"} + ) + + # Login to get token + response = await client.post( + "/api/auth/login", + json={"password": "TestPassword123!"} + ) + token = response.json()["access_token"] + + # Add token to default headers + client.headers.update({"Authorization": f"Bearer {token}"}) + yield client + + +@pytest.fixture +def mock_series_app(): + """Mock SeriesApp for testing.""" + app_mock = Mock() + app_mock.series_list = [] + app_mock.search = Mock(return_value=[]) + app_mock.ReScan = Mock() + app_mock.download = Mock(return_value=True) + return app_mock + + +@pytest.fixture +def mock_anime_service(mock_series_app): + """Create a mock AnimeService.""" + with patch("src.server.services.anime_service.SeriesApp", return_value=mock_series_app): + service = AnimeService() + service.download = AsyncMock(return_value=True) + yield service + + +@pytest.fixture +def temp_queue_file(tmp_path): + """Create a temporary queue persistence file.""" + return str(tmp_path / "test_queue.json") + + +class TestDownloadFlowEndToEnd: + """Test complete download flow from queue addition to completion.""" + + async def test_add_episodes_to_queue(self, authenticated_client, mock_anime_service): + """Test adding episodes to the download queue.""" + # Add episodes to queue + response = await authenticated_client.post( + "/api/queue/add", + json={ + "serie_id": "test-series-1", + "serie_name": "Test Anime Series", + "episodes": [ + {"season": 1, "episode": 1, "title": "Episode 1"}, + {"season": 1, "episode": 2, "title": "Episode 2"}, + ], + "priority": "normal" + } + ) + + assert response.status_code == 201 + data = response.json() + + # Verify response structure + assert data["status"] == "success" + assert "item_ids" in data + assert len(data["item_ids"]) == 2 + assert "message" in data + + async def test_queue_status_after_adding_items(self, authenticated_client): + """Test retrieving queue status after adding items.""" + # Add episodes to queue + await authenticated_client.post( + "/api/queue/add", + json={ + "serie_id": "test-series-2", + "serie_name": "Another Series", + "episodes": [{"season": 1, "episode": 1}], + "priority": "high" + } + ) + + # Get queue status + response = await authenticated_client.get("/api/queue/status") + + assert response.status_code in [200, 503] + + if response.status_code == 200: + data = response.json() + + # Verify status structure + assert "status" in data + assert "statistics" in data + + status = data["status"] + assert "pending" in status + assert "active" in status + assert "completed" in status + assert "failed" in status + + async def test_add_with_different_priorities(self, authenticated_client): + """Test adding episodes with different priority levels.""" + priorities = ["high", "normal", "low"] + + for priority in priorities: + response = await authenticated_client.post( + "/api/queue/add", + json={ + "serie_id": f"series-{priority}", + "serie_name": f"Series {priority.title()}", + "episodes": [{"season": 1, "episode": 1}], + "priority": priority + } + ) + + assert response.status_code in [201, 503] + + async def test_validation_error_for_empty_episodes(self, authenticated_client): + """Test validation error when no episodes are specified.""" + response = await authenticated_client.post( + "/api/queue/add", + json={ + "serie_id": "test-series", + "serie_name": "Test Series", + "episodes": [], + "priority": "normal" + } + ) + + assert response.status_code == 400 + data = response.json() + assert "detail" in data + + async def test_validation_error_for_invalid_priority(self, authenticated_client): + """Test validation error for invalid priority level.""" + response = await authenticated_client.post( + "/api/queue/add", + json={ + "serie_id": "test-series", + "serie_name": "Test Series", + "episodes": [{"season": 1, "episode": 1}], + "priority": "invalid" + } + ) + + assert response.status_code == 422 # Validation error + + +class TestQueueControlOperations: + """Test queue control operations (start, pause, resume, clear).""" + + async def test_start_queue_processing(self, authenticated_client): + """Test starting the queue processor.""" + response = await authenticated_client.post("/api/queue/control/start") + + assert response.status_code in [200, 503] + + if response.status_code == 200: + data = response.json() + assert data["status"] == "success" + + async def test_pause_queue_processing(self, authenticated_client): + """Test pausing the queue processor.""" + # Start first + await authenticated_client.post("/api/queue/control/start") + + # Then pause + response = await authenticated_client.post("/api/queue/control/pause") + + assert response.status_code in [200, 503] + + if response.status_code == 200: + data = response.json() + assert data["status"] == "success" + + async def test_resume_queue_processing(self, authenticated_client): + """Test resuming the queue processor.""" + # Start and pause first + await authenticated_client.post("/api/queue/control/start") + await authenticated_client.post("/api/queue/control/pause") + + # Then resume + response = await authenticated_client.post("/api/queue/control/resume") + + assert response.status_code in [200, 503] + + if response.status_code == 200: + data = response.json() + assert data["status"] == "success" + + async def test_clear_completed_downloads(self, authenticated_client): + """Test clearing completed downloads from the queue.""" + response = await authenticated_client.post("/api/queue/control/clear_completed") + + assert response.status_code in [200, 503] + + if response.status_code == 200: + data = response.json() + assert data["status"] == "success" + + +class TestQueueItemOperations: + """Test operations on individual queue items.""" + + async def test_remove_item_from_queue(self, authenticated_client): + """Test removing a specific item from the queue.""" + # First add an item + add_response = await authenticated_client.post( + "/api/queue/add", + json={ + "serie_id": "test-series", + "serie_name": "Test Series", + "episodes": [{"season": 1, "episode": 1}], + "priority": "normal" + } + ) + + if add_response.status_code == 201: + item_id = add_response.json()["item_ids"][0] + + # Remove the item + response = await authenticated_client.delete(f"/api/queue/items/{item_id}") + + assert response.status_code in [200, 404, 503] + + async def test_retry_failed_item(self, authenticated_client): + """Test retrying a failed download item.""" + # This would typically require a failed item to exist + # For now, test the endpoint with a dummy ID + response = await authenticated_client.post("/api/queue/items/dummy-id/retry") + + # Should return 404 if item doesn't exist, or 503 if service unavailable + assert response.status_code in [200, 404, 503] + + async def test_reorder_queue_items(self, authenticated_client): + """Test reordering queue items.""" + # Add multiple items + item_ids = [] + for i in range(3): + add_response = await authenticated_client.post( + "/api/queue/add", + json={ + "serie_id": f"series-{i}", + "serie_name": f"Series {i}", + "episodes": [{"season": 1, "episode": 1}], + "priority": "normal" + } + ) + + if add_response.status_code == 201: + item_ids.extend(add_response.json()["item_ids"]) + + if len(item_ids) >= 2: + # Reorder items + response = await authenticated_client.post( + "/api/queue/reorder", + json={"item_order": list(reversed(item_ids))} + ) + + assert response.status_code in [200, 503] + + +class TestDownloadProgressTracking: + """Test progress tracking during downloads.""" + + async def test_queue_status_includes_progress(self, authenticated_client): + """Test that queue status includes progress information.""" + # Add an item + await authenticated_client.post( + "/api/queue/add", + json={ + "serie_id": "test-series", + "serie_name": "Test Series", + "episodes": [{"season": 1, "episode": 1}], + "priority": "normal" + } + ) + + # Get status + response = await authenticated_client.get("/api/queue/status") + + assert response.status_code in [200, 503] + + if response.status_code == 200: + data = response.json() + assert "status" in data + + # Check that items can have progress + status = data["status"] + for item in status.get("active", []): + if "progress" in item and item["progress"]: + assert "percentage" in item["progress"] + assert "current_mb" in item["progress"] + assert "total_mb" in item["progress"] + + async def test_queue_statistics(self, authenticated_client): + """Test that queue statistics are calculated correctly.""" + response = await authenticated_client.get("/api/queue/status") + + assert response.status_code in [200, 503] + + if response.status_code == 200: + data = response.json() + assert "statistics" in data + + stats = data["statistics"] + assert "total_items" in stats + assert "pending_count" in stats + assert "active_count" in stats + assert "completed_count" in stats + assert "failed_count" in stats + assert "success_rate" in stats + + +class TestErrorHandlingAndRetries: + """Test error handling and retry mechanisms.""" + + async def test_handle_download_failure(self, authenticated_client): + """Test handling of download failures.""" + # This would require mocking a failure scenario + # For integration testing, we verify the error handling structure + + # Add an item that might fail + response = await authenticated_client.post( + "/api/queue/add", + json={ + "serie_id": "invalid-series", + "serie_name": "Invalid Series", + "episodes": [{"season": 99, "episode": 99}], + "priority": "normal" + } + ) + + # The system should handle the request gracefully + assert response.status_code in [201, 400, 503] + + async def test_retry_count_increments(self, authenticated_client): + """Test that retry count increments on failures.""" + # Add a potentially failing item + add_response = await authenticated_client.post( + "/api/queue/add", + json={ + "serie_id": "test-series", + "serie_name": "Test Series", + "episodes": [{"season": 1, "episode": 1}], + "priority": "normal" + } + ) + + if add_response.status_code == 201: + # Get queue status to check retry count + status_response = await authenticated_client.get("/api/queue/status") + + if status_response.status_code == 200: + data = status_response.json() + # Verify structure includes retry_count field + for item_list in [data["status"].get("pending", []), + data["status"].get("failed", [])]: + for item in item_list: + assert "retry_count" in item + + +class TestAuthenticationRequirements: + """Test that download endpoints require authentication.""" + + async def test_queue_status_requires_auth(self, client): + """Test that queue status endpoint requires authentication.""" + response = await client.get("/api/queue/status") + assert response.status_code == 401 + + async def test_add_to_queue_requires_auth(self, client): + """Test that add to queue endpoint requires authentication.""" + response = await client.post( + "/api/queue/add", + json={ + "serie_id": "test-series", + "serie_name": "Test Series", + "episodes": [{"season": 1, "episode": 1}], + "priority": "normal" + } + ) + assert response.status_code == 401 + + async def test_queue_control_requires_auth(self, client): + """Test that queue control endpoints require authentication.""" + response = await client.post("/api/queue/control/start") + assert response.status_code == 401 + + async def test_item_operations_require_auth(self, client): + """Test that item operations require authentication.""" + response = await client.delete("/api/queue/items/dummy-id") + assert response.status_code == 401 + + +class TestConcurrentOperations: + """Test concurrent download operations.""" + + async def test_multiple_concurrent_downloads(self, authenticated_client): + """Test handling multiple concurrent download requests.""" + # Add multiple items concurrently + tasks = [] + for i in range(5): + task = authenticated_client.post( + "/api/queue/add", + json={ + "serie_id": f"series-{i}", + "serie_name": f"Series {i}", + "episodes": [{"season": 1, "episode": 1}], + "priority": "normal" + } + ) + tasks.append(task) + + # Wait for all requests to complete + responses = await asyncio.gather(*tasks, return_exceptions=True) + + # Verify all requests were handled + for response in responses: + if not isinstance(response, Exception): + assert response.status_code in [201, 503] + + async def test_concurrent_status_requests(self, authenticated_client): + """Test handling concurrent status requests.""" + # Make multiple concurrent status requests + tasks = [ + authenticated_client.get("/api/queue/status") + for _ in range(10) + ] + + responses = await asyncio.gather(*tasks, return_exceptions=True) + + # Verify all requests were handled + for response in responses: + if not isinstance(response, Exception): + assert response.status_code in [200, 503] + + +class TestQueuePersistence: + """Test queue state persistence.""" + + async def test_queue_survives_restart(self, authenticated_client, temp_queue_file): + """Test that queue state persists across service restarts.""" + # This would require actually restarting the service + # For integration testing, we verify the persistence mechanism exists + + # Add items to queue + response = await authenticated_client.post( + "/api/queue/add", + json={ + "serie_id": "persistent-series", + "serie_name": "Persistent Series", + "episodes": [{"season": 1, "episode": 1}], + "priority": "normal" + } + ) + + # Verify the request was processed + assert response.status_code in [201, 503] + + # In a full integration test, we would restart the service here + # and verify the queue state is restored + + async def test_failed_items_are_persisted(self, authenticated_client): + """Test that failed items are persisted.""" + # Get initial queue state + initial_response = await authenticated_client.get("/api/queue/status") + + assert initial_response.status_code in [200, 503] + + # The persistence mechanism should handle failed items + # In a real scenario, we would trigger a failure and verify persistence + + +class TestWebSocketIntegrationWithDownloads: + """Test WebSocket notifications during download operations.""" + + async def test_websocket_notifies_on_queue_changes(self, authenticated_client): + """Test that WebSocket broadcasts queue changes.""" + # This is a basic integration test + # Full WebSocket testing is in test_websocket.py + + # Add an item to trigger potential WebSocket notification + response = await authenticated_client.post( + "/api/queue/add", + json={ + "serie_id": "ws-series", + "serie_name": "WebSocket Series", + "episodes": [{"season": 1, "episode": 1}], + "priority": "normal" + } + ) + + # Verify the operation succeeded + assert response.status_code in [201, 503] + + # In a full test, we would verify WebSocket clients received notifications + + +class TestCompleteDownloadWorkflow: + """Test complete end-to-end download workflow.""" + + async def test_full_download_cycle(self, authenticated_client): + """Test complete download cycle from add to completion.""" + # 1. Add episode to queue + add_response = await authenticated_client.post( + "/api/queue/add", + json={ + "serie_id": "workflow-series", + "serie_name": "Workflow Test Series", + "episodes": [{"season": 1, "episode": 1}], + "priority": "high" + } + ) + + assert add_response.status_code in [201, 503] + + if add_response.status_code == 201: + item_id = add_response.json()["item_ids"][0] + + # 2. Verify item is in queue + status_response = await authenticated_client.get("/api/queue/status") + assert status_response.status_code in [200, 503] + + # 3. Start queue processing + start_response = await authenticated_client.post("/api/queue/control/start") + assert start_response.status_code in [200, 503] + + # 4. Check status during processing + await asyncio.sleep(0.1) # Brief delay + progress_response = await authenticated_client.get("/api/queue/status") + assert progress_response.status_code in [200, 503] + + # 5. Verify final state (completed or still processing) + final_response = await authenticated_client.get("/api/queue/status") + assert final_response.status_code in [200, 503] + + async def test_workflow_with_pause_and_resume(self, authenticated_client): + """Test download workflow with pause and resume.""" + # Add items + await authenticated_client.post( + "/api/queue/add", + json={ + "serie_id": "pause-test", + "serie_name": "Pause Test Series", + "episodes": [{"season": 1, "episode": 1}], + "priority": "normal" + } + ) + + # Start processing + await authenticated_client.post("/api/queue/control/start") + + # Pause + pause_response = await authenticated_client.post("/api/queue/control/pause") + assert pause_response.status_code in [200, 503] + + # Resume + resume_response = await authenticated_client.post("/api/queue/control/resume") + assert resume_response.status_code in [200, 503] + + # Verify queue status + status_response = await authenticated_client.get("/api/queue/status") + assert status_response.status_code in [200, 503] diff --git a/tests/integration/test_websocket.py b/tests/integration/test_websocket.py new file mode 100644 index 0000000..545d8f8 --- /dev/null +++ b/tests/integration/test_websocket.py @@ -0,0 +1,765 @@ +"""Integration tests for WebSocket functionality. + +This module tests the complete WebSocket integration including: +- WebSocket connection establishment and authentication +- Real-time message broadcasting +- Room-based messaging +- Connection lifecycle management +- Integration with download and progress services +- Error handling and reconnection +- Concurrent client management +""" +import asyncio +import json +from typing import Any, Dict, List +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from httpx import ASGITransport, AsyncClient +from starlette.websockets import WebSocketDisconnect + +from src.server.fastapi_app import app +from src.server.models.download import DownloadPriority, DownloadStatus +from src.server.services.auth_service import auth_service +from src.server.services.progress_service import ProgressType +from src.server.services.websocket_service import ( + ConnectionManager, + get_websocket_service, +) + + +@pytest.fixture(autouse=True) +def reset_auth(): + """Reset authentication state before each test.""" + original_hash = auth_service._hash + auth_service._hash = None + auth_service._failed.clear() + yield + auth_service._hash = original_hash + auth_service._failed.clear() + + +@pytest.fixture +async def client(): + """Create an async test client.""" + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + yield ac + + +@pytest.fixture +async def auth_token(client): + """Get a valid authentication token.""" + password = "WebSocketTestPassword123!" + await client.post( + "/api/auth/setup", + json={"master_password": password} + ) + response = await client.post( + "/api/auth/login", + json={"password": password} + ) + return response.json()["access_token"] + + +@pytest.fixture +def websocket_service(): + """Get the WebSocket service instance.""" + return get_websocket_service() + + +@pytest.fixture +def mock_websocket(): + """Create a mock WebSocket connection.""" + ws = AsyncMock() + ws.send_text = AsyncMock() + ws.send_json = AsyncMock() + ws.receive_text = AsyncMock() + ws.accept = AsyncMock() + ws.close = AsyncMock() + return ws + + +class TestWebSocketConnection: + """Test WebSocket connection establishment and lifecycle.""" + + async def test_websocket_endpoint_exists(self, client, auth_token): + """Test that WebSocket endpoint is available.""" + # This test verifies the endpoint exists + # Full WebSocket testing requires WebSocket client + + # Verify the WebSocket route is registered + routes = [route.path for route in app.routes] + websocket_routes = [ + path for path in routes if "ws" in path or "websocket" in path + ] + assert len(websocket_routes) > 0 + + async def test_connection_manager_tracks_connections( + self, websocket_service, mock_websocket + ): + """Test that connection manager tracks active connections.""" + manager = websocket_service.manager + + # Initially no connections + initial_count = len(manager.active_connections) + + # Add a connection + await manager.connect(mock_websocket, room="test-room") + + assert len(manager.active_connections) == initial_count + 1 + assert mock_websocket in manager.active_connections + + async def test_disconnect_removes_connection( + self, websocket_service, mock_websocket + ): + """Test that disconnecting removes connection from manager.""" + manager = websocket_service.manager + + # Connect + await manager.connect(mock_websocket, room="test-room") + assert mock_websocket in manager.active_connections + + # Disconnect + manager.disconnect(mock_websocket) + assert mock_websocket not in manager.active_connections + + async def test_room_assignment_on_connection( + self, websocket_service, mock_websocket + ): + """Test that connections are assigned to rooms.""" + manager = websocket_service.manager + room = "test-room-1" + + await manager.connect(mock_websocket, room=room) + + # Verify connection is in the room + assert room in manager._rooms + assert mock_websocket in manager._rooms[room] + + async def test_multiple_rooms_support( + self, websocket_service + ): + """Test that multiple rooms can exist simultaneously.""" + manager = websocket_service.manager + + ws1 = AsyncMock() + ws2 = AsyncMock() + ws3 = AsyncMock() + + # Connect to different rooms + await manager.connect(ws1, room="room-1") + await manager.connect(ws2, room="room-2") + await manager.connect(ws3, room="room-1") + + # Verify room structure + assert "room-1" in manager._rooms + assert "room-2" in manager._rooms + assert len(manager._rooms["room-1"]) == 2 + assert len(manager._rooms["room-2"]) == 1 + + +class TestMessageBroadcasting: + """Test message broadcasting functionality.""" + + async def test_broadcast_to_all_connections( + self, websocket_service + ): + """Test broadcasting message to all connected clients.""" + manager = websocket_service.manager + + # Create mock connections + ws1 = AsyncMock() + ws2 = AsyncMock() + ws3 = AsyncMock() + + await manager.connect(ws1, room="room-1") + await manager.connect(ws2, room="room-1") + await manager.connect(ws3, room="room-2") + + # Broadcast to all + message = {"type": "test", "data": "broadcast to all"} + await manager.broadcast(message) + + # All connections should receive message + ws1.send_json.assert_called_once() + ws2.send_json.assert_called_once() + ws3.send_json.assert_called_once() + + async def test_broadcast_to_specific_room( + self, websocket_service + ): + """Test broadcasting message to specific room only.""" + manager = websocket_service.manager + + ws1 = AsyncMock() + ws2 = AsyncMock() + ws3 = AsyncMock() + + await manager.connect(ws1, room="downloads") + await manager.connect(ws2, room="downloads") + await manager.connect(ws3, room="system") + + # Broadcast to specific room + message = {"type": "download_progress", "data": {}} + await manager.broadcast_to_room(message, room="downloads") + + # Only room members should receive + assert ws1.send_json.call_count == 1 + assert ws2.send_json.call_count == 1 + assert ws3.send_json.call_count == 0 + + async def test_broadcast_with_json_message( + self, websocket_service + ): + """Test broadcasting JSON-formatted messages.""" + manager = websocket_service.manager + ws = AsyncMock() + await manager.connect(ws, room="test") + + message = { + "type": "queue_update", + "data": { + "pending": 5, + "active": 2, + "completed": 10 + }, + "timestamp": "2025-10-19T10:00:00" + } + + await manager.broadcast(message) + + ws.send_json.assert_called_once_with(message) + + async def test_broadcast_handles_disconnected_clients( + self, websocket_service + ): + """Test that broadcasting handles disconnected clients gracefully.""" + manager = websocket_service.manager + + # Mock connection that will fail + failing_ws = AsyncMock() + failing_ws.send_json.side_effect = RuntimeError("Connection closed") + + working_ws = AsyncMock() + + await manager.connect(failing_ws, room="test") + await manager.connect(working_ws, room="test") + + # Broadcast should handle failure + message = {"type": "test", "data": "test"} + await manager.broadcast(message) + + # Working connection should still receive + working_ws.send_json.assert_called_once() + + +class TestProgressIntegration: + """Test integration with progress service.""" + + async def test_download_progress_broadcasts_to_websocket( + self, websocket_service + ): + """Test that download progress updates broadcast via WebSocket.""" + manager = websocket_service.manager + ws = AsyncMock() + await manager.connect(ws, room="downloads") + + # Simulate progress update broadcast + message = { + "type": "download_progress", + "data": { + "item_id": "test-download-1", + "percentage": 45.5, + "current_mb": 45.5, + "total_mb": 100.0, + "speed_mbps": 2.5 + } + } + + await manager.broadcast_to_room(message, room="downloads") + + ws.send_json.assert_called_once_with(message) + + async def test_download_complete_notification( + self, websocket_service + ): + """Test download completion notification via WebSocket.""" + manager = websocket_service.manager + ws = AsyncMock() + await manager.connect(ws, room="downloads") + + message = { + "type": "download_complete", + "data": { + "item_id": "test-download-1", + "serie_name": "Test Anime", + "episode": {"season": 1, "episode": 1} + } + } + + await manager.broadcast_to_room(message, room="downloads") + + ws.send_json.assert_called_once() + + async def test_download_failed_notification( + self, websocket_service + ): + """Test download failure notification via WebSocket.""" + manager = websocket_service.manager + ws = AsyncMock() + await manager.connect(ws, room="downloads") + + message = { + "type": "download_failed", + "data": { + "item_id": "test-download-1", + "error": "Network timeout", + "retry_count": 2 + } + } + + await manager.broadcast_to_room(message, room="downloads") + + ws.send_json.assert_called_once() + + +class TestQueueStatusBroadcasting: + """Test queue status broadcasting via WebSocket.""" + + async def test_queue_status_update_broadcast( + self, websocket_service + ): + """Test broadcasting queue status updates.""" + manager = websocket_service.manager + ws = AsyncMock() + await manager.connect(ws, room="queue") + + message = { + "type": "queue_status", + "data": { + "pending_count": 5, + "active_count": 2, + "completed_count": 10, + "failed_count": 1, + "total_items": 18 + } + } + + await manager.broadcast_to_room(message, room="queue") + + ws.send_json.assert_called_once_with(message) + + async def test_queue_item_added_notification( + self, websocket_service + ): + """Test notification when item is added to queue.""" + manager = websocket_service.manager + ws = AsyncMock() + await manager.connect(ws, room="queue") + + message = { + "type": "queue_item_added", + "data": { + "item_id": "new-item-1", + "serie_name": "New Series", + "episode_count": 3, + "priority": "normal" + } + } + + await manager.broadcast_to_room(message, room="queue") + + ws.send_json.assert_called_once() + + async def test_queue_item_removed_notification( + self, websocket_service + ): + """Test notification when item is removed from queue.""" + manager = websocket_service.manager + ws = AsyncMock() + await manager.connect(ws, room="queue") + + message = { + "type": "queue_item_removed", + "data": { + "item_id": "removed-item-1", + "reason": "user_cancelled" + } + } + + await manager.broadcast_to_room(message, room="queue") + + ws.send_json.assert_called_once() + + +class TestSystemMessaging: + """Test system-wide messaging via WebSocket.""" + + async def test_system_notification_broadcast( + self, websocket_service + ): + """Test broadcasting system notifications.""" + manager = websocket_service.manager + + ws1 = AsyncMock() + ws2 = AsyncMock() + await manager.connect(ws1, room="system") + await manager.connect(ws2, room="system") + + message = { + "type": "system_notification", + "data": { + "level": "info", + "message": "System maintenance scheduled", + "timestamp": "2025-10-19T10:00:00" + } + } + + await manager.broadcast_to_room(message, room="system") + + ws1.send_json.assert_called_once() + ws2.send_json.assert_called_once() + + async def test_error_message_broadcast( + self, websocket_service + ): + """Test broadcasting error messages.""" + manager = websocket_service.manager + ws = AsyncMock() + await manager.connect(ws, room="errors") + + message = { + "type": "error", + "data": { + "error_code": "DOWNLOAD_FAILED", + "message": "Failed to download episode", + "details": "Connection timeout" + } + } + + await manager.broadcast_to_room(message, room="errors") + + ws.send_json.assert_called_once() + + +class TestConcurrentConnections: + """Test handling of concurrent WebSocket connections.""" + + async def test_multiple_clients_in_same_room( + self, websocket_service + ): + """Test multiple clients receiving broadcasts in same room.""" + manager = websocket_service.manager + + # Create multiple connections + clients = [AsyncMock() for _ in range(5)] + for client in clients: + await manager.connect(client, room="shared-room") + + # Broadcast to all + message = {"type": "test", "data": "multi-client test"} + await manager.broadcast_to_room(message, room="shared-room") + + # All clients should receive + for client in clients: + client.send_json.assert_called_once_with(message) + + async def test_concurrent_broadcasts_to_different_rooms( + self, websocket_service + ): + """Test concurrent broadcasts to different rooms.""" + manager = websocket_service.manager + + # Setup rooms with clients + downloads_ws = AsyncMock() + queue_ws = AsyncMock() + system_ws = AsyncMock() + + await manager.connect(downloads_ws, room="downloads") + await manager.connect(queue_ws, room="queue") + await manager.connect(system_ws, room="system") + + # Concurrent broadcasts + await asyncio.gather( + manager.broadcast_to_room( + {"type": "download_progress"}, "downloads" + ), + manager.broadcast_to_room( + {"type": "queue_update"}, "queue" + ), + manager.broadcast_to_room( + {"type": "system_message"}, "system" + ) + ) + + # Each client should receive only their room's message + downloads_ws.send_json.assert_called_once() + queue_ws.send_json.assert_called_once() + system_ws.send_json.assert_called_once() + + +class TestConnectionErrorHandling: + """Test error handling in WebSocket connections.""" + + async def test_handle_send_failure( + self, websocket_service + ): + """Test handling of message send failures.""" + manager = websocket_service.manager + + # Connection that will fail on send + failing_ws = AsyncMock() + failing_ws.send_json.side_effect = RuntimeError("Send failed") + + await manager.connect(failing_ws, room="test") + + # Should handle error gracefully + message = {"type": "test", "data": "test"} + try: + await manager.broadcast_to_room(message, room="test") + except RuntimeError: + pytest.fail("Should handle send failure gracefully") + + async def test_handle_multiple_send_failures( + self, websocket_service + ): + """Test handling multiple concurrent send failures.""" + manager = websocket_service.manager + + # Multiple failing connections + failing_clients = [] + for i in range(3): + ws = AsyncMock() + ws.send_json.side_effect = RuntimeError(f"Failed {i}") + failing_clients.append(ws) + await manager.connect(ws, room="test") + + # Add one working connection + working_ws = AsyncMock() + await manager.connect(working_ws, room="test") + + # Broadcast should continue despite failures + message = {"type": "test", "data": "test"} + await manager.broadcast_to_room(message, room="test") + + # Working connection should still receive + working_ws.send_json.assert_called_once() + + async def test_cleanup_after_disconnect( + self, websocket_service + ): + """Test proper cleanup after client disconnect.""" + manager = websocket_service.manager + ws = AsyncMock() + room = "test-room" + + # Connect and then disconnect + await manager.connect(ws, room=room) + manager.disconnect(ws) + + # Verify cleanup + assert ws not in manager.active_connections + if room in manager._rooms: + assert ws not in manager._rooms[room] + + +class TestMessageFormatting: + """Test message formatting and validation.""" + + async def test_message_structure_validation( + self, websocket_service + ): + """Test that messages have required structure.""" + manager = websocket_service.manager + ws = AsyncMock() + await manager.connect(ws, room="test") + + # Valid message structure + valid_message = { + "type": "test_message", + "data": {"key": "value"}, + } + + await manager.broadcast(valid_message) + ws.send_json.assert_called_once_with(valid_message) + + async def test_different_message_types( + self, websocket_service + ): + """Test broadcasting different message types.""" + manager = websocket_service.manager + ws = AsyncMock() + await manager.connect(ws, room="test") + + message_types = [ + "download_progress", + "download_complete", + "download_failed", + "queue_status", + "system_notification", + "error" + ] + + for msg_type in message_types: + message = {"type": msg_type, "data": {}} + await manager.broadcast(message) + + # Should have received all message types + assert ws.send_json.call_count == len(message_types) + + +class TestWebSocketServiceIntegration: + """Test WebSocket service integration with other services.""" + + async def test_websocket_service_singleton(self): + """Test that WebSocket service is a singleton.""" + service1 = get_websocket_service() + service2 = get_websocket_service() + + assert service1 is service2 + + async def test_service_has_connection_manager(self): + """Test that service has connection manager.""" + service = get_websocket_service() + + assert hasattr(service, 'manager') + assert isinstance(service.manager, ConnectionManager) + + async def test_service_broadcast_methods_exist(self): + """Test that service has required broadcast methods.""" + service = get_websocket_service() + + required_methods = [ + 'broadcast_download_progress', + 'broadcast_download_complete', + 'broadcast_download_failed', + 'broadcast_queue_status', + 'broadcast_system_message', + 'send_error' + ] + + for method in required_methods: + assert hasattr(service, method) + + +class TestRoomManagement: + """Test room management functionality.""" + + async def test_room_creation_on_first_connection( + self, websocket_service + ): + """Test that room is created when first client connects.""" + manager = websocket_service.manager + ws = AsyncMock() + room = "new-room" + + # Room should not exist initially + assert room not in manager._rooms + + # Connect to room + await manager.connect(ws, room=room) + + # Room should now exist + assert room in manager._rooms + + async def test_room_cleanup_when_empty( + self, websocket_service + ): + """Test that empty rooms are cleaned up.""" + manager = websocket_service.manager + ws = AsyncMock() + room = "temp-room" + + # Connect and disconnect + await manager.connect(ws, room=room) + manager.disconnect(ws) + + # Room should be cleaned up if empty + # (Implementation may vary) + if room in manager._rooms: + assert len(manager._rooms[room]) == 0 + + async def test_client_can_be_in_one_room( + self, websocket_service + ): + """Test client room membership.""" + manager = websocket_service.manager + ws = AsyncMock() + + # Connect to room + await manager.connect(ws, room="room-1") + + # Verify in room + assert "room-1" in manager._rooms + assert ws in manager._rooms["room-1"] + + +class TestCompleteWebSocketWorkflow: + """Test complete WebSocket workflows.""" + + async def test_full_download_notification_workflow( + self, websocket_service + ): + """Test complete workflow of download notifications.""" + manager = websocket_service.manager + ws = AsyncMock() + await manager.connect(ws, room="downloads") + + # Simulate download lifecycle + + # 1. Download started + await manager.broadcast_to_room( + {"type": "download_started", "data": {"item_id": "dl-1"}}, + "downloads" + ) + + # 2. Progress updates + for progress in [25, 50, 75]: + await manager.broadcast_to_room( + { + "type": "download_progress", + "data": {"item_id": "dl-1", "percentage": progress} + }, + "downloads" + ) + + # 3. Download complete + await manager.broadcast_to_room( + {"type": "download_complete", "data": {"item_id": "dl-1"}}, + "downloads" + ) + + # Client should have received all notifications + assert ws.send_json.call_count == 5 + + async def test_multi_room_workflow( + self, websocket_service + ): + """Test workflow involving multiple rooms.""" + manager = websocket_service.manager + + # Setup clients in different rooms + download_ws = AsyncMock() + queue_ws = AsyncMock() + system_ws = AsyncMock() + + await manager.connect(download_ws, room="downloads") + await manager.connect(queue_ws, room="queue") + await manager.connect(system_ws, room="system") + + # Broadcast to each room + await manager.broadcast_to_room( + {"type": "download_update"}, "downloads" + ) + await manager.broadcast_to_room( + {"type": "queue_update"}, "queue" + ) + await manager.broadcast_to_room( + {"type": "system_update"}, "system" + ) + + # Each client should only receive their room's messages + download_ws.send_json.assert_called_once() + queue_ws.send_json.assert_called_once() + system_ws.send_json.assert_called_once()