From a564830abb4061da4051e5ba2c034bc18dcf4f81 Mon Sep 17 00:00:00 2001 From: Lukas Date: Tue, 14 Apr 2026 09:21:38 +0200 Subject: [PATCH] Fix blocklist service injection and centralize session cookie name --- Docs/Tasks.md | 4 +++ backend/app/dependencies.py | 17 ++++++----- backend/app/routers/auth.py | 11 ++++--- backend/app/routers/blocklist.py | 3 +- backend/app/services/blocklist_service.py | 21 +++++--------- backend/app/tasks/blocklist_import.py | 12 ++++---- backend/app/utils/constants.py | 3 ++ .../test_services/test_blocklist_service.py | 29 ++++++++++++++++--- 8 files changed, 62 insertions(+), 38 deletions(-) diff --git a/Docs/Tasks.md b/Docs/Tasks.md index 45e6f68..b15450b 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -337,6 +337,8 @@ An import statement inside a loop is an unconventional pattern that confuses rea ### Task 12 — Replace blocklist_service importlib workaround with proper injection +**Status:** Completed + **Severity:** Medium **Where:** @@ -511,6 +513,8 @@ None. ### Task 18 — Move _COOKIE_NAME to constants.py +**Status:** Completed + **Severity:** Low **Where:** diff --git a/backend/app/dependencies.py b/backend/app/dependencies.py index 7d7ff12..eb32ddb 100644 --- a/backend/app/dependencies.py +++ b/backend/app/dependencies.py @@ -23,6 +23,7 @@ from app.models.config import PendingRecovery from app.models.server import ServerStatus from app.repositories.protocols import SessionRepository from app.services.protocols import AuthService, JailService +from app.utils.constants import SESSION_COOKIE_NAME from app.utils.runtime_state import RuntimeState from app.utils.session_cache import SessionCache @@ -58,8 +59,6 @@ class ApplicationContext: session_cache: SessionCache | None -_COOKIE_NAME = "bangui_session" - # --------------------------------------------------------------------------- # Session validation cache # --------------------------------------------------------------------------- @@ -137,7 +136,9 @@ async def get_db( await db.close() -async def get_http_session(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> aiohttp.ClientSession: +async def get_http_session( + app_context: Annotated[ApplicationContext, Depends(get_app_context)], +) -> aiohttp.ClientSession: """Provide the shared HTTP client session from application context. Args: @@ -209,14 +210,14 @@ async def get_auth_service() -> AuthService: """Provide the concrete authentication service implementation.""" from app.services import auth_service # noqa: PLC0415 - return cast(AuthService, auth_service) + return cast("AuthService", auth_service) async def get_jail_service() -> JailService: """Provide the concrete jail service implementation.""" from app.services import jail_service # noqa: PLC0415 - return cast(JailService, jail_service) + return cast("JailService", jail_service) async def get_session_repo() -> SessionRepository: @@ -241,7 +242,9 @@ async def get_server_status(app_context: Annotated[ApplicationContext, Depends(g return app_context.server_status -async def get_pending_recovery(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> PendingRecovery | None: +async def get_pending_recovery( + app_context: Annotated[ApplicationContext, Depends(get_app_context)], +) -> PendingRecovery | None: """Return the current pending recovery record from application context.""" return app_context.pending_recovery @@ -277,7 +280,7 @@ async def require_auth( HTTPException: 401 if no valid session token is found. """ - token: str | None = request.cookies.get(_COOKIE_NAME) + token: str | None = request.cookies.get(SESSION_COOKIE_NAME) if not token: auth_header: str = request.headers.get("Authorization", "") if auth_header.startswith("Bearer "): diff --git a/backend/app/routers/auth.py b/backend/app/routers/auth.py index 755fa3a..d2383f1 100644 --- a/backend/app/routers/auth.py +++ b/backend/app/routers/auth.py @@ -16,18 +16,17 @@ from app.dependencies import ( AuthServiceDep, DbDep, SessionCacheDep, - SettingsDep, SessionRepoDep, + SettingsDep, ) from app.models.auth import LoginRequest, LoginResponse, LogoutResponse from app.services.auth_service import sign_session_token +from app.utils.constants import SESSION_COOKIE_NAME log: structlog.stdlib.BoundLogger = structlog.get_logger() router = APIRouter(prefix="/api/auth", tags=["auth"]) -_COOKIE_NAME = "bangui_session" - @router.post( "/login", @@ -77,7 +76,7 @@ async def login( settings.session_secret, ) response.set_cookie( - key=_COOKIE_NAME, + key=SESSION_COOKIE_NAME, value=signed_token, httponly=settings.session_cookie_httponly, samesite=settings.session_cookie_samesite, @@ -127,7 +126,7 @@ async def logout( if raw_token: session_cache.invalidate(raw_token) session_cache.invalidate(token) - response.delete_cookie(key=_COOKIE_NAME) + response.delete_cookie(key=SESSION_COOKIE_NAME) return LogoutResponse() @@ -145,7 +144,7 @@ def _extract_token(request: Request) -> str | None: Returns: The token string, or ``None`` if absent. """ - token: str | None = request.cookies.get(_COOKIE_NAME) + token: str | None = request.cookies.get(SESSION_COOKIE_NAME) if token: return token auth_header: str = request.headers.get("Authorization", "") diff --git a/backend/app/routers/blocklist.py b/backend/app/routers/blocklist.py index ad6f13c..8fb53a2 100644 --- a/backend/app/routers/blocklist.py +++ b/backend/app/routers/blocklist.py @@ -46,7 +46,7 @@ from app.models.blocklist import ( ScheduleConfig, ScheduleInfo, ) -from app.services import blocklist_service, geo_service +from app.services import blocklist_service, geo_service, jail_service from app.tasks import blocklist_import as blocklist_import_task router: APIRouter = APIRouter(prefix="/api/blocklists", tags=["Blocklists"]) @@ -141,6 +141,7 @@ async def run_import_now( socket_path, geo_is_cached=geo_service.is_cached, geo_batch_lookup=geo_service.lookup_batch, + ban_ip=jail_service.ban_ip, ) diff --git a/backend/app/services/blocklist_service.py b/backend/app/services/blocklist_service.py index 07d6ca8..f95d47f 100644 --- a/backend/app/services/blocklist_service.py +++ b/backend/app/services/blocklist_service.py @@ -15,14 +15,13 @@ under the key ``"blocklist_schedule"``. from __future__ import annotations import asyncio -import importlib import json -from collections.abc import Awaitable from typing import TYPE_CHECKING import aiohttp import structlog +from app.exceptions import JailNotFoundError from app.models.blocklist import ( BlocklistSource, ImportLogEntry, @@ -33,12 +32,11 @@ from app.models.blocklist import ( ScheduleConfig, ScheduleInfo, ) -from app.exceptions import JailNotFoundError from app.repositories import blocklist_repo, import_log_repo, settings_repo from app.utils.ip_utils import is_valid_ip, is_valid_network if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Awaitable, Callable import aiohttp import aiosqlite @@ -299,9 +297,10 @@ async def import_source( http_session: aiohttp.ClientSession, socket_path: str, db: aiosqlite.Connection, + *, + ban_ip: Callable[[str, str, str], Awaitable[None]], geo_is_cached: Callable[[str], bool] | None = None, geo_batch_lookup: GeoBatchLookup | None = None, - ban_ip: Callable[[str, str, str], Awaitable[None]] | None = None, ) -> ImportSourceResult: """Download and apply bans from a single blocklist source. @@ -360,14 +359,7 @@ async def import_source( ban_error: str | None = None imported_ips: list[str] = [] - if ban_ip is None: - try: - jail_svc = importlib.import_module("app.services.jail_service") - ban_ip_fn = jail_svc.ban_ip - except (ModuleNotFoundError, AttributeError) as exc: - raise ValueError("ban_ip callback is required") from exc - else: - ban_ip_fn = ban_ip + ban_ip_fn = ban_ip for line in content.splitlines(): stripped = line.strip() @@ -450,9 +442,10 @@ async def import_all( db: aiosqlite.Connection, http_session: aiohttp.ClientSession, socket_path: str, + *, + ban_ip: Callable[[str, str, str], Awaitable[None]], geo_is_cached: Callable[[str], bool] | None = None, geo_batch_lookup: GeoBatchLookup | None = None, - ban_ip: Callable[[str, str, str], Awaitable[None]] | None = None, ) -> ImportRunResult: """Import all enabled blocklist sources. diff --git a/backend/app/tasks/blocklist_import.py b/backend/app/tasks/blocklist_import.py index f8b5c5e..9f57c45 100644 --- a/backend/app/tasks/blocklist_import.py +++ b/backend/app/tasks/blocklist_import.py @@ -19,29 +19,28 @@ import structlog from app.db import open_db from app.models.blocklist import ScheduleFrequency -from app.services import blocklist_service +from app.services import blocklist_service, jail_service from app.utils.runtime_state import get_effective_settings if TYPE_CHECKING: import aiosqlite from aiohttp import ClientSession - from app.config import Settings - -if TYPE_CHECKING: from fastapi import FastAPI + from app.config import Settings + log: structlog.stdlib.BoundLogger = structlog.get_logger() #: Stable APScheduler job id so the job can be replaced without duplicates. JOB_ID: str = "blocklist_import" -async def _get_db(settings: "Settings") -> tuple[aiosqlite.Connection, bool]: +async def _get_db(settings: Settings) -> tuple[aiosqlite.Connection, bool]: db = await open_db(settings.database_path) return db, True -async def _run_import_with_resources(settings: "Settings", http_session: "ClientSession") -> None: +async def _run_import_with_resources(settings: Settings, http_session: ClientSession) -> None: """APScheduler callback that imports all enabled blocklist sources. Args: @@ -57,6 +56,7 @@ async def _run_import_with_resources(settings: "Settings", http_session: "Client db, http_session, socket_path, + ban_ip=jail_service.ban_ip, ) log.info( "blocklist_import_finished", diff --git a/backend/app/utils/constants.py b/backend/app/utils/constants.py index fa3254b..f2f371f 100644 --- a/backend/app/utils/constants.py +++ b/backend/app/utils/constants.py @@ -39,6 +39,9 @@ SESSION_TOKEN_BYTES: Final[int] = 32 SESSION_TOKEN_SIGNATURE_SEPARATOR: Final[str] = "." """Separator used to append a signature to a signed session token.""" +SESSION_COOKIE_NAME: Final[str] = "bangui_session" +"""Name of the session cookie used by the browser SPA.""" + # --------------------------------------------------------------------------- # Time-range presets (used by dashboard and history endpoints) # --------------------------------------------------------------------------- diff --git a/backend/tests/test_services/test_blocklist_service.py b/backend/tests/test_services/test_blocklist_service.py index 0866163..04eedac 100644 --- a/backend/tests/test_services/test_blocklist_service.py +++ b/backend/tests/test_services/test_blocklist_service.py @@ -170,11 +170,17 @@ class TestImport: source = await blocklist_service.create_source(db, "Import Test", "https://t.test/") + from app.services import jail_service + with patch( "app.services.jail_service.ban_ip", new_callable=AsyncMock ) as mock_ban: result = await blocklist_service.import_source( - source, session, "/tmp/fake.sock", db + source, + session, + "/tmp/fake.sock", + db, + ban_ip=jail_service.ban_ip, ) assert result.ips_imported == 2 @@ -188,9 +194,15 @@ class TestImport: session = _make_session(content) source = await blocklist_service.create_source(db, "CIDR Test", "https://c.test/") + from app.services import jail_service + with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock): result = await blocklist_service.import_source( - source, session, "/tmp/fake.sock", db + source, + session, + "/tmp/fake.sock", + db, + ban_ip=jail_service.ban_ip, ) assert result.ips_imported == 1 @@ -201,8 +213,14 @@ class TestImport: session = _make_session("", status=503) source = await blocklist_service.create_source(db, "Err Source", "https://err.test/") + from app.services import jail_service + result = await blocklist_service.import_source( - source, session, "/tmp/fake.sock", db + source, + session, + "/tmp/fake.sock", + db, + ban_ip=jail_service.ban_ip, ) assert result.ips_imported == 0 @@ -224,7 +242,7 @@ class TestImport: call_count += 1 raise JailNotFoundError(jail) - with patch("app.services.jail_service.ban_ip", side_effect=_raise_jail_not_found) as mocked_ban_ip: + with patch("app.services.jail_service.ban_ip", side_effect=_raise_jail_not_found): from app.services import jail_service result = await blocklist_service.import_source( @@ -349,6 +367,8 @@ class TestGeoPrewarmCacheFilter: def _mock_is_cached(ip: str) -> bool: return ip == "1.2.3.4" + from app.services import jail_service + mock_batch = AsyncMock(return_value={}) with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock): result = await blocklist_service.import_source( @@ -356,6 +376,7 @@ class TestGeoPrewarmCacheFilter: session, "/tmp/fake.sock", db, + ban_ip=jail_service.ban_ip, geo_is_cached=_mock_is_cached, geo_batch_lookup=mock_batch, )