feat: implement graceful shutdown with SIGINT/SIGTERM support

- Add WebSocket shutdown() with client notification and graceful close
- Enhance download service stop() with pending state persistence
- Expand FastAPI lifespan shutdown with proper cleanup sequence
- Add SQLite WAL checkpoint before database close
- Update stop_server.sh to use SIGTERM with timeout fallback
- Configure uvicorn timeout_graceful_shutdown=30s
- Update ARCHITECTURE.md with shutdown documentation
This commit is contained in:
2025-12-25 18:59:07 +01:00
parent 1ba67357dc
commit d70d70e193
9 changed files with 443 additions and 175 deletions

View File

@@ -150,11 +150,29 @@ async def init_db() -> None:
async def close_db() -> None:
"""Close database connections and cleanup resources.
Performs a WAL checkpoint for SQLite databases to ensure all
pending writes are flushed to the main database file before
closing connections. This prevents database corruption during
shutdown.
Should be called during application shutdown.
"""
global _engine, _sync_engine, _session_factory, _sync_session_factory
try:
# For SQLite: checkpoint WAL to ensure all writes are flushed
if _sync_engine and "sqlite" in str(_sync_engine.url):
logger.info("Running SQLite WAL checkpoint before shutdown...")
try:
from sqlalchemy import text
with _sync_engine.connect() as conn:
# TRUNCATE mode: checkpoint and truncate WAL file
conn.execute(text("PRAGMA wal_checkpoint(TRUNCATE)"))
conn.commit()
logger.info("SQLite WAL checkpoint completed")
except Exception as e:
logger.warning(f"WAL checkpoint failed (non-critical): {e}")
if _engine:
logger.info("Closing async database engine...")
await _engine.dispose()

View File

@@ -155,30 +155,81 @@ async def lifespan(_application: FastAPI):
# Yield control to the application
yield
# Shutdown
logger.info("FastAPI application shutting down")
# Shutdown - execute in proper order with timeout protection
logger.info("FastAPI application shutting down (graceful shutdown initiated)")
# Shutdown download service and its thread pool
# Define shutdown timeout (total time allowed for all shutdown operations)
SHUTDOWN_TIMEOUT = 30.0
import time
shutdown_start = time.monotonic()
def remaining_time() -> float:
"""Calculate remaining shutdown time."""
elapsed = time.monotonic() - shutdown_start
return max(0.0, SHUTDOWN_TIMEOUT - elapsed)
# 1. Broadcast shutdown notification via WebSocket
try:
ws_service = get_websocket_service()
logger.info("Broadcasting shutdown notification to WebSocket clients...")
await asyncio.wait_for(
ws_service.shutdown(timeout=min(5.0, remaining_time())),
timeout=min(5.0, remaining_time())
)
logger.info("WebSocket shutdown complete")
except asyncio.TimeoutError:
logger.warning("WebSocket shutdown timed out")
except Exception as e: # pylint: disable=broad-exception-caught
logger.error("Error during WebSocket shutdown: %s", e, exc_info=True)
# 2. Shutdown download service and persist active downloads
try:
from src.server.services.download_service import ( # noqa: E501
_download_service_instance,
)
if _download_service_instance is not None:
logger.info("Stopping download service...")
await _download_service_instance.stop()
await asyncio.wait_for(
_download_service_instance.stop(timeout=min(10.0, remaining_time())),
timeout=min(15.0, remaining_time())
)
logger.info("Download service stopped successfully")
except asyncio.TimeoutError:
logger.warning("Download service shutdown timed out")
except Exception as e: # pylint: disable=broad-exception-caught
logger.error("Error stopping download service: %s", e, exc_info=True)
# Close database connections
# 3. Cleanup progress service
try:
progress_service = get_progress_service()
logger.info("Cleaning up progress service...")
# Clear any active progress tracking and subscribers
progress_service._subscribers.clear()
progress_service._active_progress.clear()
logger.info("Progress service cleanup complete")
except Exception as e: # pylint: disable=broad-exception-caught
logger.error("Error cleaning up progress service: %s", e, exc_info=True)
# 4. Close database connections with WAL checkpoint
try:
from src.server.database.connection import close_db
await close_db()
logger.info("Closing database connections...")
await asyncio.wait_for(
close_db(),
timeout=min(10.0, remaining_time())
)
logger.info("Database connections closed")
except asyncio.TimeoutError:
logger.warning("Database shutdown timed out")
except Exception as e: # pylint: disable=broad-exception-caught
logger.error("Error closing database: %s", e, exc_info=True)
logger.info("FastAPI application shutdown complete")
elapsed_total = time.monotonic() - shutdown_start
logger.info(
"FastAPI application shutdown complete (took %.2fs)",
elapsed_total
)
# Initialize FastAPI app with lifespan

View File

