Fix blocklist service injection and centralize session cookie name

This commit is contained in:
2026-04-14 09:21:38 +02:00
parent 5a9d226cca
commit a564830abb
8 changed files with 62 additions and 38 deletions

View File

@@ -337,6 +337,8 @@ An import statement inside a loop is an unconventional pattern that confuses rea
### Task 12 — Replace blocklist_service importlib workaround with proper injection ### Task 12 — Replace blocklist_service importlib workaround with proper injection
**Status:** Completed
**Severity:** Medium **Severity:** Medium
**Where:** **Where:**
@@ -511,6 +513,8 @@ None.
### Task 18 — Move _COOKIE_NAME to constants.py ### Task 18 — Move _COOKIE_NAME to constants.py
**Status:** Completed
**Severity:** Low **Severity:** Low
**Where:** **Where:**

View File

@@ -23,6 +23,7 @@ from app.models.config import PendingRecovery
from app.models.server import ServerStatus from app.models.server import ServerStatus
from app.repositories.protocols import SessionRepository from app.repositories.protocols import SessionRepository
from app.services.protocols import AuthService, JailService from app.services.protocols import AuthService, JailService
from app.utils.constants import SESSION_COOKIE_NAME
from app.utils.runtime_state import RuntimeState from app.utils.runtime_state import RuntimeState
from app.utils.session_cache import SessionCache from app.utils.session_cache import SessionCache
@@ -58,8 +59,6 @@ class ApplicationContext:
session_cache: SessionCache | None session_cache: SessionCache | None
_COOKIE_NAME = "bangui_session"
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Session validation cache # Session validation cache
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -137,7 +136,9 @@ async def get_db(
await db.close() await db.close()
async def get_http_session(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> aiohttp.ClientSession: async def get_http_session(
app_context: Annotated[ApplicationContext, Depends(get_app_context)],
) -> aiohttp.ClientSession:
"""Provide the shared HTTP client session from application context. """Provide the shared HTTP client session from application context.
Args: Args:
@@ -209,14 +210,14 @@ async def get_auth_service() -> AuthService:
"""Provide the concrete authentication service implementation.""" """Provide the concrete authentication service implementation."""
from app.services import auth_service # noqa: PLC0415 from app.services import auth_service # noqa: PLC0415
return cast(AuthService, auth_service) return cast("AuthService", auth_service)
async def get_jail_service() -> JailService: async def get_jail_service() -> JailService:
"""Provide the concrete jail service implementation.""" """Provide the concrete jail service implementation."""
from app.services import jail_service # noqa: PLC0415 from app.services import jail_service # noqa: PLC0415
return cast(JailService, jail_service) return cast("JailService", jail_service)
async def get_session_repo() -> SessionRepository: async def get_session_repo() -> SessionRepository:
@@ -241,7 +242,9 @@ async def get_server_status(app_context: Annotated[ApplicationContext, Depends(g
return app_context.server_status return app_context.server_status
async def get_pending_recovery(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> PendingRecovery | None: async def get_pending_recovery(
app_context: Annotated[ApplicationContext, Depends(get_app_context)],
) -> PendingRecovery | None:
"""Return the current pending recovery record from application context.""" """Return the current pending recovery record from application context."""
return app_context.pending_recovery return app_context.pending_recovery
@@ -277,7 +280,7 @@ async def require_auth(
HTTPException: 401 if no valid session token is found. HTTPException: 401 if no valid session token is found.
""" """
token: str | None = request.cookies.get(_COOKIE_NAME) token: str | None = request.cookies.get(SESSION_COOKIE_NAME)
if not token: if not token:
auth_header: str = request.headers.get("Authorization", "") auth_header: str = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "): if auth_header.startswith("Bearer "):

View File

@@ -16,18 +16,17 @@ from app.dependencies import (
AuthServiceDep, AuthServiceDep,
DbDep, DbDep,
SessionCacheDep, SessionCacheDep,
SettingsDep,
SessionRepoDep, SessionRepoDep,
SettingsDep,
) )
from app.models.auth import LoginRequest, LoginResponse, LogoutResponse from app.models.auth import LoginRequest, LoginResponse, LogoutResponse
from app.services.auth_service import sign_session_token from app.services.auth_service import sign_session_token
from app.utils.constants import SESSION_COOKIE_NAME
log: structlog.stdlib.BoundLogger = structlog.get_logger() log: structlog.stdlib.BoundLogger = structlog.get_logger()
router = APIRouter(prefix="/api/auth", tags=["auth"]) router = APIRouter(prefix="/api/auth", tags=["auth"])
_COOKIE_NAME = "bangui_session"
@router.post( @router.post(
"/login", "/login",
@@ -77,7 +76,7 @@ async def login(
settings.session_secret, settings.session_secret,
) )
response.set_cookie( response.set_cookie(
key=_COOKIE_NAME, key=SESSION_COOKIE_NAME,
value=signed_token, value=signed_token,
httponly=settings.session_cookie_httponly, httponly=settings.session_cookie_httponly,
samesite=settings.session_cookie_samesite, samesite=settings.session_cookie_samesite,
@@ -127,7 +126,7 @@ async def logout(
if raw_token: if raw_token:
session_cache.invalidate(raw_token) session_cache.invalidate(raw_token)
session_cache.invalidate(token) session_cache.invalidate(token)
response.delete_cookie(key=_COOKIE_NAME) response.delete_cookie(key=SESSION_COOKIE_NAME)
return LogoutResponse() return LogoutResponse()
@@ -145,7 +144,7 @@ def _extract_token(request: Request) -> str | None:
Returns: Returns:
The token string, or ``None`` if absent. The token string, or ``None`` if absent.
""" """
token: str | None = request.cookies.get(_COOKIE_NAME) token: str | None = request.cookies.get(SESSION_COOKIE_NAME)
if token: if token:
return token return token
auth_header: str = request.headers.get("Authorization", "") auth_header: str = request.headers.get("Authorization", "")

View File

@@ -46,7 +46,7 @@ from app.models.blocklist import (
ScheduleConfig, ScheduleConfig,
ScheduleInfo, ScheduleInfo,
) )
from app.services import blocklist_service, geo_service from app.services import blocklist_service, geo_service, jail_service
from app.tasks import blocklist_import as blocklist_import_task from app.tasks import blocklist_import as blocklist_import_task
router: APIRouter = APIRouter(prefix="/api/blocklists", tags=["Blocklists"]) router: APIRouter = APIRouter(prefix="/api/blocklists", tags=["Blocklists"])
@@ -141,6 +141,7 @@ async def run_import_now(
socket_path, socket_path,
geo_is_cached=geo_service.is_cached, geo_is_cached=geo_service.is_cached,
geo_batch_lookup=geo_service.lookup_batch, geo_batch_lookup=geo_service.lookup_batch,
ban_ip=jail_service.ban_ip,
) )

View File

@@ -15,14 +15,13 @@ under the key ``"blocklist_schedule"``.
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import importlib
import json import json
from collections.abc import Awaitable
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import aiohttp import aiohttp
import structlog import structlog
from app.exceptions import JailNotFoundError
from app.models.blocklist import ( from app.models.blocklist import (
BlocklistSource, BlocklistSource,
ImportLogEntry, ImportLogEntry,
@@ -33,12 +32,11 @@ from app.models.blocklist import (
ScheduleConfig, ScheduleConfig,
ScheduleInfo, ScheduleInfo,
) )
from app.exceptions import JailNotFoundError
from app.repositories import blocklist_repo, import_log_repo, settings_repo from app.repositories import blocklist_repo, import_log_repo, settings_repo
from app.utils.ip_utils import is_valid_ip, is_valid_network from app.utils.ip_utils import is_valid_ip, is_valid_network
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Callable from collections.abc import Awaitable, Callable
import aiohttp import aiohttp
import aiosqlite import aiosqlite
@@ -299,9 +297,10 @@ async def import_source(
http_session: aiohttp.ClientSession, http_session: aiohttp.ClientSession,
socket_path: str, socket_path: str,
db: aiosqlite.Connection, db: aiosqlite.Connection,
*,
ban_ip: Callable[[str, str, str], Awaitable[None]],
geo_is_cached: Callable[[str], bool] | None = None, geo_is_cached: Callable[[str], bool] | None = None,
geo_batch_lookup: GeoBatchLookup | None = None, geo_batch_lookup: GeoBatchLookup | None = None,
ban_ip: Callable[[str, str, str], Awaitable[None]] | None = None,
) -> ImportSourceResult: ) -> ImportSourceResult:
"""Download and apply bans from a single blocklist source. """Download and apply bans from a single blocklist source.
@@ -360,14 +359,7 @@ async def import_source(
ban_error: str | None = None ban_error: str | None = None
imported_ips: list[str] = [] imported_ips: list[str] = []
if ban_ip is None: ban_ip_fn = ban_ip
try:
jail_svc = importlib.import_module("app.services.jail_service")
ban_ip_fn = jail_svc.ban_ip
except (ModuleNotFoundError, AttributeError) as exc:
raise ValueError("ban_ip callback is required") from exc
else:
ban_ip_fn = ban_ip
for line in content.splitlines(): for line in content.splitlines():
stripped = line.strip() stripped = line.strip()
@@ -450,9 +442,10 @@ async def import_all(
db: aiosqlite.Connection, db: aiosqlite.Connection,
http_session: aiohttp.ClientSession, http_session: aiohttp.ClientSession,
socket_path: str, socket_path: str,
*,
ban_ip: Callable[[str, str, str], Awaitable[None]],
geo_is_cached: Callable[[str], bool] | None = None, geo_is_cached: Callable[[str], bool] | None = None,
geo_batch_lookup: GeoBatchLookup | None = None, geo_batch_lookup: GeoBatchLookup | None = None,
ban_ip: Callable[[str, str, str], Awaitable[None]] | None = None,
) -> ImportRunResult: ) -> ImportRunResult:
"""Import all enabled blocklist sources. """Import all enabled blocklist sources.

View File

@@ -19,29 +19,28 @@ import structlog
from app.db import open_db from app.db import open_db
from app.models.blocklist import ScheduleFrequency from app.models.blocklist import ScheduleFrequency
from app.services import blocklist_service from app.services import blocklist_service, jail_service
from app.utils.runtime_state import get_effective_settings from app.utils.runtime_state import get_effective_settings
if TYPE_CHECKING: if TYPE_CHECKING:
import aiosqlite import aiosqlite
from aiohttp import ClientSession from aiohttp import ClientSession
from app.config import Settings
if TYPE_CHECKING:
from fastapi import FastAPI from fastapi import FastAPI
from app.config import Settings
log: structlog.stdlib.BoundLogger = structlog.get_logger() log: structlog.stdlib.BoundLogger = structlog.get_logger()
#: Stable APScheduler job id so the job can be replaced without duplicates. #: Stable APScheduler job id so the job can be replaced without duplicates.
JOB_ID: str = "blocklist_import" JOB_ID: str = "blocklist_import"
async def _get_db(settings: "Settings") -> tuple[aiosqlite.Connection, bool]: async def _get_db(settings: Settings) -> tuple[aiosqlite.Connection, bool]:
db = await open_db(settings.database_path) db = await open_db(settings.database_path)
return db, True return db, True
async def _run_import_with_resources(settings: "Settings", http_session: "ClientSession") -> None: async def _run_import_with_resources(settings: Settings, http_session: ClientSession) -> None:
"""APScheduler callback that imports all enabled blocklist sources. """APScheduler callback that imports all enabled blocklist sources.
Args: Args:
@@ -57,6 +56,7 @@ async def _run_import_with_resources(settings: "Settings", http_session: "Client
db, db,
http_session, http_session,
socket_path, socket_path,
ban_ip=jail_service.ban_ip,
) )
log.info( log.info(
"blocklist_import_finished", "blocklist_import_finished",

View File

@@ -39,6 +39,9 @@ SESSION_TOKEN_BYTES: Final[int] = 32
SESSION_TOKEN_SIGNATURE_SEPARATOR: Final[str] = "." SESSION_TOKEN_SIGNATURE_SEPARATOR: Final[str] = "."
"""Separator used to append a signature to a signed session token.""" """Separator used to append a signature to a signed session token."""
SESSION_COOKIE_NAME: Final[str] = "bangui_session"
"""Name of the session cookie used by the browser SPA."""
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Time-range presets (used by dashboard and history endpoints) # Time-range presets (used by dashboard and history endpoints)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View File

@@ -170,11 +170,17 @@ class TestImport:
source = await blocklist_service.create_source(db, "Import Test", "https://t.test/") source = await blocklist_service.create_source(db, "Import Test", "https://t.test/")
from app.services import jail_service
with patch( with patch(
"app.services.jail_service.ban_ip", new_callable=AsyncMock "app.services.jail_service.ban_ip", new_callable=AsyncMock
) as mock_ban: ) as mock_ban:
result = await blocklist_service.import_source( result = await blocklist_service.import_source(
source, session, "/tmp/fake.sock", db source,
session,
"/tmp/fake.sock",
db,
ban_ip=jail_service.ban_ip,
) )
assert result.ips_imported == 2 assert result.ips_imported == 2
@@ -188,9 +194,15 @@ class TestImport:
session = _make_session(content) session = _make_session(content)
source = await blocklist_service.create_source(db, "CIDR Test", "https://c.test/") source = await blocklist_service.create_source(db, "CIDR Test", "https://c.test/")
from app.services import jail_service
with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock): with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock):
result = await blocklist_service.import_source( result = await blocklist_service.import_source(
source, session, "/tmp/fake.sock", db source,
session,
"/tmp/fake.sock",
db,
ban_ip=jail_service.ban_ip,
) )
assert result.ips_imported == 1 assert result.ips_imported == 1
@@ -201,8 +213,14 @@ class TestImport:
session = _make_session("", status=503) session = _make_session("", status=503)
source = await blocklist_service.create_source(db, "Err Source", "https://err.test/") source = await blocklist_service.create_source(db, "Err Source", "https://err.test/")
from app.services import jail_service
result = await blocklist_service.import_source( result = await blocklist_service.import_source(
source, session, "/tmp/fake.sock", db source,
session,
"/tmp/fake.sock",
db,
ban_ip=jail_service.ban_ip,
) )
assert result.ips_imported == 0 assert result.ips_imported == 0
@@ -224,7 +242,7 @@ class TestImport:
call_count += 1 call_count += 1
raise JailNotFoundError(jail) raise JailNotFoundError(jail)
with patch("app.services.jail_service.ban_ip", side_effect=_raise_jail_not_found) as mocked_ban_ip: with patch("app.services.jail_service.ban_ip", side_effect=_raise_jail_not_found):
from app.services import jail_service from app.services import jail_service
result = await blocklist_service.import_source( result = await blocklist_service.import_source(
@@ -349,6 +367,8 @@ class TestGeoPrewarmCacheFilter:
def _mock_is_cached(ip: str) -> bool: def _mock_is_cached(ip: str) -> bool:
return ip == "1.2.3.4" return ip == "1.2.3.4"
from app.services import jail_service
mock_batch = AsyncMock(return_value={}) mock_batch = AsyncMock(return_value={})
with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock): with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock):
result = await blocklist_service.import_source( result = await blocklist_service.import_source(
@@ -356,6 +376,7 @@ class TestGeoPrewarmCacheFilter:
session, session,
"/tmp/fake.sock", "/tmp/fake.sock",
db, db,
ban_ip=jail_service.ban_ip,
geo_is_cached=_mock_is_cached, geo_is_cached=_mock_is_cached,
geo_batch_lookup=mock_batch, geo_batch_lookup=mock_batch,
) )