From 5e5d7c34b2e422775d7f361a17572db328921f4d Mon Sep 17 00:00:00 2001 From: Lukas Date: Fri, 17 Apr 2026 17:18:49 +0200 Subject: [PATCH] Document task DB access and unify background task DB handling --- Docs/Architekture.md | 9 ++++ Docs/Tasks.md | 2 + backend/app/tasks/blocklist_import.py | 37 +++++----------- backend/app/tasks/db.py | 29 ++++++++++++ backend/app/tasks/geo_cache_flush.py | 25 +++-------- backend/app/tasks/geo_re_resolve.py | 40 ++++++----------- backend/app/tasks/history_sync.py | 8 ++-- .../tests/test_tasks/test_blocklist_import.py | 12 ++--- .../tests/test_tasks/test_geo_cache_flush.py | 7 +-- .../tests/test_tasks/test_geo_re_resolve.py | 13 ++++-- backend/tests/test_tasks/test_history_sync.py | 3 +- backend/tests/test_tasks/test_task_db.py | 44 +++++++++++++++++++ 12 files changed, 139 insertions(+), 90 deletions(-) create mode 100644 backend/app/tasks/db.py create mode 100644 backend/tests/test_tasks/test_task_db.py diff --git a/Docs/Architekture.md b/Docs/Architekture.md index 0413f98..7d2fb4b 100644 --- a/Docs/Architekture.md +++ b/Docs/Architekture.md @@ -687,6 +687,15 @@ APScheduler 4.x (async mode) manages recurring background tasks. --- +## 7.1 Background Tasks and Database Access + +- APScheduler jobs run outside FastAPI request/response scope and therefore cannot rely on ``Depends(get_db)``. +- Background tasks must open their own application database connection via ``app.db.open_db`` and close it when the work completes. +- Use a shared task helper (``app.tasks.db.task_db``) so every task follows the same async context manager pattern and avoids connection leaks. +- This pattern is intentional: task code is structurally separate from request-handling dependencies and should not attempt to reuse request-scoped DB connections. + +--- + ## 8. API Design ### 8.1 Conventions diff --git a/Docs/Tasks.md b/Docs/Tasks.md index 293de8b..719be96 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -406,4 +406,6 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue. **Docs changes needed:** Add a "Background Tasks and Database Access" section to `Docs/Architekture.md` explaining why tasks own their connections and how to write a new task. +**Status:** Completed ✅ + **Why this is needed:** Without explanation, the inconsistency between router DI-provided connections and task-managed connections looks like an oversight. Documentation prevents future developers from incorrectly trying to inject a DB connection into a task via `Depends`, which would fail silently at runtime. diff --git a/backend/app/tasks/blocklist_import.py b/backend/app/tasks/blocklist_import.py index ae0f290..87a828d 100644 --- a/backend/app/tasks/blocklist_import.py +++ b/backend/app/tasks/blocklist_import.py @@ -17,12 +17,11 @@ from typing import TYPE_CHECKING, Any import structlog -from app.db import open_db from app.services import ban_service, blocklist_service +from app.tasks.db import task_db from app.utils.runtime_state import get_effective_settings if TYPE_CHECKING: - import aiosqlite from aiohttp import ClientSession from fastapi import FastAPI @@ -34,11 +33,6 @@ log: structlog.stdlib.BoundLogger = structlog.get_logger() JOB_ID: str = "blocklist_import" -async def _get_db(settings: Settings) -> tuple[aiosqlite.Connection, bool]: - db = await open_db(settings.database_path) - return db, True - - async def _run_import_with_resources(settings: Settings, http_session: ClientSession) -> None: """APScheduler callback that imports all enabled blocklist sources. @@ -46,17 +40,17 @@ async def _run_import_with_resources(settings: Settings, http_session: ClientSes settings: The resolved application settings used for database access. http_session: The shared aiohttp session used for blocklist downloads. """ - db, close_db = await _get_db(settings) socket_path: str = settings.fail2ban_socket log.info("blocklist_import_starting") try: - result = await blocklist_service.import_all( - db, - http_session, - socket_path, - ban_ip=ban_service.ban_ip, - ) + async with task_db(settings) as db: + result = await blocklist_service.import_all( + db, + http_session, + socket_path, + ban_ip=ban_service.ban_ip, + ) log.info( "blocklist_import_finished", total_imported=result.total_imported, @@ -65,9 +59,6 @@ async def _run_import_with_resources(settings: Settings, http_session: ClientSes ) except Exception: log.exception("blocklist_import_unexpected_error") - finally: - if close_db: - await db.close() run_import_with_resources = _run_import_with_resources @@ -91,12 +82,8 @@ async def register(app: FastAPI) -> None: ``app.state.scheduler`` will receive the job. """ settings = get_effective_settings(app) - db, close_db = await _get_db(settings) - try: + async with task_db(settings) as db: config = await blocklist_service.get_schedule(db) - finally: - if close_db: - await db.close() _apply_schedule(app, config) @@ -114,12 +101,8 @@ def reschedule(app: FastAPI) -> None: async def _do_reschedule() -> None: settings = get_effective_settings(app) - db, close_db = await _get_db(settings) - try: + async with task_db(settings) as db: config = await blocklist_service.get_schedule(db) - finally: - if close_db: - await db.close() _apply_schedule(app, config) asyncio.ensure_future(_do_reschedule()) diff --git a/backend/app/tasks/db.py b/backend/app/tasks/db.py new file mode 100644 index 0000000..56163c8 --- /dev/null +++ b/backend/app/tasks/db.py @@ -0,0 +1,29 @@ +"""Shared database helpers for APScheduler background tasks.""" + +from __future__ import annotations + +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING + +from app.db import open_db + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + + import aiosqlite + + from app.config import Settings + + +@asynccontextmanager +async def task_db(settings: Settings) -> AsyncIterator[aiosqlite.Connection]: + """Open a dedicated application database connection for a background task. + + Background tasks run outside FastAPI request scope and therefore must + manage their own SQLite connection instead of using FastAPI dependencies. + """ + db = await open_db(settings.database_path) + try: + yield db + finally: + await db.close() diff --git a/backend/app/tasks/geo_cache_flush.py b/backend/app/tasks/geo_cache_flush.py index a8449a9..2d3f7f8 100644 --- a/backend/app/tasks/geo_cache_flush.py +++ b/backend/app/tasks/geo_cache_flush.py @@ -11,21 +11,19 @@ at risk on an unexpected process restart. from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import structlog -from app.db import open_db +from app.services import geo_service +from app.tasks.db import task_db from app.utils.runtime_state import get_effective_settings -if TYPE_CHECKING: - import aiosqlite - from app.config import Settings -from app.services import geo_service - if TYPE_CHECKING: from fastapi import FastAPI + from app.config import Settings + log: structlog.stdlib.BoundLogger = structlog.get_logger() #: How often the flush job fires (seconds). Configurable tuning constant. @@ -35,23 +33,14 @@ GEO_FLUSH_INTERVAL: int = 60 JOB_ID: str = "geo_cache_flush" -async def _get_db(settings: "Settings") -> tuple[aiosqlite.Connection, bool]: - db = await open_db(settings.database_path) - return db, True - - -async def _run_flush_with_settings(settings: "Settings") -> None: +async def _run_flush_with_settings(settings: Settings) -> None: """Flush the geo service dirty set to the application database. Args: settings: The resolved application settings used for database access. """ - db, close_db = await _get_db(settings) - try: + async with task_db(settings) as db: count = await geo_service.flush_dirty(db) - finally: - if close_db: - await db.close() if count > 0: log.debug("geo_cache_flush_ran", flushed=count) diff --git a/backend/app/tasks/geo_re_resolve.py b/backend/app/tasks/geo_re_resolve.py index 53e1ed9..67847b3 100644 --- a/backend/app/tasks/geo_re_resolve.py +++ b/backend/app/tasks/geo_re_resolve.py @@ -21,18 +21,16 @@ from typing import TYPE_CHECKING import structlog -from app.db import open_db +from app.services import geo_service +from app.tasks.db import task_db from app.utils.runtime_state import get_effective_settings if TYPE_CHECKING: - import aiosqlite from aiohttp import ClientSession - from app.config import Settings -from app.services import geo_service - -if TYPE_CHECKING: from fastapi import FastAPI + from app.config import Settings + log: structlog.stdlib.BoundLogger = structlog.get_logger() #: How often the re-resolve job fires (seconds). 10 minutes. @@ -42,21 +40,14 @@ GEO_RE_RESOLVE_INTERVAL: int = 600 JOB_ID: str = "geo_re_resolve" -async def _get_db(settings: "Settings") -> tuple[aiosqlite.Connection, bool]: - db = await open_db(settings.database_path) - return db, True - - -async def _run_re_resolve_with_resources(settings: "Settings", http_session: "ClientSession") -> None: +async def _run_re_resolve_with_resources(settings: Settings, http_session: ClientSession) -> None: """Query NULL-country IPs from the database and re-resolve them. Args: settings: The resolved application settings used for database access. http_session: The shared aiohttp session used for external lookups. """ - db, close_db = await _get_db(settings) - - try: + async with task_db(settings) as db: # Fetch all IPs with NULL country_code from the persistent cache. unresolved_ips = await geo_service.get_unresolved_ips(db) @@ -73,17 +64,14 @@ async def _run_re_resolve_with_resources(settings: "Settings", http_session: "Cl # passed. This is a background task so DB writes are allowed. results = await geo_service.lookup_batch(unresolved_ips, http_session, db=db) - resolved_count: int = sum( - 1 for info in results.values() if info.country_code is not None - ) - log.info( - "geo_re_resolve_complete", - retried=len(unresolved_ips), - resolved=resolved_count, - ) - finally: - if close_db: - await db.close() + resolved_count: int = sum( + 1 for info in results.values() if info.country_code is not None + ) + log.info( + "geo_re_resolve_complete", + retried=len(unresolved_ips), + resolved=resolved_count, + ) async def _run_re_resolve(app: FastAPI) -> None: diff --git a/backend/app/tasks/history_sync.py b/backend/app/tasks/history_sync.py index 32b84c1..ac1a422 100644 --- a/backend/app/tasks/history_sync.py +++ b/backend/app/tasks/history_sync.py @@ -11,8 +11,8 @@ from typing import TYPE_CHECKING import structlog -from app.db import open_db from app.services import history_service +from app.tasks.db import task_db from app.utils.runtime_state import get_effective_settings if TYPE_CHECKING: @@ -34,15 +34,13 @@ BACKFILL_WINDOW: int = 648000 async def _run_sync_with_settings(settings: Settings) -> None: socket_path: str = settings.fail2ban_socket - db = await open_db(settings.database_path) try: - synced = await history_service.sync_from_fail2ban_db(db, socket_path) + async with task_db(settings) as db: + synced = await history_service.sync_from_fail2ban_db(db, socket_path) log.info("history_sync_complete", synced=synced) except Exception: log.exception("history_sync_failed") - finally: - await db.close() async def _run_sync(app: FastAPI) -> None: diff --git a/backend/tests/test_tasks/test_blocklist_import.py b/backend/tests/test_tasks/test_blocklist_import.py index f1134ee..980daea 100644 --- a/backend/tests/test_tasks/test_blocklist_import.py +++ b/backend/tests/test_tasks/test_blocklist_import.py @@ -98,7 +98,7 @@ class TestRunImport: result = _make_import_result(total_imported=100, total_skipped=2, errors_count=0) with patch( - "app.tasks.blocklist_import.open_db", + "app.tasks.db.open_db", new_callable=AsyncMock, return_value=app.state.db, ), patch( @@ -122,7 +122,7 @@ class TestRunImport: result = _make_import_result(total_imported=42, total_skipped=3, errors_count=1) with patch( - "app.tasks.blocklist_import.open_db", + "app.tasks.db.open_db", new_callable=AsyncMock, return_value=app.state.db, ), patch( @@ -146,7 +146,7 @@ class TestRunImport: result = _make_import_result() with patch( - "app.tasks.blocklist_import.open_db", + "app.tasks.db.open_db", new_callable=AsyncMock, return_value=app.state.db, ), patch( @@ -165,7 +165,7 @@ class TestRunImport: app = _make_app() with patch( - "app.tasks.blocklist_import.open_db", + "app.tasks.db.open_db", new_callable=AsyncMock, return_value=app.state.db, ), patch( @@ -323,7 +323,7 @@ class TestRegister: config = ScheduleConfig(frequency=ScheduleFrequency.daily, hour=3, minute=0) with patch( - "app.tasks.blocklist_import.open_db", + "app.tasks.db.open_db", new_callable=AsyncMock, return_value=app.state.db, ), patch( @@ -353,7 +353,7 @@ class TestReschedule: coro.close() with patch( - "app.tasks.blocklist_import.open_db", + "app.tasks.db.open_db", new_callable=AsyncMock, return_value=app.state.db, ), patch("asyncio.ensure_future", side_effect=_close_coro) as mock_ensure_future: diff --git a/backend/tests/test_tasks/test_geo_cache_flush.py b/backend/tests/test_tasks/test_geo_cache_flush.py index 078f341..d0569ac 100644 --- a/backend/tests/test_tasks/test_geo_cache_flush.py +++ b/backend/tests/test_tasks/test_geo_cache_flush.py @@ -34,6 +34,7 @@ def _make_app(flush_count: int = 0) -> MagicMock: app.state.db.close = AsyncMock() app.state.scheduler = MagicMock() app.state.settings = MagicMock(database_path="/tmp/fake.db") + app.state.runtime_settings = None return app @@ -51,7 +52,7 @@ class TestRunFlush: app = _make_app() with patch( - "app.tasks.geo_cache_flush.open_db", + "app.tasks.db.open_db", new_callable=AsyncMock, return_value=app.state.db, ), patch( @@ -69,7 +70,7 @@ class TestRunFlush: app = _make_app() with patch( - "app.tasks.geo_cache_flush.open_db", + "app.tasks.db.open_db", new_callable=AsyncMock, return_value=app.state.db, ), patch( @@ -89,7 +90,7 @@ class TestRunFlush: app = _make_app() with patch( - "app.tasks.geo_cache_flush.open_db", + "app.tasks.db.open_db", new_callable=AsyncMock, return_value=app.state.db, ), patch( diff --git a/backend/tests/test_tasks/test_geo_re_resolve.py b/backend/tests/test_tasks/test_geo_re_resolve.py index dfdeaab..5f188df 100644 --- a/backend/tests/test_tasks/test_geo_re_resolve.py +++ b/backend/tests/test_tasks/test_geo_re_resolve.py @@ -70,6 +70,7 @@ def _make_app( app.state.db = db app.state.http_session = http_session app.state.settings = MagicMock(database_path="/tmp/fake.db") + app.state.runtime_settings = None return app @@ -80,7 +81,7 @@ async def test_run_re_resolve_no_unresolved_ips_skips() -> None: app = _make_app(unresolved_ips=[]) with patch( - "app.tasks.geo_re_resolve.open_db", + "app.tasks.db.open_db", new_callable=AsyncMock, return_value=app.state.db, ), patch("app.tasks.geo_re_resolve.geo_service") as mock_geo: @@ -104,6 +105,7 @@ async def test_run_re_resolve_clears_neg_cache() -> None: with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo: mock_geo.get_unresolved_ips = AsyncMock(return_value=ips) + mock_geo.clear_neg_cache = AsyncMock() mock_geo.lookup_batch = AsyncMock(return_value=result) await _run_re_resolve(app) @@ -122,11 +124,12 @@ async def test_run_re_resolve_calls_lookup_batch_with_db() -> None: app = _make_app(unresolved_ips=ips, lookup_result=result) with patch( - "app.tasks.geo_re_resolve.open_db", + "app.tasks.db.open_db", new_callable=AsyncMock, return_value=app.state.db, ), patch("app.tasks.geo_re_resolve.geo_service") as mock_geo: mock_geo.get_unresolved_ips = AsyncMock(return_value=ips) + mock_geo.clear_neg_cache = AsyncMock() mock_geo.lookup_batch = AsyncMock(return_value=result) await _run_re_resolve(app) @@ -150,11 +153,12 @@ async def test_run_re_resolve_logs_correct_counts(caplog: Any) -> None: app = _make_app(unresolved_ips=ips, lookup_result=result) with patch( - "app.tasks.geo_re_resolve.open_db", + "app.tasks.db.open_db", new_callable=AsyncMock, return_value=app.state.db, ), patch("app.tasks.geo_re_resolve.geo_service") as mock_geo: mock_geo.get_unresolved_ips = AsyncMock(return_value=ips) + mock_geo.clear_neg_cache = AsyncMock() mock_geo.lookup_batch = AsyncMock(return_value=result) await _run_re_resolve(app) @@ -177,11 +181,12 @@ async def test_run_re_resolve_handles_all_resolved() -> None: app = _make_app(unresolved_ips=ips, lookup_result=result) with patch( - "app.tasks.geo_re_resolve.open_db", + "app.tasks.db.open_db", new_callable=AsyncMock, return_value=app.state.db, ), patch("app.tasks.geo_re_resolve.geo_service") as mock_geo: mock_geo.get_unresolved_ips = AsyncMock(return_value=ips) + mock_geo.clear_neg_cache = AsyncMock() mock_geo.lookup_batch = AsyncMock(return_value=result) await _run_re_resolve(app) diff --git a/backend/tests/test_tasks/test_history_sync.py b/backend/tests/test_tasks/test_history_sync.py index 5c77456..33dc4f9 100644 --- a/backend/tests/test_tasks/test_history_sync.py +++ b/backend/tests/test_tasks/test_history_sync.py @@ -37,12 +37,13 @@ class TestHistorySyncTask: fake_app.state.settings = type("FakeSettings", (), {})() fake_app.state.settings.fail2ban_socket = "/tmp/fake.sock" fake_app.state.settings.database_path = "/tmp/fake.db" + fake_app.state.runtime_settings = None fake_db = AsyncMock() fake_db.close = AsyncMock() with patch( - "app.tasks.history_sync.open_db", + "app.tasks.db.open_db", new_callable=AsyncMock, return_value=fake_db, ), patch( diff --git a/backend/tests/test_tasks/test_task_db.py b/backend/tests/test_tasks/test_task_db.py new file mode 100644 index 0000000..1395407 --- /dev/null +++ b/backend/tests/test_tasks/test_task_db.py @@ -0,0 +1,44 @@ +"""Tests for the shared background task database helper.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, patch + +import pytest + +from app.tasks.db import task_db + + +class FakeSettings: + database_path = "/tmp/fake.db" + + +@pytest.mark.asyncio +async def test_task_db_opens_and_closes_connection() -> None: + """``task_db`` must open a DB connection and close it after use.""" + fake_db = AsyncMock() + fake_db.close = AsyncMock() + + with patch("app.tasks.db.open_db", new_callable=AsyncMock, return_value=fake_db) as mock_open_db: + async with task_db(FakeSettings()) as db: + assert db is fake_db + + mock_open_db.assert_awaited_once_with("/tmp/fake.db") + fake_db.close.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_task_db_closes_connection_on_exception() -> None: + """``task_db`` must close the connection even when the body raises.""" + fake_db = AsyncMock() + fake_db.close = AsyncMock() + + with patch( + "app.tasks.db.open_db", + new_callable=AsyncMock, + return_value=fake_db, + ), pytest.raises(RuntimeError, match="boom"): + async with task_db(FakeSettings()): + raise RuntimeError("boom") + + fake_db.close.assert_awaited_once()