Add WebSocket reconnection tests (68 unit + 18 integration)
This commit is contained in:
658
tests/integration/test_websocket_resilience.py
Normal file
658
tests/integration/test_websocket_resilience.py
Normal file
@@ -0,0 +1,658 @@
|
||||
"""Integration tests for WebSocket resilience and stress testing.
|
||||
|
||||
This module tests WebSocket connection resilience, concurrent client handling,
|
||||
server restart recovery, authentication, message ordering, and broadcast filtering.
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import List, Dict, Any
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import WebSocket
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from src.server.services.websocket_service import WebSocketService, get_websocket_service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def websocket_service():
|
||||
"""Create a WebSocketService instance for testing."""
|
||||
return WebSocketService()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth_token():
|
||||
"""Create a mock authentication token for testing."""
|
||||
return "test_auth_token_12345"
|
||||
|
||||
|
||||
class MockWebSocketClient:
|
||||
"""Mock WebSocket client for testing."""
|
||||
|
||||
def __init__(self, client_id: str, service: WebSocketService):
|
||||
self.client_id = client_id
|
||||
self.service = service
|
||||
self.received_messages: List[Dict[str, Any]] = []
|
||||
self.is_connected = False
|
||||
self.websocket = Mock(spec=WebSocket)
|
||||
self.websocket.send_json = self._mock_send_json
|
||||
self.websocket.accept = self._mock_accept
|
||||
|
||||
async def _mock_accept(self):
|
||||
"""Mock WebSocket accept."""
|
||||
self.is_connected = True
|
||||
|
||||
async def _mock_send_json(self, data: Dict[str, Any]):
|
||||
"""Mock WebSocket send_json to capture messages."""
|
||||
self.received_messages.append(data)
|
||||
|
||||
async def connect(self, metadata: Dict[str, Any] = None):
|
||||
"""Connect the mock client to the service."""
|
||||
await self.service._manager.connect(
|
||||
self.websocket,
|
||||
self.client_id,
|
||||
metadata or {}
|
||||
)
|
||||
self.is_connected = True
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect the mock client from the service."""
|
||||
await self.service._manager.disconnect(self.client_id)
|
||||
self.is_connected = False
|
||||
|
||||
async def join_room(self, room: str):
|
||||
"""Join a room."""
|
||||
await self.service._manager.join_room(self.client_id, room)
|
||||
|
||||
async def leave_room(self, room: str):
|
||||
"""Leave a room."""
|
||||
await self.service._manager.leave_room(self.client_id, room)
|
||||
|
||||
def clear_messages(self):
|
||||
"""Clear received messages."""
|
||||
self.received_messages.clear()
|
||||
|
||||
|
||||
class TestWebSocketConcurrentClients:
|
||||
"""Test WebSocket handling of multiple concurrent clients."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_concurrent_connections(self, websocket_service):
|
||||
"""Test handling 100+ concurrent WebSocket clients."""
|
||||
num_clients = 100
|
||||
clients: List[MockWebSocketClient] = []
|
||||
|
||||
# Connect 100 clients
|
||||
for i in range(num_clients):
|
||||
client = MockWebSocketClient(f"client_{i}", websocket_service)
|
||||
await client.connect({"user_id": f"user_{i}"})
|
||||
clients.append(client)
|
||||
|
||||
# Verify all clients are connected
|
||||
assert len(websocket_service._manager._active_connections) == num_clients
|
||||
|
||||
# Broadcast a message to all clients
|
||||
test_message = {
|
||||
"type": "test_broadcast",
|
||||
"timestamp": time.time(),
|
||||
"message": "Test broadcast to all clients"
|
||||
}
|
||||
await websocket_service.broadcast(test_message)
|
||||
|
||||
# Verify all clients received the message
|
||||
for client in clients:
|
||||
assert len(client.received_messages) == 1
|
||||
assert client.received_messages[0] == test_message
|
||||
|
||||
# Disconnect all clients
|
||||
for client in clients:
|
||||
await client.disconnect()
|
||||
|
||||
assert len(websocket_service._manager._active_connections) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_room_broadcasts(self, websocket_service):
|
||||
"""Test broadcasting to specific rooms with concurrent clients."""
|
||||
# Create clients in different rooms
|
||||
room_a_clients = []
|
||||
room_b_clients = []
|
||||
room_both_clients = []
|
||||
|
||||
for i in range(10):
|
||||
client = MockWebSocketClient(f"room_a_{i}", websocket_service)
|
||||
await client.connect()
|
||||
await client.join_room("room_a")
|
||||
room_a_clients.append(client)
|
||||
|
||||
for i in range(10):
|
||||
client = MockWebSocketClient(f"room_b_{i}", websocket_service)
|
||||
await client.connect()
|
||||
await client.join_room("room_b")
|
||||
room_b_clients.append(client)
|
||||
|
||||
for i in range(5):
|
||||
client = MockWebSocketClient(f"room_both_{i}", websocket_service)
|
||||
await client.connect()
|
||||
await client.join_room("room_a")
|
||||
await client.join_room("room_b")
|
||||
room_both_clients.append(client)
|
||||
|
||||
# Broadcast to room_a
|
||||
message_a = {"type": "room_a_message", "data": "Message for room A"}
|
||||
await websocket_service._manager.broadcast_to_room(message_a, "room_a")
|
||||
|
||||
# Verify room_a and room_both clients received, room_b did not
|
||||
for client in room_a_clients:
|
||||
assert len(client.received_messages) == 1
|
||||
assert client.received_messages[0] == message_a
|
||||
|
||||
for client in room_both_clients:
|
||||
assert len(client.received_messages) == 1
|
||||
assert client.received_messages[0] == message_a
|
||||
|
||||
for client in room_b_clients:
|
||||
assert len(client.received_messages) == 0
|
||||
|
||||
# Clear messages
|
||||
for client in room_a_clients + room_b_clients + room_both_clients:
|
||||
client.clear_messages()
|
||||
|
||||
# Broadcast to room_b
|
||||
message_b = {"type": "room_b_message", "data": "Message for room B"}
|
||||
await websocket_service._manager.broadcast_to_room(message_b, "room_b")
|
||||
|
||||
# Verify room_b and room_both clients received, room_a did not
|
||||
for client in room_b_clients:
|
||||
assert len(client.received_messages) == 1
|
||||
assert client.received_messages[0] == message_b
|
||||
|
||||
for client in room_both_clients:
|
||||
assert len(client.received_messages) == 1
|
||||
assert client.received_messages[0] == message_b
|
||||
|
||||
for client in room_a_clients:
|
||||
assert len(client.received_messages) == 0
|
||||
|
||||
# Cleanup
|
||||
for client in room_a_clients + room_b_clients + room_both_clients:
|
||||
await client.disconnect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rapid_connect_disconnect(self, websocket_service):
|
||||
"""Test rapid connection and disconnection cycles."""
|
||||
client_id = "rapid_test_client"
|
||||
|
||||
# Perform 50 rapid connect/disconnect cycles
|
||||
for i in range(50):
|
||||
client = MockWebSocketClient(f"{client_id}_{i}", websocket_service)
|
||||
await client.connect()
|
||||
assert client.is_connected
|
||||
await client.disconnect()
|
||||
assert not client.is_connected
|
||||
|
||||
# Verify no stale connections remain
|
||||
assert len(websocket_service._manager._active_connections) == 0
|
||||
assert len(websocket_service._manager._connection_metadata) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stress_message_rate(self, websocket_service):
|
||||
"""Test high-frequency message broadcasting."""
|
||||
num_clients = 20
|
||||
num_messages = 100
|
||||
clients: List[MockWebSocketClient] = []
|
||||
|
||||
# Connect clients
|
||||
for i in range(num_clients):
|
||||
client = MockWebSocketClient(f"stress_client_{i}", websocket_service)
|
||||
await client.connect()
|
||||
clients.append(client)
|
||||
|
||||
# Send 100 messages rapidly
|
||||
for i in range(num_messages):
|
||||
message = {
|
||||
"type": "stress_test",
|
||||
"sequence": i,
|
||||
"timestamp": time.time()
|
||||
}
|
||||
await websocket_service.broadcast(message)
|
||||
|
||||
# Verify all clients received all messages
|
||||
for client in clients:
|
||||
assert len(client.received_messages) == num_messages
|
||||
# Verify messages are in order
|
||||
for i in range(num_messages):
|
||||
assert client.received_messages[i]["sequence"] == i
|
||||
|
||||
# Cleanup
|
||||
for client in clients:
|
||||
await client.disconnect()
|
||||
|
||||
|
||||
class TestWebSocketConnectionRecovery:
|
||||
"""Test WebSocket connection recovery after failures."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_recovery_after_disconnect(self, websocket_service):
|
||||
"""Test client can reconnect after unexpected disconnect."""
|
||||
client_id = "recovery_test_client"
|
||||
|
||||
# Initial connection
|
||||
client1 = MockWebSocketClient(client_id, websocket_service)
|
||||
await client1.connect({"user_id": "test_user"})
|
||||
await client1.join_room("downloads")
|
||||
|
||||
# Simulate unexpected disconnect
|
||||
await client1.disconnect()
|
||||
assert not client1.is_connected
|
||||
|
||||
# Reconnect with same client_id
|
||||
client2 = MockWebSocketClient(client_id, websocket_service)
|
||||
await client2.connect({"user_id": "test_user"})
|
||||
await client2.join_room("downloads")
|
||||
|
||||
# Verify new connection works
|
||||
message = {"type": "test", "data": "recovery test"}
|
||||
await websocket_service._manager.broadcast_to_room(message, "downloads")
|
||||
|
||||
assert len(client2.received_messages) == 1
|
||||
assert client2.received_messages[0] == message
|
||||
|
||||
await client2.disconnect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_room_rejoin_after_reconnection(self, websocket_service):
|
||||
"""Test client can rejoin rooms after reconnection."""
|
||||
client_id = "rejoin_test_client"
|
||||
|
||||
# Connect and join multiple rooms
|
||||
client1 = MockWebSocketClient(client_id, websocket_service)
|
||||
await client1.connect()
|
||||
await client1.join_room("downloads")
|
||||
await client1.join_room("progress")
|
||||
await client1.join_room("updates")
|
||||
|
||||
# Verify client is in all rooms
|
||||
assert client_id in websocket_service._manager._rooms["downloads"]
|
||||
assert client_id in websocket_service._manager._rooms["progress"]
|
||||
assert client_id in websocket_service._manager._rooms["updates"]
|
||||
|
||||
# Disconnect
|
||||
await client1.disconnect()
|
||||
|
||||
# Rooms should be empty after disconnect
|
||||
for room in ["downloads", "progress", "updates"]:
|
||||
assert client_id not in websocket_service._manager._rooms.get(room, set())
|
||||
|
||||
# Reconnect and rejoin rooms
|
||||
client2 = MockWebSocketClient(client_id, websocket_service)
|
||||
await client2.connect()
|
||||
await client2.join_room("downloads")
|
||||
await client2.join_room("progress")
|
||||
await client2.join_room("updates")
|
||||
|
||||
# Verify client is in all rooms again
|
||||
assert client_id in websocket_service._manager._rooms["downloads"]
|
||||
assert client_id in websocket_service._manager._rooms["progress"]
|
||||
assert client_id in websocket_service._manager._rooms["updates"]
|
||||
|
||||
await client2.disconnect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_delivery_after_reconnection(self, websocket_service):
|
||||
"""Test messages are delivered correctly after reconnection."""
|
||||
client_id = "delivery_test_client"
|
||||
|
||||
# Connect, receive a message, disconnect
|
||||
client1 = MockWebSocketClient(client_id, websocket_service)
|
||||
await client1.connect()
|
||||
|
||||
message1 = {"type": "test", "sequence": 1}
|
||||
await websocket_service.broadcast(message1)
|
||||
assert len(client1.received_messages) == 1
|
||||
|
||||
await client1.disconnect()
|
||||
|
||||
# Reconnect and verify new messages are received
|
||||
client2 = MockWebSocketClient(client_id, websocket_service)
|
||||
await client2.connect()
|
||||
|
||||
message2 = {"type": "test", "sequence": 2}
|
||||
await websocket_service.broadcast(message2)
|
||||
|
||||
# Should only receive message2 (not message1 from before disconnect)
|
||||
assert len(client2.received_messages) == 1
|
||||
assert client2.received_messages[0] == message2
|
||||
|
||||
await client2.disconnect()
|
||||
|
||||
|
||||
class TestWebSocketAuthentication:
|
||||
"""Test WebSocket authentication and token handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_with_authentication_metadata(
|
||||
self, websocket_service, mock_auth_token
|
||||
):
|
||||
"""Test WebSocket connection with authentication token in metadata."""
|
||||
client = MockWebSocketClient("auth_client", websocket_service)
|
||||
metadata = {
|
||||
"user_id": "test_user",
|
||||
"auth_token": mock_auth_token,
|
||||
"session_id": "session_123"
|
||||
}
|
||||
|
||||
await client.connect(metadata)
|
||||
|
||||
# Verify metadata is stored
|
||||
stored_metadata = websocket_service._manager._connection_metadata["auth_client"]
|
||||
assert stored_metadata["user_id"] == "test_user"
|
||||
assert stored_metadata["auth_token"] == mock_auth_token
|
||||
assert stored_metadata["session_id"] == "session_123"
|
||||
|
||||
await client.disconnect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_to_specific_user(self, websocket_service):
|
||||
"""Test broadcasting to specific user using metadata filtering."""
|
||||
# Connect multiple clients with different user IDs
|
||||
client1 = MockWebSocketClient("client1", websocket_service)
|
||||
await client1.connect({"user_id": "user_1"})
|
||||
|
||||
client2 = MockWebSocketClient("client2", websocket_service)
|
||||
await client2.connect({"user_id": "user_2"})
|
||||
|
||||
client3 = MockWebSocketClient("client3", websocket_service)
|
||||
await client3.connect({"user_id": "user_1"}) # Same user, different connection
|
||||
|
||||
# Broadcast to specific user
|
||||
message = {"type": "user_specific", "data": "Message for user_1"}
|
||||
|
||||
# Filter connections by user_id and send
|
||||
for conn_id, metadata in websocket_service._manager._connection_metadata.items():
|
||||
if metadata.get("user_id") == "user_1":
|
||||
ws = websocket_service._manager._active_connections[conn_id]
|
||||
await ws.send_json(message)
|
||||
|
||||
# Verify only user_1 clients received the message
|
||||
assert len(client1.received_messages) == 1
|
||||
assert client1.received_messages[0] == message
|
||||
|
||||
assert len(client3.received_messages) == 1
|
||||
assert client3.received_messages[0] == message
|
||||
|
||||
assert len(client2.received_messages) == 0
|
||||
|
||||
# Cleanup
|
||||
await client1.disconnect()
|
||||
await client2.disconnect()
|
||||
await client3.disconnect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_refresh_in_metadata(self, websocket_service):
|
||||
"""Test updating authentication token in connection metadata."""
|
||||
client = MockWebSocketClient("token_refresh_client", websocket_service)
|
||||
old_token = "old_token_12345"
|
||||
new_token = "new_token_67890"
|
||||
|
||||
# Connect with old token
|
||||
await client.connect({"user_id": "test_user", "auth_token": old_token})
|
||||
|
||||
# Verify old token is stored
|
||||
metadata = websocket_service._manager._connection_metadata["token_refresh_client"]
|
||||
assert metadata["auth_token"] == old_token
|
||||
|
||||
# Update token (simulating token refresh)
|
||||
metadata["auth_token"] = new_token
|
||||
|
||||
# Verify token is updated
|
||||
updated_metadata = websocket_service._manager._connection_metadata["token_refresh_client"]
|
||||
assert updated_metadata["auth_token"] == new_token
|
||||
|
||||
await client.disconnect()
|
||||
|
||||
|
||||
class TestWebSocketMessageOrdering:
|
||||
"""Test WebSocket message ordering guarantees."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_order_preservation(self, websocket_service):
|
||||
"""Test messages are received in the order they are sent."""
|
||||
client = MockWebSocketClient("order_test_client", websocket_service)
|
||||
await client.connect()
|
||||
|
||||
# Send 50 messages in sequence
|
||||
num_messages = 50
|
||||
for i in range(num_messages):
|
||||
message = {
|
||||
"type": "sequence_test",
|
||||
"sequence": i,
|
||||
"timestamp": time.time()
|
||||
}
|
||||
await websocket_service.broadcast(message)
|
||||
|
||||
# Verify all messages received in order
|
||||
assert len(client.received_messages) == num_messages
|
||||
for i in range(num_messages):
|
||||
assert client.received_messages[i]["sequence"] == i
|
||||
|
||||
await client.disconnect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_broadcast_order(self, websocket_service):
|
||||
"""Test message ordering with concurrent broadcasts to different rooms."""
|
||||
# Create clients in two rooms
|
||||
room1_client = MockWebSocketClient("room1_client", websocket_service)
|
||||
await room1_client.connect()
|
||||
await room1_client.join_room("room1")
|
||||
|
||||
room2_client = MockWebSocketClient("room2_client", websocket_service)
|
||||
await room2_client.connect()
|
||||
await room2_client.join_room("room2")
|
||||
|
||||
both_rooms_client = MockWebSocketClient("both_client", websocket_service)
|
||||
await both_rooms_client.connect()
|
||||
await both_rooms_client.join_room("room1")
|
||||
await both_rooms_client.join_room("room2")
|
||||
|
||||
# Send interleaved messages to both rooms
|
||||
for i in range(10):
|
||||
message1 = {"type": "room1_msg", "sequence": i}
|
||||
await websocket_service._manager.broadcast_to_room(message1, "room1")
|
||||
|
||||
message2 = {"type": "room2_msg", "sequence": i}
|
||||
await websocket_service._manager.broadcast_to_room(message2, "room2")
|
||||
|
||||
# Verify room1_client received only room1 messages in order
|
||||
assert len(room1_client.received_messages) == 10
|
||||
for i in range(10):
|
||||
assert room1_client.received_messages[i]["type"] == "room1_msg"
|
||||
assert room1_client.received_messages[i]["sequence"] == i
|
||||
|
||||
# Verify room2_client received only room2 messages in order
|
||||
assert len(room2_client.received_messages) == 10
|
||||
for i in range(10):
|
||||
assert room2_client.received_messages[i]["type"] == "room2_msg"
|
||||
assert room2_client.received_messages[i]["sequence"] == i
|
||||
|
||||
# Verify both_rooms_client received all messages (may be interleaved)
|
||||
assert len(both_rooms_client.received_messages) == 20
|
||||
room1_msgs = [msg for msg in both_rooms_client.received_messages if msg["type"] == "room1_msg"]
|
||||
room2_msgs = [msg for msg in both_rooms_client.received_messages if msg["type"] == "room2_msg"]
|
||||
assert len(room1_msgs) == 10
|
||||
assert len(room2_msgs) == 10
|
||||
|
||||
# Cleanup
|
||||
await room1_client.disconnect()
|
||||
await room2_client.disconnect()
|
||||
await both_rooms_client.disconnect()
|
||||
|
||||
|
||||
class TestWebSocketBroadcastFiltering:
|
||||
"""Test WebSocket broadcast filtering to specific clients."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_to_all_except_sender(self, websocket_service):
|
||||
"""Test broadcasting to all clients except the sender."""
|
||||
# Connect multiple clients
|
||||
sender = MockWebSocketClient("sender", websocket_service)
|
||||
await sender.connect()
|
||||
|
||||
clients = []
|
||||
for i in range(5):
|
||||
client = MockWebSocketClient(f"client_{i}", websocket_service)
|
||||
await client.connect()
|
||||
clients.append(client)
|
||||
|
||||
# Broadcast to all except sender
|
||||
message = {"type": "broadcast", "data": "Message to all except sender"}
|
||||
|
||||
for conn_id in websocket_service._manager._active_connections:
|
||||
if conn_id != "sender":
|
||||
ws = websocket_service._manager._active_connections[conn_id]
|
||||
await ws.send_json(message)
|
||||
|
||||
# Verify sender did not receive message
|
||||
assert len(sender.received_messages) == 0
|
||||
|
||||
# Verify all other clients received message
|
||||
for client in clients:
|
||||
assert len(client.received_messages) == 1
|
||||
assert client.received_messages[0] == message
|
||||
|
||||
# Cleanup
|
||||
await sender.disconnect()
|
||||
for client in clients:
|
||||
await client.disconnect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_filtered_by_metadata(self, websocket_service):
|
||||
"""Test broadcasting filtered by connection metadata."""
|
||||
# Connect clients with different roles
|
||||
admin_clients = []
|
||||
for i in range(3):
|
||||
client = MockWebSocketClient(f"admin_{i}", websocket_service)
|
||||
await client.connect({"role": "admin", "user_id": f"admin_{i}"})
|
||||
admin_clients.append(client)
|
||||
|
||||
user_clients = []
|
||||
for i in range(3):
|
||||
client = MockWebSocketClient(f"user_{i}", websocket_service)
|
||||
await client.connect({"role": "user", "user_id": f"user_{i}"})
|
||||
user_clients.append(client)
|
||||
|
||||
# Broadcast only to admins
|
||||
admin_message = {"type": "admin_only", "data": "Admin notification"}
|
||||
|
||||
for conn_id, metadata in websocket_service._manager._connection_metadata.items():
|
||||
if metadata.get("role") == "admin":
|
||||
ws = websocket_service._manager._active_connections[conn_id]
|
||||
await ws.send_json(admin_message)
|
||||
|
||||
# Verify only admin clients received message
|
||||
for client in admin_clients:
|
||||
assert len(client.received_messages) == 1
|
||||
assert client.received_messages[0] == admin_message
|
||||
|
||||
for client in user_clients:
|
||||
assert len(client.received_messages) == 0
|
||||
|
||||
# Cleanup
|
||||
for client in admin_clients + user_clients:
|
||||
await client.disconnect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_room_based_filtering(self, websocket_service):
|
||||
"""Test combining room membership and metadata filtering."""
|
||||
# Create clients with different metadata in the same room
|
||||
premium_client = MockWebSocketClient("premium", websocket_service)
|
||||
await premium_client.connect({"subscription": "premium"})
|
||||
await premium_client.join_room("downloads")
|
||||
|
||||
free_client = MockWebSocketClient("free", websocket_service)
|
||||
await free_client.connect({"subscription": "free"})
|
||||
await free_client.join_room("downloads")
|
||||
|
||||
# Send premium-only message to downloads room
|
||||
premium_message = {"type": "premium_feature", "data": "Premium notification"}
|
||||
|
||||
# Get clients in downloads room with premium subscription
|
||||
room_members = websocket_service._manager._rooms.get("downloads", set())
|
||||
for conn_id in room_members:
|
||||
metadata = websocket_service._manager._connection_metadata.get(conn_id, {})
|
||||
if metadata.get("subscription") == "premium":
|
||||
ws = websocket_service._manager._active_connections[conn_id]
|
||||
await ws.send_json(premium_message)
|
||||
|
||||
# Verify only premium client received message
|
||||
assert len(premium_client.received_messages) == 1
|
||||
assert premium_client.received_messages[0] == premium_message
|
||||
|
||||
assert len(free_client.received_messages) == 0
|
||||
|
||||
# Cleanup
|
||||
await premium_client.disconnect()
|
||||
await free_client.disconnect()
|
||||
|
||||
|
||||
class TestWebSocketEdgeCases:
|
||||
"""Test WebSocket edge cases and error conditions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_duplicate_connection_ids(self, websocket_service):
|
||||
"""Test handling duplicate connection IDs (should replace old connection)."""
|
||||
client_id = "duplicate_id"
|
||||
|
||||
# First connection
|
||||
client1 = MockWebSocketClient(client_id, websocket_service)
|
||||
await client1.connect()
|
||||
|
||||
# Send message to first connection
|
||||
message1 = {"type": "test", "sequence": 1}
|
||||
await websocket_service.broadcast(message1)
|
||||
assert len(client1.received_messages) == 1
|
||||
|
||||
# Second connection with same ID (should replace first)
|
||||
client2 = MockWebSocketClient(client_id, websocket_service)
|
||||
await client2.connect()
|
||||
|
||||
# Only one connection should exist
|
||||
assert len(websocket_service._manager._active_connections) == 1
|
||||
|
||||
# Send message to second connection
|
||||
message2 = {"type": "test", "sequence": 2}
|
||||
await websocket_service.broadcast(message2)
|
||||
|
||||
# Second client should receive message
|
||||
assert len(client2.received_messages) == 1
|
||||
assert client2.received_messages[0] == message2
|
||||
|
||||
await client2.disconnect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_leave_nonexistent_room(self, websocket_service):
|
||||
"""Test leaving a room that doesn't exist or client isn't in."""
|
||||
client = MockWebSocketClient("test_client", websocket_service)
|
||||
await client.connect()
|
||||
|
||||
# Should not raise error
|
||||
await client.leave_room("nonexistent_room")
|
||||
|
||||
await client.disconnect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_to_disconnected_client(self, websocket_service):
|
||||
"""Test sending message to a client that has disconnected."""
|
||||
client = MockWebSocketClient("disconnect_test", websocket_service)
|
||||
await client.connect()
|
||||
|
||||
# Disconnect
|
||||
await client.disconnect()
|
||||
|
||||
# Attempt to broadcast (should not raise error)
|
||||
message = {"type": "test", "data": "test"}
|
||||
await websocket_service.broadcast(message)
|
||||
|
||||
# Client should not receive message (already disconnected)
|
||||
assert len(client.received_messages) == 0
|
||||
Reference in New Issue
Block a user