refactoring-backend #3

Merged
lukas.pupkalipinski merged 403 commits from refactoring-backend into main 2026-05-20 20:23:46 +02:00
3 changed files with 114 additions and 10 deletions
Showing only changes of commit 95f72018f7 - Show all commits

View File

@@ -39,7 +39,7 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue.
### Reliability and Resilience
- **Add backend lifecycle tests for resource cleanup.**
- **Add backend lifecycle tests for resource cleanup.**
- Verify startup opens and initialises DB, HTTP session, scheduler, and geo cache correctly.
- Verify shutdown closes those resources cleanly.
@@ -77,4 +77,4 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue.
3. ✅ Harden fail2ban client concurrency and packaging.
4. ✅ Convert setup guard to a safer startup-driven model.
5. ✅ Add deployment-safe configuration and production-ready CORS.
6. Add lifecycle and concurrency regression tests.
6. Add lifecycle and concurrency regression tests.

View File

@@ -121,18 +121,18 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
from app.services import geo_service # noqa: PLC0415
log.debug("database_directory_ensured", directory=str(db_path.parent))
db = await open_db(settings.database_path)
startup_db = await open_db(settings.database_path)
try:
await init_db(db)
await geo_service.load_cache_from_db(db)
unresolved_count = await geo_service.count_unresolved(db)
await init_db(startup_db)
await geo_service.load_cache_from_db(startup_db)
unresolved_count = await geo_service.count_unresolved(startup_db)
from app.services import setup_service # noqa: PLC0415
setup_complete = await setup_service.is_setup_complete(db)
setup_complete = await setup_service.is_setup_complete(startup_db)
set_setup_complete_cache(app, setup_complete)
log.debug("setup_completion_cached", completed=setup_complete)
finally:
await db.close()
await startup_db.close()
if unresolved_count > 0:
log.warning("geo_cache_unresolved_ips", unresolved=unresolved_count)
@@ -172,7 +172,6 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
log.info("bangui_shutting_down")
scheduler.shutdown(wait=False)
await http_session.close()
await db.close()
log.info("bangui_shut_down")

View File

@@ -1,7 +1,13 @@
"""Unit tests for backend application startup and middleware configuration."""
import asyncio
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
from httpx import ASGITransport, AsyncClient
from app.config import Settings
from app.main import CORSMiddleware, create_app
from app.main import CORSMiddleware, _lifespan, create_app
def test_create_app_configures_cors_from_settings() -> None:
@@ -68,3 +74,102 @@ def test_create_app_disables_cors_by_default() -> None:
]
assert cors_middleware == []
async def test_lifespan_initialises_and_cleans_up_shared_resources(tmp_path: Path) -> None:
"""The app lifespan creates and shuts down shared resources cleanly."""
settings = Settings(
database_path=str(tmp_path / "bangui.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(tmp_path / "fail2ban"),
session_secret="test-lifespan-secret",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=settings)
mock_scheduler = MagicMock()
mock_scheduler.start = MagicMock()
mock_scheduler.shutdown = MagicMock()
mock_http_session = MagicMock()
mock_http_session.close = AsyncMock()
with (
patch("app.main.ensure_jail_configs"),
patch("app.main.aiohttp.ClientSession", return_value=mock_http_session),
patch("app.main.AsyncIOScheduler", return_value=mock_scheduler),
patch("app.main.init_db", new=AsyncMock()),
patch("app.services.geo_service.init_geoip"),
patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)),
patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)),
patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)),
patch("app.tasks.health_check.register"),
patch("app.tasks.blocklist_import.register"),
patch("app.tasks.geo_cache_flush.register"),
patch("app.tasks.geo_re_resolve.register"),
patch("app.tasks.history_sync.register"),
):
async with _lifespan(app):
assert app.state.http_session is mock_http_session
assert app.state.scheduler is mock_scheduler
assert app.state.settings is settings
mock_http_session.close.assert_awaited_once()
mock_scheduler.shutdown.assert_called_once_with(wait=False)
async def test_concurrent_requests_use_request_scoped_db_connections(tmp_path: Path) -> None:
"""Concurrent requests each open and close their own database connection."""
settings = Settings(
database_path=str(tmp_path / "bangui.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(tmp_path / "fail2ban"),
session_secret="test-concurrency-secret",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=settings)
connections: list[MagicMock] = []
async def fake_open_db(database_path: str) -> MagicMock:
connection = MagicMock()
connection.close = AsyncMock()
connections.append(connection)
return connection
mock_scheduler = MagicMock()
mock_scheduler.start = MagicMock()
mock_scheduler.shutdown = MagicMock()
mock_http_session = MagicMock()
mock_http_session.close = AsyncMock()
with (
patch("app.main.open_db", new=AsyncMock(side_effect=fake_open_db)),
patch("app.db.open_db", new=AsyncMock(side_effect=fake_open_db)),
patch("app.main.init_db", new=AsyncMock()),
patch("app.main.ensure_jail_configs"),
patch("app.main.aiohttp.ClientSession", return_value=mock_http_session),
patch("app.main.AsyncIOScheduler", return_value=mock_scheduler),
patch("app.services.geo_service.init_geoip"),
patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)),
patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)),
patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)),
patch("app.tasks.health_check.register"),
patch("app.tasks.blocklist_import.register"),
patch("app.tasks.geo_cache_flush.register"),
patch("app.tasks.geo_re_resolve.register"),
patch("app.tasks.history_sync.register"),
):
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
responses = await asyncio.gather(*(client.get("/api/setup") for _ in range(5)))
assert len(connections) == 5
assert len({id(connection) for connection in connections}) == 5
assert all(response.status_code == 200 for response in responses)
assert all(connection.close.await_count == 1 for connection in connections)