From 95f72018f7d95eef131e1ce8df38df1ea0066caa Mon Sep 17 00:00:00 2001 From: Lukas Date: Mon, 6 Apr 2026 20:56:57 +0200 Subject: [PATCH] Add backend lifecycle regression tests and fix lifespan cleanup --- Docs/Tasks.md | 4 +- backend/app/main.py | 13 +++-- backend/tests/test_main.py | 107 ++++++++++++++++++++++++++++++++++++- 3 files changed, 114 insertions(+), 10 deletions(-) diff --git a/Docs/Tasks.md b/Docs/Tasks.md index fa5033c..18121c0 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -39,7 +39,7 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue. ### Reliability and Resilience -- **Add backend lifecycle tests for resource cleanup.** +- **Add backend lifecycle tests for resource cleanup.** ✅ - Verify startup opens and initialises DB, HTTP session, scheduler, and geo cache correctly. - Verify shutdown closes those resources cleanly. @@ -77,4 +77,4 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue. 3. ✅ Harden fail2ban client concurrency and packaging. 4. ✅ Convert setup guard to a safer startup-driven model. 5. ✅ Add deployment-safe configuration and production-ready CORS. -6. Add lifecycle and concurrency regression tests. +6. ✅ Add lifecycle and concurrency regression tests. diff --git a/backend/app/main.py b/backend/app/main.py index 2ba771c..5174c78 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -121,18 +121,18 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]: from app.services import geo_service # noqa: PLC0415 log.debug("database_directory_ensured", directory=str(db_path.parent)) - db = await open_db(settings.database_path) + startup_db = await open_db(settings.database_path) try: - await init_db(db) - await geo_service.load_cache_from_db(db) - unresolved_count = await geo_service.count_unresolved(db) + await init_db(startup_db) + await geo_service.load_cache_from_db(startup_db) + unresolved_count = await geo_service.count_unresolved(startup_db) from app.services import setup_service # noqa: PLC0415 - setup_complete = await setup_service.is_setup_complete(db) + setup_complete = await setup_service.is_setup_complete(startup_db) set_setup_complete_cache(app, setup_complete) log.debug("setup_completion_cached", completed=setup_complete) finally: - await db.close() + await startup_db.close() if unresolved_count > 0: log.warning("geo_cache_unresolved_ips", unresolved=unresolved_count) @@ -172,7 +172,6 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]: log.info("bangui_shutting_down") scheduler.shutdown(wait=False) await http_session.close() - await db.close() log.info("bangui_shut_down") diff --git a/backend/tests/test_main.py b/backend/tests/test_main.py index 3c07699..e5e9c8c 100644 --- a/backend/tests/test_main.py +++ b/backend/tests/test_main.py @@ -1,7 +1,13 @@ """Unit tests for backend application startup and middleware configuration.""" +import asyncio +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +from httpx import ASGITransport, AsyncClient + from app.config import Settings -from app.main import CORSMiddleware, create_app +from app.main import CORSMiddleware, _lifespan, create_app def test_create_app_configures_cors_from_settings() -> None: @@ -68,3 +74,102 @@ def test_create_app_disables_cors_by_default() -> None: ] assert cors_middleware == [] + + +async def test_lifespan_initialises_and_cleans_up_shared_resources(tmp_path: Path) -> None: + """The app lifespan creates and shuts down shared resources cleanly.""" + settings = Settings( + database_path=str(tmp_path / "bangui.db"), + fail2ban_socket="/tmp/fake_fail2ban.sock", + fail2ban_config_dir=str(tmp_path / "fail2ban"), + session_secret="test-lifespan-secret", + session_duration_minutes=60, + timezone="UTC", + log_level="debug", + ) + app = create_app(settings=settings) + + mock_scheduler = MagicMock() + mock_scheduler.start = MagicMock() + mock_scheduler.shutdown = MagicMock() + + mock_http_session = MagicMock() + mock_http_session.close = AsyncMock() + + with ( + patch("app.main.ensure_jail_configs"), + patch("app.main.aiohttp.ClientSession", return_value=mock_http_session), + patch("app.main.AsyncIOScheduler", return_value=mock_scheduler), + patch("app.main.init_db", new=AsyncMock()), + patch("app.services.geo_service.init_geoip"), + patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)), + patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)), + patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)), + patch("app.tasks.health_check.register"), + patch("app.tasks.blocklist_import.register"), + patch("app.tasks.geo_cache_flush.register"), + patch("app.tasks.geo_re_resolve.register"), + patch("app.tasks.history_sync.register"), + ): + async with _lifespan(app): + assert app.state.http_session is mock_http_session + assert app.state.scheduler is mock_scheduler + assert app.state.settings is settings + + mock_http_session.close.assert_awaited_once() + mock_scheduler.shutdown.assert_called_once_with(wait=False) + + +async def test_concurrent_requests_use_request_scoped_db_connections(tmp_path: Path) -> None: + """Concurrent requests each open and close their own database connection.""" + settings = Settings( + database_path=str(tmp_path / "bangui.db"), + fail2ban_socket="/tmp/fake_fail2ban.sock", + fail2ban_config_dir=str(tmp_path / "fail2ban"), + session_secret="test-concurrency-secret", + session_duration_minutes=60, + timezone="UTC", + log_level="debug", + ) + app = create_app(settings=settings) + + connections: list[MagicMock] = [] + + async def fake_open_db(database_path: str) -> MagicMock: + connection = MagicMock() + connection.close = AsyncMock() + connections.append(connection) + return connection + + mock_scheduler = MagicMock() + mock_scheduler.start = MagicMock() + mock_scheduler.shutdown = MagicMock() + + mock_http_session = MagicMock() + mock_http_session.close = AsyncMock() + + with ( + patch("app.main.open_db", new=AsyncMock(side_effect=fake_open_db)), + patch("app.db.open_db", new=AsyncMock(side_effect=fake_open_db)), + patch("app.main.init_db", new=AsyncMock()), + patch("app.main.ensure_jail_configs"), + patch("app.main.aiohttp.ClientSession", return_value=mock_http_session), + patch("app.main.AsyncIOScheduler", return_value=mock_scheduler), + patch("app.services.geo_service.init_geoip"), + patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)), + patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)), + patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)), + patch("app.tasks.health_check.register"), + patch("app.tasks.blocklist_import.register"), + patch("app.tasks.geo_cache_flush.register"), + patch("app.tasks.geo_re_resolve.register"), + patch("app.tasks.history_sync.register"), + ): + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + responses = await asyncio.gather(*(client.get("/api/setup") for _ in range(5))) + + assert len(connections) == 5 + assert len({id(connection) for connection in connections}) == 5 + assert all(response.status_code == 200 for response in responses) + assert all(connection.close.await_count == 1 for connection in connections)