feat: implement graceful shutdown with SIGINT/SIGTERM support
- Add WebSocket shutdown() with client notification and graceful close - Enhance download service stop() with pending state persistence - Expand FastAPI lifespan shutdown with proper cleanup sequence - Add SQLite WAL checkpoint before database close - Update stop_server.sh to use SIGTERM with timeout fallback - Configure uvicorn timeout_graceful_shutdown=30s - Update ARCHITECTURE.md with shutdown documentation
This commit is contained in:
@@ -150,11 +150,29 @@ async def init_db() -> None:
|
||||
async def close_db() -> None:
|
||||
"""Close database connections and cleanup resources.
|
||||
|
||||
Performs a WAL checkpoint for SQLite databases to ensure all
|
||||
pending writes are flushed to the main database file before
|
||||
closing connections. This prevents database corruption during
|
||||
shutdown.
|
||||
|
||||
Should be called during application shutdown.
|
||||
"""
|
||||
global _engine, _sync_engine, _session_factory, _sync_session_factory
|
||||
|
||||
try:
|
||||
# For SQLite: checkpoint WAL to ensure all writes are flushed
|
||||
if _sync_engine and "sqlite" in str(_sync_engine.url):
|
||||
logger.info("Running SQLite WAL checkpoint before shutdown...")
|
||||
try:
|
||||
from sqlalchemy import text
|
||||
with _sync_engine.connect() as conn:
|
||||
# TRUNCATE mode: checkpoint and truncate WAL file
|
||||
conn.execute(text("PRAGMA wal_checkpoint(TRUNCATE)"))
|
||||
conn.commit()
|
||||
logger.info("SQLite WAL checkpoint completed")
|
||||
except Exception as e:
|
||||
logger.warning(f"WAL checkpoint failed (non-critical): {e}")
|
||||
|
||||
if _engine:
|
||||
logger.info("Closing async database engine...")
|
||||
await _engine.dispose()
|
||||
|
||||
@@ -155,30 +155,81 @@ async def lifespan(_application: FastAPI):
|
||||
# Yield control to the application
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logger.info("FastAPI application shutting down")
|
||||
# Shutdown - execute in proper order with timeout protection
|
||||
logger.info("FastAPI application shutting down (graceful shutdown initiated)")
|
||||
|
||||
# Shutdown download service and its thread pool
|
||||
# Define shutdown timeout (total time allowed for all shutdown operations)
|
||||
SHUTDOWN_TIMEOUT = 30.0
|
||||
|
||||
import time
|
||||
shutdown_start = time.monotonic()
|
||||
|
||||
def remaining_time() -> float:
|
||||
"""Calculate remaining shutdown time."""
|
||||
elapsed = time.monotonic() - shutdown_start
|
||||
return max(0.0, SHUTDOWN_TIMEOUT - elapsed)
|
||||
|
||||
# 1. Broadcast shutdown notification via WebSocket
|
||||
try:
|
||||
ws_service = get_websocket_service()
|
||||
logger.info("Broadcasting shutdown notification to WebSocket clients...")
|
||||
await asyncio.wait_for(
|
||||
ws_service.shutdown(timeout=min(5.0, remaining_time())),
|
||||
timeout=min(5.0, remaining_time())
|
||||
)
|
||||
logger.info("WebSocket shutdown complete")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("WebSocket shutdown timed out")
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.error("Error during WebSocket shutdown: %s", e, exc_info=True)
|
||||
|
||||
# 2. Shutdown download service and persist active downloads
|
||||
try:
|
||||
from src.server.services.download_service import ( # noqa: E501
|
||||
_download_service_instance,
|
||||
)
|
||||
if _download_service_instance is not None:
|
||||
logger.info("Stopping download service...")
|
||||
await _download_service_instance.stop()
|
||||
await asyncio.wait_for(
|
||||
_download_service_instance.stop(timeout=min(10.0, remaining_time())),
|
||||
timeout=min(15.0, remaining_time())
|
||||
)
|
||||
logger.info("Download service stopped successfully")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Download service shutdown timed out")
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.error("Error stopping download service: %s", e, exc_info=True)
|
||||
|
||||
# Close database connections
|
||||
# 3. Cleanup progress service
|
||||
try:
|
||||
progress_service = get_progress_service()
|
||||
logger.info("Cleaning up progress service...")
|
||||
# Clear any active progress tracking and subscribers
|
||||
progress_service._subscribers.clear()
|
||||
progress_service._active_progress.clear()
|
||||
logger.info("Progress service cleanup complete")
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.error("Error cleaning up progress service: %s", e, exc_info=True)
|
||||
|
||||
# 4. Close database connections with WAL checkpoint
|
||||
try:
|
||||
from src.server.database.connection import close_db
|
||||
await close_db()
|
||||
logger.info("Closing database connections...")
|
||||
await asyncio.wait_for(
|
||||
close_db(),
|
||||
timeout=min(10.0, remaining_time())
|
||||
)
|
||||
logger.info("Database connections closed")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Database shutdown timed out")
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
logger.error("Error closing database: %s", e, exc_info=True)
|
||||
|
||||
logger.info("FastAPI application shutdown complete")
|
||||
elapsed_total = time.monotonic() - shutdown_start
|
||||
logger.info(
|
||||
"FastAPI application shutdown complete (took %.2fs)",
|
||||
elapsed_total
|
||||
)
|
||||
|
||||
|
||||
# Initialize FastAPI app with lifespan
|
||||
|
||||
@@ -997,30 +997,76 @@ class DownloadService:
|
||||
"""
|
||||
logger.info("Download queue service initialized")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the download queue service and cancel active downloads.
|
||||
async def stop(self, timeout: float = 10.0) -> None:
|
||||
"""Stop the download queue service gracefully.
|
||||
|
||||
Cancels any active download and shuts down the thread pool immediately.
|
||||
Persists in-progress downloads back to pending state, cancels active
|
||||
tasks, and shuts down the thread pool with a timeout.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time (seconds) to wait for executor shutdown
|
||||
"""
|
||||
logger.info("Stopping download queue service...")
|
||||
logger.info("Stopping download queue service (timeout=%.1fs)...", timeout)
|
||||
|
||||
# Set shutdown flag
|
||||
# Set shutdown flag first to prevent new downloads
|
||||
self._is_shutting_down = True
|
||||
self._is_stopped = True
|
||||
|
||||
# Persist active download back to pending state if one exists
|
||||
if self._active_download:
|
||||
logger.info(
|
||||
"Persisting active download to pending: item_id=%s",
|
||||
self._active_download.id
|
||||
)
|
||||
try:
|
||||
# Reset status to pending so it can be resumed on restart
|
||||
self._active_download.status = DownloadStatus.PENDING
|
||||
self._active_download.completed_at = None
|
||||
await self._save_to_database(self._active_download)
|
||||
logger.info("Active download persisted to database as pending")
|
||||
except Exception as e:
|
||||
logger.error("Failed to persist active download: %s", e)
|
||||
|
||||
# Cancel active download task if running
|
||||
active_task = self._active_download_task
|
||||
if active_task and not active_task.done():
|
||||
logger.info("Cancelling active download task...")
|
||||
active_task.cancel()
|
||||
try:
|
||||
await active_task
|
||||
# Wait briefly for cancellation to complete
|
||||
await asyncio.wait_for(
|
||||
asyncio.shield(active_task),
|
||||
timeout=2.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Download task cancellation timed out")
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Active download task cancelled")
|
||||
except Exception as e:
|
||||
logger.warning("Error during task cancellation: %s", e)
|
||||
|
||||
# Shutdown executor immediately, don't wait for tasks
|
||||
# Shutdown executor with wait and timeout
|
||||
logger.info("Shutting down thread pool executor...")
|
||||
self._executor.shutdown(wait=False, cancel_futures=True)
|
||||
try:
|
||||
# Run executor shutdown in thread to avoid blocking event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
await asyncio.wait_for(
|
||||
loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._executor.shutdown(wait=True, cancel_futures=True)
|
||||
),
|
||||
timeout=timeout
|
||||
)
|
||||
logger.info("Thread pool executor shutdown complete")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"Executor shutdown timed out after %.1fs, forcing shutdown",
|
||||
timeout
|
||||
)
|
||||
# Force shutdown without waiting
|
||||
self._executor.shutdown(wait=False, cancel_futures=True)
|
||||
except Exception as e:
|
||||
logger.error("Error during executor shutdown: %s", e)
|
||||
|
||||
logger.info("Download queue service stopped")
|
||||
|
||||
|
||||
@@ -322,6 +322,85 @@ class ConnectionManager:
|
||||
connection_id=connection_id,
|
||||
)
|
||||
|
||||
async def shutdown(self, timeout: float = 5.0) -> None:
|
||||
"""Gracefully shutdown all WebSocket connections.
|
||||
|
||||
Broadcasts a shutdown notification to all clients, then closes
|
||||
each connection with proper close codes.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time (seconds) to wait for all closes to complete
|
||||
"""
|
||||
logger.info(
|
||||
"Initiating WebSocket shutdown, connections=%d",
|
||||
len(self._active_connections)
|
||||
)
|
||||
|
||||
# Broadcast shutdown notification to all clients
|
||||
shutdown_message = {
|
||||
"type": "server_shutdown",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"data": {
|
||||
"message": "Server is shutting down",
|
||||
"reason": "graceful_shutdown",
|
||||
},
|
||||
}
|
||||
|
||||
try:
|
||||
await self.broadcast(shutdown_message)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to broadcast shutdown message: %s", e)
|
||||
|
||||
# Close all connections gracefully
|
||||
async with self._lock:
|
||||
connection_ids = list(self._active_connections.keys())
|
||||
|
||||
close_tasks = []
|
||||
for connection_id in connection_ids:
|
||||
websocket = self._active_connections.get(connection_id)
|
||||
if websocket:
|
||||
close_tasks.append(
|
||||
self._close_connection_gracefully(connection_id, websocket)
|
||||
)
|
||||
|
||||
if close_tasks:
|
||||
# Wait for all closes with timeout
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(*close_tasks, return_exceptions=True),
|
||||
timeout=timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"WebSocket shutdown timed out after %.1f seconds", timeout
|
||||
)
|
||||
|
||||
# Clear all data structures
|
||||
async with self._lock:
|
||||
self._active_connections.clear()
|
||||
self._rooms.clear()
|
||||
self._connection_metadata.clear()
|
||||
|
||||
logger.info("WebSocket shutdown complete")
|
||||
|
||||
async def _close_connection_gracefully(
|
||||
self, connection_id: str, websocket: WebSocket
|
||||
) -> None:
|
||||
"""Close a single WebSocket connection gracefully.
|
||||
|
||||
Args:
|
||||
connection_id: The connection identifier
|
||||
websocket: The WebSocket connection to close
|
||||
"""
|
||||
try:
|
||||
# Code 1001 = Going Away (server shutdown)
|
||||
await websocket.close(code=1001, reason="Server shutdown")
|
||||
logger.debug("Closed WebSocket connection: %s", connection_id)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Error closing WebSocket %s: %s", connection_id, str(e)
|
||||
)
|
||||
|
||||
|
||||
class WebSocketService:
|
||||
"""High-level WebSocket service for application-wide messaging.
|
||||
@@ -579,6 +658,18 @@ class WebSocketService:
|
||||
elapsed_seconds=round(elapsed_seconds, 2),
|
||||
)
|
||||
|
||||
async def shutdown(self, timeout: float = 5.0) -> None:
|
||||
"""Gracefully shutdown the WebSocket service.
|
||||
|
||||
Broadcasts shutdown notification and closes all connections.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time (seconds) to wait for shutdown
|
||||
"""
|
||||
logger.info("Shutting down WebSocket service...")
|
||||
await self._manager.shutdown(timeout=timeout)
|
||||
logger.info("WebSocket service shutdown complete")
|
||||
|
||||
|
||||
# Singleton instance for application-wide access
|
||||
_websocket_service: Optional[WebSocketService] = None
|
||||
|
||||
Reference in New Issue
Block a user