Files
Aniworld/tests/integration/test_websocket_resilience.py

662 lines
26 KiB
Python

"""Integration tests for WebSocket resilience and stress testing.
This module tests WebSocket connection resilience, concurrent client handling,
server restart recovery, authentication, message ordering, and broadcast filtering.
"""
import asyncio
import json
import time
from typing import Any, Dict, List
from unittest.mock import Mock, patch
import pytest
from fastapi import WebSocket
from fastapi.testclient import TestClient
from src.server.services.websocket_service import (
WebSocketService,
get_websocket_service,
)
@pytest.fixture
def websocket_service():
"""Create a WebSocketService instance for testing."""
return WebSocketService()
@pytest.fixture
def mock_auth_token():
"""Create a mock authentication token for testing."""
return "test_auth_token_12345"
class MockWebSocketClient:
"""Mock WebSocket client for testing."""
def __init__(self, client_id: str, service: WebSocketService):
self.client_id = client_id
self.service = service
self.received_messages: List[Dict[str, Any]] = []
self.is_connected = False
self.websocket = Mock(spec=WebSocket)
self.websocket.send_json = self._mock_send_json
self.websocket.accept = self._mock_accept
async def _mock_accept(self):
"""Mock WebSocket accept."""
self.is_connected = True
async def _mock_send_json(self, data: Dict[str, Any]):
"""Mock WebSocket send_json to capture messages."""
self.received_messages.append(data)
async def connect(self, metadata: Dict[str, Any] = None):
"""Connect the mock client to the service."""
await self.service._manager.connect(
self.websocket,
self.client_id,
metadata or {}
)
self.is_connected = True
async def disconnect(self):
"""Disconnect the mock client from the service."""
await self.service._manager.disconnect(self.client_id)
self.is_connected = False
async def join_room(self, room: str):
"""Join a room."""
await self.service._manager.join_room(self.client_id, room)
async def leave_room(self, room: str):
"""Leave a room."""
await self.service._manager.leave_room(self.client_id, room)
def clear_messages(self):
"""Clear received messages."""
self.received_messages.clear()
class TestWebSocketConcurrentClients:
"""Test WebSocket handling of multiple concurrent clients."""
@pytest.mark.asyncio
async def test_multiple_concurrent_connections(self, websocket_service):
"""Test handling 100+ concurrent WebSocket clients."""
num_clients = 100
clients: List[MockWebSocketClient] = []
# Connect 100 clients
for i in range(num_clients):
client = MockWebSocketClient(f"client_{i}", websocket_service)
await client.connect({"user_id": f"user_{i}"})
clients.append(client)
# Verify all clients are connected
assert len(websocket_service._manager._active_connections) == num_clients
# Broadcast a message to all clients
test_message = {
"type": "test_broadcast",
"timestamp": time.time(),
"message": "Test broadcast to all clients"
}
await websocket_service.broadcast(test_message)
# Verify all clients received the message
for client in clients:
assert len(client.received_messages) == 1
assert client.received_messages[0] == test_message
# Disconnect all clients
for client in clients:
await client.disconnect()
assert len(websocket_service._manager._active_connections) == 0
@pytest.mark.asyncio
async def test_concurrent_room_broadcasts(self, websocket_service):
"""Test broadcasting to specific rooms with concurrent clients."""
# Create clients in different rooms
room_a_clients = []
room_b_clients = []
room_both_clients = []
for i in range(10):
client = MockWebSocketClient(f"room_a_{i}", websocket_service)
await client.connect()
await client.join_room("room_a")
room_a_clients.append(client)
for i in range(10):
client = MockWebSocketClient(f"room_b_{i}", websocket_service)
await client.connect()
await client.join_room("room_b")
room_b_clients.append(client)
for i in range(5):
client = MockWebSocketClient(f"room_both_{i}", websocket_service)
await client.connect()
await client.join_room("room_a")
await client.join_room("room_b")
room_both_clients.append(client)
# Broadcast to room_a
message_a = {"type": "room_a_message", "data": "Message for room A"}
await websocket_service._manager.broadcast_to_room(message_a, "room_a")
# Verify room_a and room_both clients received, room_b did not
for client in room_a_clients:
assert len(client.received_messages) == 1
assert client.received_messages[0] == message_a
for client in room_both_clients:
assert len(client.received_messages) == 1
assert client.received_messages[0] == message_a
for client in room_b_clients:
assert len(client.received_messages) == 0
# Clear messages
for client in room_a_clients + room_b_clients + room_both_clients:
client.clear_messages()
# Broadcast to room_b
message_b = {"type": "room_b_message", "data": "Message for room B"}
await websocket_service._manager.broadcast_to_room(message_b, "room_b")
# Verify room_b and room_both clients received, room_a did not
for client in room_b_clients:
assert len(client.received_messages) == 1
assert client.received_messages[0] == message_b
for client in room_both_clients:
assert len(client.received_messages) == 1
assert client.received_messages[0] == message_b
for client in room_a_clients:
assert len(client.received_messages) == 0
# Cleanup
for client in room_a_clients + room_b_clients + room_both_clients:
await client.disconnect()
@pytest.mark.asyncio
async def test_rapid_connect_disconnect(self, websocket_service):
"""Test rapid connection and disconnection cycles."""
client_id = "rapid_test_client"
# Perform 50 rapid connect/disconnect cycles
for i in range(50):
client = MockWebSocketClient(f"{client_id}_{i}", websocket_service)
await client.connect()
assert client.is_connected
await client.disconnect()
assert not client.is_connected
# Verify no stale connections remain
assert len(websocket_service._manager._active_connections) == 0
assert len(websocket_service._manager._connection_metadata) == 0
@pytest.mark.asyncio
async def test_stress_message_rate(self, websocket_service):
"""Test high-frequency message broadcasting."""
num_clients = 20
num_messages = 100
clients: List[MockWebSocketClient] = []
# Connect clients
for i in range(num_clients):
client = MockWebSocketClient(f"stress_client_{i}", websocket_service)
await client.connect()
clients.append(client)
# Send 100 messages rapidly
for i in range(num_messages):
message = {
"type": "stress_test",
"sequence": i,
"timestamp": time.time()
}
await websocket_service.broadcast(message)
# Verify all clients received all messages
for client in clients:
assert len(client.received_messages) == num_messages
# Verify messages are in order
for i in range(num_messages):
assert client.received_messages[i]["sequence"] == i
# Cleanup
for client in clients:
await client.disconnect()
class TestWebSocketConnectionRecovery:
"""Test WebSocket connection recovery after failures."""
@pytest.mark.asyncio
async def test_connection_recovery_after_disconnect(self, websocket_service):
"""Test client can reconnect after unexpected disconnect."""
client_id = "recovery_test_client"
# Initial connection
client1 = MockWebSocketClient(client_id, websocket_service)
await client1.connect({"user_id": "test_user"})
await client1.join_room("downloads")
# Simulate unexpected disconnect
await client1.disconnect()
assert not client1.is_connected
# Reconnect with same client_id
client2 = MockWebSocketClient(client_id, websocket_service)
await client2.connect({"user_id": "test_user"})
await client2.join_room("downloads")
# Verify new connection works
message = {"type": "test", "data": "recovery test"}
await websocket_service._manager.broadcast_to_room(message, "downloads")
assert len(client2.received_messages) == 1
assert client2.received_messages[0] == message
await client2.disconnect()
@pytest.mark.asyncio
async def test_room_rejoin_after_reconnection(self, websocket_service):
"""Test client can rejoin rooms after reconnection."""
client_id = "rejoin_test_client"
# Connect and join multiple rooms
client1 = MockWebSocketClient(client_id, websocket_service)
await client1.connect()
await client1.join_room("downloads")
await client1.join_room("progress")
await client1.join_room("updates")
# Verify client is in all rooms
assert client_id in websocket_service._manager._rooms["downloads"]
assert client_id in websocket_service._manager._rooms["progress"]
assert client_id in websocket_service._manager._rooms["updates"]
# Disconnect
await client1.disconnect()
# Rooms should be empty after disconnect
for room in ["downloads", "progress", "updates"]:
assert client_id not in websocket_service._manager._rooms.get(room, set())
# Reconnect and rejoin rooms
client2 = MockWebSocketClient(client_id, websocket_service)
await client2.connect()
await client2.join_room("downloads")
await client2.join_room("progress")
await client2.join_room("updates")
# Verify client is in all rooms again
assert client_id in websocket_service._manager._rooms["downloads"]
assert client_id in websocket_service._manager._rooms["progress"]
assert client_id in websocket_service._manager._rooms["updates"]
await client2.disconnect()
@pytest.mark.asyncio
async def test_message_delivery_after_reconnection(self, websocket_service):
"""Test messages are delivered correctly after reconnection."""
client_id = "delivery_test_client"
# Connect, receive a message, disconnect
client1 = MockWebSocketClient(client_id, websocket_service)
await client1.connect()
message1 = {"type": "test", "sequence": 1}
await websocket_service.broadcast(message1)
assert len(client1.received_messages) == 1
await client1.disconnect()
# Reconnect and verify new messages are received
client2 = MockWebSocketClient(client_id, websocket_service)
await client2.connect()
message2 = {"type": "test", "sequence": 2}
await websocket_service.broadcast(message2)
# Should only receive message2 (not message1 from before disconnect)
assert len(client2.received_messages) == 1
assert client2.received_messages[0] == message2
await client2.disconnect()
class TestWebSocketAuthentication:
"""Test WebSocket authentication and token handling."""
@pytest.mark.asyncio
async def test_connection_with_authentication_metadata(
self, websocket_service, mock_auth_token
):
"""Test WebSocket connection with authentication token in metadata."""
client = MockWebSocketClient("auth_client", websocket_service)
metadata = {
"user_id": "test_user",
"auth_token": mock_auth_token,
"session_id": "session_123"
}
await client.connect(metadata)
# Verify metadata is stored
stored_metadata = websocket_service._manager._connection_metadata["auth_client"]
assert stored_metadata["user_id"] == "test_user"
assert stored_metadata["auth_token"] == mock_auth_token
assert stored_metadata["session_id"] == "session_123"
await client.disconnect()
@pytest.mark.asyncio
async def test_broadcast_to_specific_user(self, websocket_service):
"""Test broadcasting to specific user using metadata filtering."""
# Connect multiple clients with different user IDs
client1 = MockWebSocketClient("client1", websocket_service)
await client1.connect({"user_id": "user_1"})
client2 = MockWebSocketClient("client2", websocket_service)
await client2.connect({"user_id": "user_2"})
client3 = MockWebSocketClient("client3", websocket_service)
await client3.connect({"user_id": "user_1"}) # Same user, different connection
# Broadcast to specific user
message = {"type": "user_specific", "data": "Message for user_1"}
# Filter connections by user_id and send
for conn_id, metadata in websocket_service._manager._connection_metadata.items():
if metadata.get("user_id") == "user_1":
ws = websocket_service._manager._active_connections[conn_id]
await ws.send_json(message)
# Verify only user_1 clients received the message
assert len(client1.received_messages) == 1
assert client1.received_messages[0] == message
assert len(client3.received_messages) == 1
assert client3.received_messages[0] == message
assert len(client2.received_messages) == 0
# Cleanup
await client1.disconnect()
await client2.disconnect()
await client3.disconnect()
@pytest.mark.asyncio
async def test_token_refresh_in_metadata(self, websocket_service):
"""Test updating authentication token in connection metadata."""
client = MockWebSocketClient("token_refresh_client", websocket_service)
old_token = "old_token_12345"
new_token = "new_token_67890"
# Connect with old token
await client.connect({"user_id": "test_user", "auth_token": old_token})
# Verify old token is stored
metadata = websocket_service._manager._connection_metadata["token_refresh_client"]
assert metadata["auth_token"] == old_token
# Update token (simulating token refresh)
metadata["auth_token"] = new_token
# Verify token is updated
updated_metadata = websocket_service._manager._connection_metadata["token_refresh_client"]
assert updated_metadata["auth_token"] == new_token
await client.disconnect()
class TestWebSocketMessageOrdering:
"""Test WebSocket message ordering guarantees."""
@pytest.mark.asyncio
async def test_message_order_preservation(self, websocket_service):
"""Test messages are received in the order they are sent."""
client = MockWebSocketClient("order_test_client", websocket_service)
await client.connect()
# Send 50 messages in sequence
num_messages = 50
for i in range(num_messages):
message = {
"type": "sequence_test",
"sequence": i,
"timestamp": time.time()
}
await websocket_service.broadcast(message)
# Verify all messages received in order
assert len(client.received_messages) == num_messages
for i in range(num_messages):
assert client.received_messages[i]["sequence"] == i
await client.disconnect()
@pytest.mark.asyncio
async def test_concurrent_broadcast_order(self, websocket_service):
"""Test message ordering with concurrent broadcasts to different rooms."""
# Create clients in two rooms
room1_client = MockWebSocketClient("room1_client", websocket_service)
await room1_client.connect()
await room1_client.join_room("room1")
room2_client = MockWebSocketClient("room2_client", websocket_service)
await room2_client.connect()
await room2_client.join_room("room2")
both_rooms_client = MockWebSocketClient("both_client", websocket_service)
await both_rooms_client.connect()
await both_rooms_client.join_room("room1")
await both_rooms_client.join_room("room2")
# Send interleaved messages to both rooms
for i in range(10):
message1 = {"type": "room1_msg", "sequence": i}
await websocket_service._manager.broadcast_to_room(message1, "room1")
message2 = {"type": "room2_msg", "sequence": i}
await websocket_service._manager.broadcast_to_room(message2, "room2")
# Verify room1_client received only room1 messages in order
assert len(room1_client.received_messages) == 10
for i in range(10):
assert room1_client.received_messages[i]["type"] == "room1_msg"
assert room1_client.received_messages[i]["sequence"] == i
# Verify room2_client received only room2 messages in order
assert len(room2_client.received_messages) == 10
for i in range(10):
assert room2_client.received_messages[i]["type"] == "room2_msg"
assert room2_client.received_messages[i]["sequence"] == i
# Verify both_rooms_client received all messages (may be interleaved)
assert len(both_rooms_client.received_messages) == 20
room1_msgs = [msg for msg in both_rooms_client.received_messages if msg["type"] == "room1_msg"]
room2_msgs = [msg for msg in both_rooms_client.received_messages if msg["type"] == "room2_msg"]
assert len(room1_msgs) == 10
assert len(room2_msgs) == 10
# Cleanup
await room1_client.disconnect()
await room2_client.disconnect()
await both_rooms_client.disconnect()
class TestWebSocketBroadcastFiltering:
"""Test WebSocket broadcast filtering to specific clients."""
@pytest.mark.asyncio
async def test_broadcast_to_all_except_sender(self, websocket_service):
"""Test broadcasting to all clients except the sender."""
# Connect multiple clients
sender = MockWebSocketClient("sender", websocket_service)
await sender.connect()
clients = []
for i in range(5):
client = MockWebSocketClient(f"client_{i}", websocket_service)
await client.connect()
clients.append(client)
# Broadcast to all except sender
message = {"type": "broadcast", "data": "Message to all except sender"}
for conn_id in websocket_service._manager._active_connections:
if conn_id != "sender":
ws = websocket_service._manager._active_connections[conn_id]
await ws.send_json(message)
# Verify sender did not receive message
assert len(sender.received_messages) == 0
# Verify all other clients received message
for client in clients:
assert len(client.received_messages) == 1
assert client.received_messages[0] == message
# Cleanup
await sender.disconnect()
for client in clients:
await client.disconnect()
@pytest.mark.asyncio
async def test_broadcast_filtered_by_metadata(self, websocket_service):
"""Test broadcasting filtered by connection metadata."""
# Connect clients with different roles
admin_clients = []
for i in range(3):
client = MockWebSocketClient(f"admin_{i}", websocket_service)
await client.connect({"role": "admin", "user_id": f"admin_{i}"})
admin_clients.append(client)
user_clients = []
for i in range(3):
client = MockWebSocketClient(f"user_{i}", websocket_service)
await client.connect({"role": "user", "user_id": f"user_{i}"})
user_clients.append(client)
# Broadcast only to admins
admin_message = {"type": "admin_only", "data": "Admin notification"}
for conn_id, metadata in websocket_service._manager._connection_metadata.items():
if metadata.get("role") == "admin":
ws = websocket_service._manager._active_connections[conn_id]
await ws.send_json(admin_message)
# Verify only admin clients received message
for client in admin_clients:
assert len(client.received_messages) == 1
assert client.received_messages[0] == admin_message
for client in user_clients:
assert len(client.received_messages) == 0
# Cleanup
for client in admin_clients + user_clients:
await client.disconnect()
@pytest.mark.asyncio
async def test_room_based_filtering(self, websocket_service):
"""Test combining room membership and metadata filtering."""
# Create clients with different metadata in the same room
premium_client = MockWebSocketClient("premium", websocket_service)
await premium_client.connect({"subscription": "premium"})
await premium_client.join_room("downloads")
free_client = MockWebSocketClient("free", websocket_service)
await free_client.connect({"subscription": "free"})
await free_client.join_room("downloads")
# Send premium-only message to downloads room
premium_message = {"type": "premium_feature", "data": "Premium notification"}
# Get clients in downloads room with premium subscription
room_members = websocket_service._manager._rooms.get("downloads", set())
for conn_id in room_members:
metadata = websocket_service._manager._connection_metadata.get(conn_id, {})
if metadata.get("subscription") == "premium":
ws = websocket_service._manager._active_connections[conn_id]
await ws.send_json(premium_message)
# Verify only premium client received message
assert len(premium_client.received_messages) == 1
assert premium_client.received_messages[0] == premium_message
assert len(free_client.received_messages) == 0
# Cleanup
await premium_client.disconnect()
await free_client.disconnect()
class TestWebSocketEdgeCases:
"""Test WebSocket edge cases and error conditions."""
@pytest.mark.asyncio
async def test_duplicate_connection_ids(self, websocket_service):
"""Test handling duplicate connection IDs (should replace old connection)."""
client_id = "duplicate_id"
# First connection
client1 = MockWebSocketClient(client_id, websocket_service)
await client1.connect()
# Send message to first connection
message1 = {"type": "test", "sequence": 1}
await websocket_service.broadcast(message1)
assert len(client1.received_messages) == 1
# Second connection with same ID (should replace first)
client2 = MockWebSocketClient(client_id, websocket_service)
await client2.connect()
# Only one connection should exist
assert len(websocket_service._manager._active_connections) == 1
# Send message to second connection
message2 = {"type": "test", "sequence": 2}
await websocket_service.broadcast(message2)
# Second client should receive message
assert len(client2.received_messages) == 1
assert client2.received_messages[0] == message2
await client2.disconnect()
@pytest.mark.asyncio
async def test_leave_nonexistent_room(self, websocket_service):
"""Test leaving a room that doesn't exist or client isn't in."""
client = MockWebSocketClient("test_client", websocket_service)
await client.connect()
# Should not raise error
await client.leave_room("nonexistent_room")
await client.disconnect()
@pytest.mark.asyncio
async def test_send_to_disconnected_client(self, websocket_service):
"""Test sending message to a client that has disconnected."""
client = MockWebSocketClient("disconnect_test", websocket_service)
await client.connect()
# Disconnect
await client.disconnect()
# Attempt to broadcast (should not raise error)
message = {"type": "test", "data": "test"}
await websocket_service.broadcast(message)
# Client should not receive message (already disconnected)
assert len(client.received_messages) == 0