Add backend lifecycle regression tests and fix lifespan cleanup

This commit is contained in:
2026-04-06 20:56:57 +02:00
parent c2982116a8
commit 95f72018f7
3 changed files with 114 additions and 10 deletions

View File

@@ -1,7 +1,13 @@
"""Unit tests for backend application startup and middleware configuration."""
import asyncio
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
from httpx import ASGITransport, AsyncClient
from app.config import Settings
from app.main import CORSMiddleware, create_app
from app.main import CORSMiddleware, _lifespan, create_app
def test_create_app_configures_cors_from_settings() -> None:
@@ -68,3 +74,102 @@ def test_create_app_disables_cors_by_default() -> None:
]
assert cors_middleware == []
async def test_lifespan_initialises_and_cleans_up_shared_resources(tmp_path: Path) -> None:
"""The app lifespan creates and shuts down shared resources cleanly."""
settings = Settings(
database_path=str(tmp_path / "bangui.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(tmp_path / "fail2ban"),
session_secret="test-lifespan-secret",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=settings)
mock_scheduler = MagicMock()
mock_scheduler.start = MagicMock()
mock_scheduler.shutdown = MagicMock()
mock_http_session = MagicMock()
mock_http_session.close = AsyncMock()
with (
patch("app.main.ensure_jail_configs"),
patch("app.main.aiohttp.ClientSession", return_value=mock_http_session),
patch("app.main.AsyncIOScheduler", return_value=mock_scheduler),
patch("app.main.init_db", new=AsyncMock()),
patch("app.services.geo_service.init_geoip"),
patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)),
patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)),
patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)),
patch("app.tasks.health_check.register"),
patch("app.tasks.blocklist_import.register"),
patch("app.tasks.geo_cache_flush.register"),
patch("app.tasks.geo_re_resolve.register"),
patch("app.tasks.history_sync.register"),
):
async with _lifespan(app):
assert app.state.http_session is mock_http_session
assert app.state.scheduler is mock_scheduler
assert app.state.settings is settings
mock_http_session.close.assert_awaited_once()
mock_scheduler.shutdown.assert_called_once_with(wait=False)
async def test_concurrent_requests_use_request_scoped_db_connections(tmp_path: Path) -> None:
"""Concurrent requests each open and close their own database connection."""
settings = Settings(
database_path=str(tmp_path / "bangui.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(tmp_path / "fail2ban"),
session_secret="test-concurrency-secret",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=settings)
connections: list[MagicMock] = []
async def fake_open_db(database_path: str) -> MagicMock:
connection = MagicMock()
connection.close = AsyncMock()
connections.append(connection)
return connection
mock_scheduler = MagicMock()
mock_scheduler.start = MagicMock()
mock_scheduler.shutdown = MagicMock()
mock_http_session = MagicMock()
mock_http_session.close = AsyncMock()
with (
patch("app.main.open_db", new=AsyncMock(side_effect=fake_open_db)),
patch("app.db.open_db", new=AsyncMock(side_effect=fake_open_db)),
patch("app.main.init_db", new=AsyncMock()),
patch("app.main.ensure_jail_configs"),
patch("app.main.aiohttp.ClientSession", return_value=mock_http_session),
patch("app.main.AsyncIOScheduler", return_value=mock_scheduler),
patch("app.services.geo_service.init_geoip"),
patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)),
patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)),
patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)),
patch("app.tasks.health_check.register"),
patch("app.tasks.blocklist_import.register"),
patch("app.tasks.geo_cache_flush.register"),
patch("app.tasks.geo_re_resolve.register"),
patch("app.tasks.history_sync.register"),
):
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
responses = await asyncio.gather(*(client.get("/api/setup") for _ in range(5)))
assert len(connections) == 5
assert len({id(connection) for connection in connections}) == 5
assert all(response.status_code == 200 for response in responses)
assert all(connection.close.await_count == 1 for connection in connections)