@@ -997,30 +997,76 @@ class DownloadService:
"""
logger.info("Download queue service initialized")
async def stop(self) -> None:
"""Stop the download queue service and cancel active downloads.
async def stop(self, timeout: float = 10.0) -> None:
"""Stop the download queue service gracefully.
Cancels any active download and shuts down the thread pool immediately.
Persists in-progress downloads back to pending state, cancels active
tasks, and shuts down the thread pool with a timeout.
Args:
timeout: Maximum time (seconds) to wait for executor shutdown
"""
logger.info("Stopping download queue service...")
logger.info("Stopping download queue service (timeout=%.1fs)...", timeout)
# Set shutdown flag
# Set shutdown flag first to prevent new downloads
self._is_shutting_down = True
self._is_stopped = True
# Persist active download back to pending state if one exists
if self._active_download:
logger.info(
"Persisting active download to pending: item_id=%s",
self._active_download.id
)
try:
# Reset status to pending so it can be resumed on restart
self._active_download.status = DownloadStatus.PENDING
self._active_download.completed_at = None
await self._save_to_database(self._active_download)
logger.info("Active download persisted to database as pending")
except Exception as e:
logger.error("Failed to persist active download: %s", e)
# Cancel active download task if running
active_task = self._active_download_task
if active_task and not active_task.done():
logger.info("Cancelling active download task...")
active_task.cancel()
try:
await active_task
# Wait briefly for cancellation to complete
await asyncio.wait_for(
asyncio.shield(active_task),
timeout=2.0
)
except asyncio.TimeoutError:
logger.warning("Download task cancellation timed out")
except asyncio.CancelledError:
logger.info("Active download task cancelled")
except Exception as e:
logger.warning("Error during task cancellation: %s", e)
# Shutdown executor immediately, don't wait for tasks
# Shutdown executor with wait and timeout
logger.info("Shutting down thread pool executor...")
self._executor.shutdown(wait=False, cancel_futures=True)
try:
# Run executor shutdown in thread to avoid blocking event loop
loop = asyncio.get_event_loop()
await asyncio.wait_for(
loop.run_in_executor(
None,
lambda: self._executor.shutdown(wait=True, cancel_futures=True)
),
timeout=timeout
)
logger.info("Thread pool executor shutdown complete")
except asyncio.TimeoutError:
logger.warning(
"Executor shutdown timed out after %.1fs, forcing shutdown",
timeout
)
# Force shutdown without waiting
self._executor.shutdown(wait=False, cancel_futures=True)
except Exception as e:
logger.error("Error during executor shutdown: %s", e)
logger.info("Download queue service stopped")

View File

@@ -322,6 +322,85 @@ class ConnectionManager:
connection_id=connection_id,
)
async def shutdown(self, timeout: float = 5.0) -> None:
"""Gracefully shutdown all WebSocket connections.
Broadcasts a shutdown notification to all clients, then closes
each connection with proper close codes.
Args:
timeout: Maximum time (seconds) to wait for all closes to complete
"""
logger.info(
"Initiating WebSocket shutdown, connections=%d",
len(self._active_connections)
)
# Broadcast shutdown notification to all clients
shutdown_message = {
"type": "server_shutdown",
"timestamp": datetime.now(timezone.utc).isoformat(),
"data": {
"message": "Server is shutting down",
"reason": "graceful_shutdown",
},
}
try:
await self.broadcast(shutdown_message)
except Exception as e:
logger.warning("Failed to broadcast shutdown message: %s", e)
# Close all connections gracefully
async with self._lock:
connection_ids = list(self._active_connections.keys())
close_tasks = []
for connection_id in connection_ids:
websocket = self._active_connections.get(connection_id)
if websocket:
close_tasks.append(
self._close_connection_gracefully(connection_id, websocket)
)
if close_tasks:
# Wait for all closes with timeout
try:
await asyncio.wait_for(
asyncio.gather(*close_tasks, return_exceptions=True),
timeout=timeout
)
except asyncio.TimeoutError:
logger.warning(
"WebSocket shutdown timed out after %.1f seconds", timeout
)
# Clear all data structures
async with self._lock:
self._active_connections.clear()
self._rooms.clear()
self._connection_metadata.clear()
logger.info("WebSocket shutdown complete")
async def _close_connection_gracefully(
self, connection_id: str, websocket: WebSocket
) -> None:
"""Close a single WebSocket connection gracefully.
Args:
connection_id: The connection identifier
websocket: The WebSocket connection to close
"""
try:
# Code 1001 = Going Away (server shutdown)
await websocket.close(code=1001, reason="Server shutdown")
logger.debug("Closed WebSocket connection: %s", connection_id)
except Exception as e:
logger.debug(
"Error closing WebSocket %s: %s", connection_id, str(e)
)
class WebSocketService:
"""High-level WebSocket service for application-wide messaging.
@@ -579,6 +658,18 @@ class WebSocketService:
elapsed_seconds=round(elapsed_seconds, 2),
)
async def shutdown(self, timeout: float = 5.0) -> None:
"""Gracefully shutdown the WebSocket service.
Broadcasts shutdown notification and closes all connections.
Args:
timeout: Maximum time (seconds) to wait for shutdown
"""
logger.info("Shutting down WebSocket service...")
await self._manager.shutdown(timeout=timeout)
logger.info("WebSocket service shutdown complete")
# Singleton instance for application-wide access
_websocket_service: Optional[WebSocketService] = None