diff --git a/Docs/Tasks.md b/Docs/Tasks.md index 03c03b6..bd9124e 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -20,7 +20,7 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue. - Fix: move service-dependent helpers into `app/services/` or extract shared logic into a new `app/helpers/` layer, keeping `app/utils/` purely independent. - Expected outcome: lower coupling between utility and service layers, cleaner dependency direction, and better maintainability. -- Background task modules in `backend/app/tasks/` check `app.state.db` before opening a DB connection, but no code sets `app.state.db`. That means the fast-path branch is dead and the DB handling is misleading. +- Status: **done** — background task modules in `backend/app/tasks/` no longer rely on the dead `app.state.db` fast-path and now open/close dedicated task-local DB connections using `app.state.settings.database_path`. - Fix: remove the unused `app.state.db` branch and always open/close a dedicated task-local connection, or intentionally add a shared DB connection to `app.state` and manage its lifecycle. - Expected outcome: background jobs have predictable DB lifecycle, avoid hidden bugs from stale connection assumptions, and task code is simpler. diff --git a/backend/app/tasks/blocklist_import.py b/backend/app/tasks/blocklist_import.py index 94609d6..1bc96b3 100644 --- a/backend/app/tasks/blocklist_import.py +++ b/backend/app/tasks/blocklist_import.py @@ -13,6 +13,7 @@ existing entry without creating duplicates. from __future__ import annotations +import inspect from typing import TYPE_CHECKING, Any import structlog @@ -34,10 +35,6 @@ JOB_ID: str = "blocklist_import" async def _get_db(app: Any) -> tuple[aiosqlite.Connection, bool]: - existing_db = getattr(app.state, "db", None) - if existing_db is not None: - return existing_db, False - db = await open_db(app.state.settings.database_path) return db, True @@ -102,12 +99,16 @@ def register(app: FastAPI) -> None: # APScheduler is synchronous at registration time; use asyncio to read # the stored schedule from the DB before registering. + coro = None try: loop = asyncio.get_event_loop() - loop.run_until_complete(_do_register()) + coro = _do_register() + loop.run_until_complete(coro) except RuntimeError: # If the current thread already has a running loop (uvicorn), schedule # the registration as a coroutine. + if coro is not None and inspect.getcoroutinestate(coro) != inspect.CORO_CLOSED: + coro.close() asyncio.ensure_future(_do_register()) diff --git a/backend/app/tasks/geo_cache_flush.py b/backend/app/tasks/geo_cache_flush.py index 4a5be29..5d818c9 100644 --- a/backend/app/tasks/geo_cache_flush.py +++ b/backend/app/tasks/geo_cache_flush.py @@ -34,10 +34,6 @@ JOB_ID: str = "geo_cache_flush" async def _get_db(app: Any) -> tuple[aiosqlite.Connection, bool]: - existing_db = getattr(app.state, "db", None) - if existing_db is not None: - return existing_db, False - db = await open_db(app.state.settings.database_path) return db, True diff --git a/backend/app/tasks/geo_re_resolve.py b/backend/app/tasks/geo_re_resolve.py index fcbfce2..5ad07fb 100644 --- a/backend/app/tasks/geo_re_resolve.py +++ b/backend/app/tasks/geo_re_resolve.py @@ -40,10 +40,6 @@ JOB_ID: str = "geo_re_resolve" async def _get_db(app: FastAPI) -> tuple[aiosqlite.Connection, bool]: - existing_db = getattr(app.state, "db", None) - if existing_db is not None: - return existing_db, False - db = await open_db(app.state.settings.database_path) return db, True diff --git a/backend/app/tasks/history_sync.py b/backend/app/tasks/history_sync.py index 162407c..5c8b6e4 100644 --- a/backend/app/tasks/history_sync.py +++ b/backend/app/tasks/history_sync.py @@ -34,10 +34,6 @@ BACKFILL_WINDOW: int = 648000 async def _get_db(app: FastAPI) -> tuple[aiosqlite.Connection, bool]: - existing_db = getattr(app.state, "db", None) - if existing_db is not None: - return existing_db, False - db = await open_db(app.state.settings.database_path) return db, True diff --git a/backend/tests/test_tasks/test_blocklist_import.py b/backend/tests/test_tasks/test_blocklist_import.py index b512601..a55fc4a 100644 --- a/backend/tests/test_tasks/test_blocklist_import.py +++ b/backend/tests/test_tasks/test_blocklist_import.py @@ -36,8 +36,12 @@ def _make_app( """ app = MagicMock() app.state.db = MagicMock() + app.state.db.close = AsyncMock() app.state.http_session = MagicMock() - app.state.settings.fail2ban_socket = "/var/run/fail2ban/fail2ban.sock" + app.state.settings = MagicMock( + fail2ban_socket="/var/run/fail2ban/fail2ban.sock", + database_path="/tmp/fake.db", + ) return app @@ -93,6 +97,10 @@ class TestRunImport: result = _make_import_result(total_imported=100, total_skipped=2, errors_count=0) with patch( + "app.tasks.blocklist_import.open_db", + new_callable=AsyncMock, + return_value=app.state.db, + ), patch( "app.tasks.blocklist_import.blocklist_service.import_all", new_callable=AsyncMock, return_value=result, @@ -112,6 +120,10 @@ class TestRunImport: result = _make_import_result(total_imported=42, total_skipped=3, errors_count=1) with patch( + "app.tasks.blocklist_import.open_db", + new_callable=AsyncMock, + return_value=app.state.db, + ), patch( "app.tasks.blocklist_import.blocklist_service.import_all", new_callable=AsyncMock, return_value=result, @@ -132,6 +144,10 @@ class TestRunImport: result = _make_import_result() with patch( + "app.tasks.blocklist_import.open_db", + new_callable=AsyncMock, + return_value=app.state.db, + ), patch( "app.tasks.blocklist_import.blocklist_service.import_all", new_callable=AsyncMock, return_value=result, @@ -147,6 +163,10 @@ class TestRunImport: app = _make_app() with patch( + "app.tasks.blocklist_import.open_db", + new_callable=AsyncMock, + return_value=app.state.db, + ), patch( "app.tasks.blocklist_import.blocklist_service.import_all", new_callable=AsyncMock, side_effect=RuntimeError("unexpected failure"), @@ -288,12 +308,18 @@ class TestRegister: app = MagicMock() app.state.db = MagicMock() + app.state.db.close = AsyncMock() + app.state.settings = MagicMock(database_path="/tmp/fake.db") app.state.scheduler = MagicMock() app.state.scheduler.get_job.return_value = None config = ScheduleConfig(frequency=ScheduleFrequency.daily, hour=3, minute=0) with patch( + "app.tasks.blocklist_import.open_db", + new_callable=AsyncMock, + return_value=app.state.db, + ), patch( "app.tasks.blocklist_import.blocklist_service.get_schedule", new_callable=AsyncMock, return_value=config, @@ -314,6 +340,8 @@ class TestRegister: app = MagicMock() app.state.db = MagicMock() + app.state.db.close = AsyncMock() + app.state.settings = MagicMock(database_path="/tmp/fake.db") app.state.scheduler = MagicMock() config = ScheduleConfig(frequency=ScheduleFrequency.daily) @@ -321,14 +349,22 @@ class TestRegister: mock_loop = MagicMock() mock_loop.run_until_complete.side_effect = RuntimeError("already running") + def _close_coro(coro: Any) -> None: + coro.close() + with ( + patch( + "app.tasks.blocklist_import.open_db", + new_callable=AsyncMock, + return_value=app.state.db, + ), patch( "app.tasks.blocklist_import.blocklist_service.get_schedule", new_callable=AsyncMock, return_value=config, ), patch("asyncio.get_event_loop", return_value=mock_loop), - patch("asyncio.ensure_future") as mock_ensure_future, + patch("asyncio.ensure_future", side_effect=_close_coro) as mock_ensure_future, ): register(app) @@ -344,9 +380,18 @@ class TestReschedule: app = MagicMock() app.state.db = MagicMock() + app.state.db.close = AsyncMock() + app.state.settings = MagicMock(database_path="/tmp/fake.db") app.state.scheduler = MagicMock() - with patch("asyncio.ensure_future") as mock_ensure_future: + def _close_coro(coro: Any) -> None: + coro.close() + + with patch( + "app.tasks.blocklist_import.open_db", + new_callable=AsyncMock, + return_value=app.state.db, + ), patch("asyncio.ensure_future", side_effect=_close_coro) as mock_ensure_future: reschedule(app) mock_ensure_future.assert_called_once() diff --git a/backend/tests/test_tasks/test_geo_cache_flush.py b/backend/tests/test_tasks/test_geo_cache_flush.py index ab65a39..be7ed5d 100644 --- a/backend/tests/test_tasks/test_geo_cache_flush.py +++ b/backend/tests/test_tasks/test_geo_cache_flush.py @@ -31,7 +31,9 @@ def _make_app(flush_count: int = 0) -> MagicMock: """ app = MagicMock() app.state.db = MagicMock() + app.state.db.close = AsyncMock() app.state.scheduler = MagicMock() + app.state.settings = MagicMock(database_path="/tmp/fake.db") return app @@ -49,6 +51,10 @@ class TestRunFlush: app = _make_app() with patch( + "app.tasks.geo_cache_flush.open_db", + new_callable=AsyncMock, + return_value=app.state.db, + ), patch( "app.tasks.geo_cache_flush.geo_service.flush_dirty", new_callable=AsyncMock, return_value=0, @@ -63,6 +69,10 @@ class TestRunFlush: app = _make_app() with patch( + "app.tasks.geo_cache_flush.open_db", + new_callable=AsyncMock, + return_value=app.state.db, + ), patch( "app.tasks.geo_cache_flush.geo_service.flush_dirty", new_callable=AsyncMock, return_value=15, @@ -79,6 +89,10 @@ class TestRunFlush: app = _make_app() with patch( + "app.tasks.geo_cache_flush.open_db", + new_callable=AsyncMock, + return_value=app.state.db, + ), patch( "app.tasks.geo_cache_flush.geo_service.flush_dirty", new_callable=AsyncMock, return_value=0, diff --git a/backend/tests/test_tasks/test_geo_re_resolve.py b/backend/tests/test_tasks/test_geo_re_resolve.py index afd1ee2..dfdeaab 100644 --- a/backend/tests/test_tasks/test_geo_re_resolve.py +++ b/backend/tests/test_tasks/test_geo_re_resolve.py @@ -69,6 +69,7 @@ def _make_app( app = MagicMock() app.state.db = db app.state.http_session = http_session + app.state.settings = MagicMock(database_path="/tmp/fake.db") return app @@ -78,7 +79,11 @@ async def test_run_re_resolve_no_unresolved_ips_skips() -> None: """The task should return immediately when no NULL-country IPs exist.""" app = _make_app(unresolved_ips=[]) - with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo: + with patch( + "app.tasks.geo_re_resolve.open_db", + new_callable=AsyncMock, + return_value=app.state.db, + ), patch("app.tasks.geo_re_resolve.geo_service") as mock_geo: mock_geo.get_unresolved_ips = AsyncMock(return_value=[]) await _run_re_resolve(app) @@ -116,7 +121,11 @@ async def test_run_re_resolve_calls_lookup_batch_with_db() -> None: } app = _make_app(unresolved_ips=ips, lookup_result=result) - with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo: + with patch( + "app.tasks.geo_re_resolve.open_db", + new_callable=AsyncMock, + return_value=app.state.db, + ), patch("app.tasks.geo_re_resolve.geo_service") as mock_geo: mock_geo.get_unresolved_ips = AsyncMock(return_value=ips) mock_geo.lookup_batch = AsyncMock(return_value=result) @@ -140,7 +149,11 @@ async def test_run_re_resolve_logs_correct_counts(caplog: Any) -> None: } app = _make_app(unresolved_ips=ips, lookup_result=result) - with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo: + with patch( + "app.tasks.geo_re_resolve.open_db", + new_callable=AsyncMock, + return_value=app.state.db, + ), patch("app.tasks.geo_re_resolve.geo_service") as mock_geo: mock_geo.get_unresolved_ips = AsyncMock(return_value=ips) mock_geo.lookup_batch = AsyncMock(return_value=result) @@ -163,7 +176,11 @@ async def test_run_re_resolve_handles_all_resolved() -> None: } app = _make_app(unresolved_ips=ips, lookup_result=result) - with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo: + with patch( + "app.tasks.geo_re_resolve.open_db", + new_callable=AsyncMock, + return_value=app.state.db, + ), patch("app.tasks.geo_re_resolve.geo_service") as mock_geo: mock_geo.get_unresolved_ips = AsyncMock(return_value=ips) mock_geo.lookup_batch = AsyncMock(return_value=result) diff --git a/backend/tests/test_tasks/test_history_sync.py b/backend/tests/test_tasks/test_history_sync.py index de167d6..d5a29a8 100644 --- a/backend/tests/test_tasks/test_history_sync.py +++ b/backend/tests/test_tasks/test_history_sync.py @@ -36,8 +36,10 @@ class TestHistorySyncTask: fake_app.state = type("FakeState", (), {})() fake_app.state.settings = type("FakeSettings", (), {})() fake_app.state.settings.fail2ban_socket = "/tmp/fake.sock" + fake_app.state.settings.database_path = "/tmp/fake.db" fake_app.state.db = MagicMock() + fake_app.state.db.close = AsyncMock() async def fake_get_history_page(*, db_path: str, since: int, page: int, page_size: int, **kwargs): assert since == 1001 @@ -47,6 +49,10 @@ class TestHistorySyncTask: return "/tmp/fake.sqlite3" with patch( + "app.tasks.history_sync.open_db", + new_callable=AsyncMock, + return_value=fake_app.state.db, + ), patch( "app.tasks.history_sync._get_last_archive_ts", new=AsyncMock(return_value=1000), ), patch(