"""Integration tests for WebSocket resilience and stress testing. This module tests WebSocket connection resilience, concurrent client handling, server restart recovery, authentication, message ordering, and broadcast filtering. """ import asyncio import json import time from typing import Any, Dict, List from unittest.mock import Mock, patch import pytest from fastapi import WebSocket from fastapi.testclient import TestClient from src.server.services.websocket_service import ( WebSocketService, get_websocket_service, ) @pytest.fixture def websocket_service(): """Create a WebSocketService instance for testing.""" return WebSocketService() @pytest.fixture def mock_auth_token(): """Create a mock authentication token for testing.""" return "test_auth_token_12345" class MockWebSocketClient: """Mock WebSocket client for testing.""" def __init__(self, client_id: str, service: WebSocketService): self.client_id = client_id self.service = service self.received_messages: List[Dict[str, Any]] = [] self.is_connected = False self.websocket = Mock(spec=WebSocket) self.websocket.send_json = self._mock_send_json self.websocket.accept = self._mock_accept async def _mock_accept(self): """Mock WebSocket accept.""" self.is_connected = True async def _mock_send_json(self, data: Dict[str, Any]): """Mock WebSocket send_json to capture messages.""" self.received_messages.append(data) async def connect(self, metadata: Dict[str, Any] = None): """Connect the mock client to the service.""" await self.service._manager.connect( self.websocket, self.client_id, metadata or {} ) self.is_connected = True async def disconnect(self): """Disconnect the mock client from the service.""" await self.service._manager.disconnect(self.client_id) self.is_connected = False async def join_room(self, room: str): """Join a room.""" await self.service._manager.join_room(self.client_id, room) async def leave_room(self, room: str): """Leave a room.""" await self.service._manager.leave_room(self.client_id, room) def clear_messages(self): """Clear received messages.""" self.received_messages.clear() class TestWebSocketConcurrentClients: """Test WebSocket handling of multiple concurrent clients.""" @pytest.mark.asyncio async def test_multiple_concurrent_connections(self, websocket_service): """Test handling 100+ concurrent WebSocket clients.""" num_clients = 100 clients: List[MockWebSocketClient] = [] # Connect 100 clients for i in range(num_clients): client = MockWebSocketClient(f"client_{i}", websocket_service) await client.connect({"user_id": f"user_{i}"}) clients.append(client) # Verify all clients are connected assert len(websocket_service._manager._active_connections) == num_clients # Broadcast a message to all clients test_message = { "type": "test_broadcast", "timestamp": time.time(), "message": "Test broadcast to all clients" } await websocket_service.broadcast(test_message) # Verify all clients received the message for client in clients: assert len(client.received_messages) == 1 assert client.received_messages[0] == test_message # Disconnect all clients for client in clients: await client.disconnect() assert len(websocket_service._manager._active_connections) == 0 @pytest.mark.asyncio async def test_concurrent_room_broadcasts(self, websocket_service): """Test broadcasting to specific rooms with concurrent clients.""" # Create clients in different rooms room_a_clients = [] room_b_clients = [] room_both_clients = [] for i in range(10): client = MockWebSocketClient(f"room_a_{i}", websocket_service) await client.connect() await client.join_room("room_a") room_a_clients.append(client) for i in range(10): client = MockWebSocketClient(f"room_b_{i}", websocket_service) await client.connect() await client.join_room("room_b") room_b_clients.append(client) for i in range(5): client = MockWebSocketClient(f"room_both_{i}", websocket_service) await client.connect() await client.join_room("room_a") await client.join_room("room_b") room_both_clients.append(client) # Broadcast to room_a message_a = {"type": "room_a_message", "data": "Message for room A"} await websocket_service._manager.broadcast_to_room(message_a, "room_a") # Verify room_a and room_both clients received, room_b did not for client in room_a_clients: assert len(client.received_messages) == 1 assert client.received_messages[0] == message_a for client in room_both_clients: assert len(client.received_messages) == 1 assert client.received_messages[0] == message_a for client in room_b_clients: assert len(client.received_messages) == 0 # Clear messages for client in room_a_clients + room_b_clients + room_both_clients: client.clear_messages() # Broadcast to room_b message_b = {"type": "room_b_message", "data": "Message for room B"} await websocket_service._manager.broadcast_to_room(message_b, "room_b") # Verify room_b and room_both clients received, room_a did not for client in room_b_clients: assert len(client.received_messages) == 1 assert client.received_messages[0] == message_b for client in room_both_clients: assert len(client.received_messages) == 1 assert client.received_messages[0] == message_b for client in room_a_clients: assert len(client.received_messages) == 0 # Cleanup for client in room_a_clients + room_b_clients + room_both_clients: await client.disconnect() @pytest.mark.asyncio async def test_rapid_connect_disconnect(self, websocket_service): """Test rapid connection and disconnection cycles.""" client_id = "rapid_test_client" # Perform 50 rapid connect/disconnect cycles for i in range(50): client = MockWebSocketClient(f"{client_id}_{i}", websocket_service) await client.connect() assert client.is_connected await client.disconnect() assert not client.is_connected # Verify no stale connections remain assert len(websocket_service._manager._active_connections) == 0 assert len(websocket_service._manager._connection_metadata) == 0 @pytest.mark.asyncio async def test_stress_message_rate(self, websocket_service): """Test high-frequency message broadcasting.""" num_clients = 20 num_messages = 100 clients: List[MockWebSocketClient] = [] # Connect clients for i in range(num_clients): client = MockWebSocketClient(f"stress_client_{i}", websocket_service) await client.connect() clients.append(client) # Send 100 messages rapidly for i in range(num_messages): message = { "type": "stress_test", "sequence": i, "timestamp": time.time() } await websocket_service.broadcast(message) # Verify all clients received all messages for client in clients: assert len(client.received_messages) == num_messages # Verify messages are in order for i in range(num_messages): assert client.received_messages[i]["sequence"] == i # Cleanup for client in clients: await client.disconnect() class TestWebSocketConnectionRecovery: """Test WebSocket connection recovery after failures.""" @pytest.mark.asyncio async def test_connection_recovery_after_disconnect(self, websocket_service): """Test client can reconnect after unexpected disconnect.""" client_id = "recovery_test_client" # Initial connection client1 = MockWebSocketClient(client_id, websocket_service) await client1.connect({"user_id": "test_user"}) await client1.join_room("downloads") # Simulate unexpected disconnect await client1.disconnect() assert not client1.is_connected # Reconnect with same client_id client2 = MockWebSocketClient(client_id, websocket_service) await client2.connect({"user_id": "test_user"}) await client2.join_room("downloads") # Verify new connection works message = {"type": "test", "data": "recovery test"} await websocket_service._manager.broadcast_to_room(message, "downloads") assert len(client2.received_messages) == 1 assert client2.received_messages[0] == message await client2.disconnect() @pytest.mark.asyncio async def test_room_rejoin_after_reconnection(self, websocket_service): """Test client can rejoin rooms after reconnection.""" client_id = "rejoin_test_client" # Connect and join multiple rooms client1 = MockWebSocketClient(client_id, websocket_service) await client1.connect() await client1.join_room("downloads") await client1.join_room("progress") await client1.join_room("updates") # Verify client is in all rooms assert client_id in websocket_service._manager._rooms["downloads"] assert client_id in websocket_service._manager._rooms["progress"] assert client_id in websocket_service._manager._rooms["updates"] # Disconnect await client1.disconnect() # Rooms should be empty after disconnect for room in ["downloads", "progress", "updates"]: assert client_id not in websocket_service._manager._rooms.get(room, set()) # Reconnect and rejoin rooms client2 = MockWebSocketClient(client_id, websocket_service) await client2.connect() await client2.join_room("downloads") await client2.join_room("progress") await client2.join_room("updates") # Verify client is in all rooms again assert client_id in websocket_service._manager._rooms["downloads"] assert client_id in websocket_service._manager._rooms["progress"] assert client_id in websocket_service._manager._rooms["updates"] await client2.disconnect() @pytest.mark.asyncio async def test_message_delivery_after_reconnection(self, websocket_service): """Test messages are delivered correctly after reconnection.""" client_id = "delivery_test_client" # Connect, receive a message, disconnect client1 = MockWebSocketClient(client_id, websocket_service) await client1.connect() message1 = {"type": "test", "sequence": 1} await websocket_service.broadcast(message1) assert len(client1.received_messages) == 1 await client1.disconnect() # Reconnect and verify new messages are received client2 = MockWebSocketClient(client_id, websocket_service) await client2.connect() message2 = {"type": "test", "sequence": 2} await websocket_service.broadcast(message2) # Should only receive message2 (not message1 from before disconnect) assert len(client2.received_messages) == 1 assert client2.received_messages[0] == message2 await client2.disconnect() class TestWebSocketAuthentication: """Test WebSocket authentication and token handling.""" @pytest.mark.asyncio async def test_connection_with_authentication_metadata( self, websocket_service, mock_auth_token ): """Test WebSocket connection with authentication token in metadata.""" client = MockWebSocketClient("auth_client", websocket_service) metadata = { "user_id": "test_user", "auth_token": mock_auth_token, "session_id": "session_123" } await client.connect(metadata) # Verify metadata is stored stored_metadata = websocket_service._manager._connection_metadata["auth_client"] assert stored_metadata["user_id"] == "test_user" assert stored_metadata["auth_token"] == mock_auth_token assert stored_metadata["session_id"] == "session_123" await client.disconnect() @pytest.mark.asyncio async def test_broadcast_to_specific_user(self, websocket_service): """Test broadcasting to specific user using metadata filtering.""" # Connect multiple clients with different user IDs client1 = MockWebSocketClient("client1", websocket_service) await client1.connect({"user_id": "user_1"}) client2 = MockWebSocketClient("client2", websocket_service) await client2.connect({"user_id": "user_2"}) client3 = MockWebSocketClient("client3", websocket_service) await client3.connect({"user_id": "user_1"}) # Same user, different connection # Broadcast to specific user message = {"type": "user_specific", "data": "Message for user_1"} # Filter connections by user_id and send for conn_id, metadata in websocket_service._manager._connection_metadata.items(): if metadata.get("user_id") == "user_1": ws = websocket_service._manager._active_connections[conn_id] await ws.send_json(message) # Verify only user_1 clients received the message assert len(client1.received_messages) == 1 assert client1.received_messages[0] == message assert len(client3.received_messages) == 1 assert client3.received_messages[0] == message assert len(client2.received_messages) == 0 # Cleanup await client1.disconnect() await client2.disconnect() await client3.disconnect() @pytest.mark.asyncio async def test_token_refresh_in_metadata(self, websocket_service): """Test updating authentication token in connection metadata.""" client = MockWebSocketClient("token_refresh_client", websocket_service) old_token = "old_token_12345" new_token = "new_token_67890" # Connect with old token await client.connect({"user_id": "test_user", "auth_token": old_token}) # Verify old token is stored metadata = websocket_service._manager._connection_metadata["token_refresh_client"] assert metadata["auth_token"] == old_token # Update token (simulating token refresh) metadata["auth_token"] = new_token # Verify token is updated updated_metadata = websocket_service._manager._connection_metadata["token_refresh_client"] assert updated_metadata["auth_token"] == new_token await client.disconnect() class TestWebSocketMessageOrdering: """Test WebSocket message ordering guarantees.""" @pytest.mark.asyncio async def test_message_order_preservation(self, websocket_service): """Test messages are received in the order they are sent.""" client = MockWebSocketClient("order_test_client", websocket_service) await client.connect() # Send 50 messages in sequence num_messages = 50 for i in range(num_messages): message = { "type": "sequence_test", "sequence": i, "timestamp": time.time() } await websocket_service.broadcast(message) # Verify all messages received in order assert len(client.received_messages) == num_messages for i in range(num_messages): assert client.received_messages[i]["sequence"] == i await client.disconnect() @pytest.mark.asyncio async def test_concurrent_broadcast_order(self, websocket_service): """Test message ordering with concurrent broadcasts to different rooms.""" # Create clients in two rooms room1_client = MockWebSocketClient("room1_client", websocket_service) await room1_client.connect() await room1_client.join_room("room1") room2_client = MockWebSocketClient("room2_client", websocket_service) await room2_client.connect() await room2_client.join_room("room2") both_rooms_client = MockWebSocketClient("both_client", websocket_service) await both_rooms_client.connect() await both_rooms_client.join_room("room1") await both_rooms_client.join_room("room2") # Send interleaved messages to both rooms for i in range(10): message1 = {"type": "room1_msg", "sequence": i} await websocket_service._manager.broadcast_to_room(message1, "room1") message2 = {"type": "room2_msg", "sequence": i} await websocket_service._manager.broadcast_to_room(message2, "room2") # Verify room1_client received only room1 messages in order assert len(room1_client.received_messages) == 10 for i in range(10): assert room1_client.received_messages[i]["type"] == "room1_msg" assert room1_client.received_messages[i]["sequence"] == i # Verify room2_client received only room2 messages in order assert len(room2_client.received_messages) == 10 for i in range(10): assert room2_client.received_messages[i]["type"] == "room2_msg" assert room2_client.received_messages[i]["sequence"] == i # Verify both_rooms_client received all messages (may be interleaved) assert len(both_rooms_client.received_messages) == 20 room1_msgs = [msg for msg in both_rooms_client.received_messages if msg["type"] == "room1_msg"] room2_msgs = [msg for msg in both_rooms_client.received_messages if msg["type"] == "room2_msg"] assert len(room1_msgs) == 10 assert len(room2_msgs) == 10 # Cleanup await room1_client.disconnect() await room2_client.disconnect() await both_rooms_client.disconnect() class TestWebSocketBroadcastFiltering: """Test WebSocket broadcast filtering to specific clients.""" @pytest.mark.asyncio async def test_broadcast_to_all_except_sender(self, websocket_service): """Test broadcasting to all clients except the sender.""" # Connect multiple clients sender = MockWebSocketClient("sender", websocket_service) await sender.connect() clients = [] for i in range(5): client = MockWebSocketClient(f"client_{i}", websocket_service) await client.connect() clients.append(client) # Broadcast to all except sender message = {"type": "broadcast", "data": "Message to all except sender"} for conn_id in websocket_service._manager._active_connections: if conn_id != "sender": ws = websocket_service._manager._active_connections[conn_id] await ws.send_json(message) # Verify sender did not receive message assert len(sender.received_messages) == 0 # Verify all other clients received message for client in clients: assert len(client.received_messages) == 1 assert client.received_messages[0] == message # Cleanup await sender.disconnect() for client in clients: await client.disconnect() @pytest.mark.asyncio async def test_broadcast_filtered_by_metadata(self, websocket_service): """Test broadcasting filtered by connection metadata.""" # Connect clients with different roles admin_clients = [] for i in range(3): client = MockWebSocketClient(f"admin_{i}", websocket_service) await client.connect({"role": "admin", "user_id": f"admin_{i}"}) admin_clients.append(client) user_clients = [] for i in range(3): client = MockWebSocketClient(f"user_{i}", websocket_service) await client.connect({"role": "user", "user_id": f"user_{i}"}) user_clients.append(client) # Broadcast only to admins admin_message = {"type": "admin_only", "data": "Admin notification"} for conn_id, metadata in websocket_service._manager._connection_metadata.items(): if metadata.get("role") == "admin": ws = websocket_service._manager._active_connections[conn_id] await ws.send_json(admin_message) # Verify only admin clients received message for client in admin_clients: assert len(client.received_messages) == 1 assert client.received_messages[0] == admin_message for client in user_clients: assert len(client.received_messages) == 0 # Cleanup for client in admin_clients + user_clients: await client.disconnect() @pytest.mark.asyncio async def test_room_based_filtering(self, websocket_service): """Test combining room membership and metadata filtering.""" # Create clients with different metadata in the same room premium_client = MockWebSocketClient("premium", websocket_service) await premium_client.connect({"subscription": "premium"}) await premium_client.join_room("downloads") free_client = MockWebSocketClient("free", websocket_service) await free_client.connect({"subscription": "free"}) await free_client.join_room("downloads") # Send premium-only message to downloads room premium_message = {"type": "premium_feature", "data": "Premium notification"} # Get clients in downloads room with premium subscription room_members = websocket_service._manager._rooms.get("downloads", set()) for conn_id in room_members: metadata = websocket_service._manager._connection_metadata.get(conn_id, {}) if metadata.get("subscription") == "premium": ws = websocket_service._manager._active_connections[conn_id] await ws.send_json(premium_message) # Verify only premium client received message assert len(premium_client.received_messages) == 1 assert premium_client.received_messages[0] == premium_message assert len(free_client.received_messages) == 0 # Cleanup await premium_client.disconnect() await free_client.disconnect() class TestWebSocketEdgeCases: """Test WebSocket edge cases and error conditions.""" @pytest.mark.asyncio async def test_duplicate_connection_ids(self, websocket_service): """Test handling duplicate connection IDs (should replace old connection).""" client_id = "duplicate_id" # First connection client1 = MockWebSocketClient(client_id, websocket_service) await client1.connect() # Send message to first connection message1 = {"type": "test", "sequence": 1} await websocket_service.broadcast(message1) assert len(client1.received_messages) == 1 # Second connection with same ID (should replace first) client2 = MockWebSocketClient(client_id, websocket_service) await client2.connect() # Only one connection should exist assert len(websocket_service._manager._active_connections) == 1 # Send message to second connection message2 = {"type": "test", "sequence": 2} await websocket_service.broadcast(message2) # Second client should receive message assert len(client2.received_messages) == 1 assert client2.received_messages[0] == message2 await client2.disconnect() @pytest.mark.asyncio async def test_leave_nonexistent_room(self, websocket_service): """Test leaving a room that doesn't exist or client isn't in.""" client = MockWebSocketClient("test_client", websocket_service) await client.connect() # Should not raise error await client.leave_room("nonexistent_room") await client.disconnect() @pytest.mark.asyncio async def test_send_to_disconnected_client(self, websocket_service): """Test sending message to a client that has disconnected.""" client = MockWebSocketClient("disconnect_test", websocket_service) await client.connect() # Disconnect await client.disconnect() # Attempt to broadcast (should not raise error) message = {"type": "test", "data": "test"} await websocket_service.broadcast(message) # Client should not receive message (already disconnected) assert len(client.received_messages) == 0