Fix blocklist service injection and centralize session cookie name

This commit is contained in:
2026-04-14 09:21:38 +02:00
parent 5a9d226cca
commit a564830abb
8 changed files with 62 additions and 38 deletions

View File

@@ -23,6 +23,7 @@ from app.models.config import PendingRecovery
from app.models.server import ServerStatus
from app.repositories.protocols import SessionRepository
from app.services.protocols import AuthService, JailService
from app.utils.constants import SESSION_COOKIE_NAME
from app.utils.runtime_state import RuntimeState
from app.utils.session_cache import SessionCache
@@ -58,8 +59,6 @@ class ApplicationContext:
session_cache: SessionCache | None
_COOKIE_NAME = "bangui_session"
# ---------------------------------------------------------------------------
# Session validation cache
# ---------------------------------------------------------------------------
@@ -137,7 +136,9 @@ async def get_db(
await db.close()
async def get_http_session(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> aiohttp.ClientSession:
async def get_http_session(
app_context: Annotated[ApplicationContext, Depends(get_app_context)],
) -> aiohttp.ClientSession:
"""Provide the shared HTTP client session from application context.
Args:
@@ -209,14 +210,14 @@ async def get_auth_service() -> AuthService:
"""Provide the concrete authentication service implementation."""
from app.services import auth_service # noqa: PLC0415
return cast(AuthService, auth_service)
return cast("AuthService", auth_service)
async def get_jail_service() -> JailService:
"""Provide the concrete jail service implementation."""
from app.services import jail_service # noqa: PLC0415
return cast(JailService, jail_service)
return cast("JailService", jail_service)
async def get_session_repo() -> SessionRepository:
@@ -241,7 +242,9 @@ async def get_server_status(app_context: Annotated[ApplicationContext, Depends(g
return app_context.server_status
async def get_pending_recovery(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> PendingRecovery | None:
async def get_pending_recovery(
app_context: Annotated[ApplicationContext, Depends(get_app_context)],
) -> PendingRecovery | None:
"""Return the current pending recovery record from application context."""
return app_context.pending_recovery
@@ -277,7 +280,7 @@ async def require_auth(
HTTPException: 401 if no valid session token is found.
"""
token: str | None = request.cookies.get(_COOKIE_NAME)
token: str | None = request.cookies.get(SESSION_COOKIE_NAME)
if not token:
auth_header: str = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):

View File

@@ -16,18 +16,17 @@ from app.dependencies import (
AuthServiceDep,
DbDep,
SessionCacheDep,
SettingsDep,
SessionRepoDep,
SettingsDep,
)
from app.models.auth import LoginRequest, LoginResponse, LogoutResponse
from app.services.auth_service import sign_session_token
from app.utils.constants import SESSION_COOKIE_NAME
log: structlog.stdlib.BoundLogger = structlog.get_logger()
router = APIRouter(prefix="/api/auth", tags=["auth"])
_COOKIE_NAME = "bangui_session"
@router.post(
"/login",
@@ -77,7 +76,7 @@ async def login(
settings.session_secret,
)
response.set_cookie(
key=_COOKIE_NAME,
key=SESSION_COOKIE_NAME,
value=signed_token,
httponly=settings.session_cookie_httponly,
samesite=settings.session_cookie_samesite,
@@ -127,7 +126,7 @@ async def logout(
if raw_token:
session_cache.invalidate(raw_token)
session_cache.invalidate(token)
response.delete_cookie(key=_COOKIE_NAME)
response.delete_cookie(key=SESSION_COOKIE_NAME)
return LogoutResponse()
@@ -145,7 +144,7 @@ def _extract_token(request: Request) -> str | None:
Returns:
The token string, or ``None`` if absent.
"""
token: str | None = request.cookies.get(_COOKIE_NAME)
token: str | None = request.cookies.get(SESSION_COOKIE_NAME)
if token:
return token
auth_header: str = request.headers.get("Authorization", "")

View File

@@ -46,7 +46,7 @@ from app.models.blocklist import (
ScheduleConfig,
ScheduleInfo,
)
from app.services import blocklist_service, geo_service
from app.services import blocklist_service, geo_service, jail_service
from app.tasks import blocklist_import as blocklist_import_task
router: APIRouter = APIRouter(prefix="/api/blocklists", tags=["Blocklists"])
@@ -141,6 +141,7 @@ async def run_import_now(
socket_path,
geo_is_cached=geo_service.is_cached,
geo_batch_lookup=geo_service.lookup_batch,
ban_ip=jail_service.ban_ip,
)

View File

@@ -15,14 +15,13 @@ under the key ``"blocklist_schedule"``.
from __future__ import annotations
import asyncio
import importlib
import json
from collections.abc import Awaitable
from typing import TYPE_CHECKING
import aiohttp
import structlog
from app.exceptions import JailNotFoundError
from app.models.blocklist import (
BlocklistSource,
ImportLogEntry,
@@ -33,12 +32,11 @@ from app.models.blocklist import (
ScheduleConfig,
ScheduleInfo,
)
from app.exceptions import JailNotFoundError
from app.repositories import blocklist_repo, import_log_repo, settings_repo
from app.utils.ip_utils import is_valid_ip, is_valid_network
if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Awaitable, Callable
import aiohttp
import aiosqlite
@@ -299,9 +297,10 @@ async def import_source(
http_session: aiohttp.ClientSession,
socket_path: str,
db: aiosqlite.Connection,
*,
ban_ip: Callable[[str, str, str], Awaitable[None]],
geo_is_cached: Callable[[str], bool] | None = None,
geo_batch_lookup: GeoBatchLookup | None = None,
ban_ip: Callable[[str, str, str], Awaitable[None]] | None = None,
) -> ImportSourceResult:
"""Download and apply bans from a single blocklist source.
@@ -360,14 +359,7 @@ async def import_source(
ban_error: str | None = None
imported_ips: list[str] = []
if ban_ip is None:
try:
jail_svc = importlib.import_module("app.services.jail_service")
ban_ip_fn = jail_svc.ban_ip
except (ModuleNotFoundError, AttributeError) as exc:
raise ValueError("ban_ip callback is required") from exc
else:
ban_ip_fn = ban_ip
ban_ip_fn = ban_ip
for line in content.splitlines():
stripped = line.strip()
@@ -450,9 +442,10 @@ async def import_all(
db: aiosqlite.Connection,
http_session: aiohttp.ClientSession,
socket_path: str,
*,
ban_ip: Callable[[str, str, str], Awaitable[None]],
geo_is_cached: Callable[[str], bool] | None = None,
geo_batch_lookup: GeoBatchLookup | None = None,
ban_ip: Callable[[str, str, str], Awaitable[None]] | None = None,
) -> ImportRunResult:
"""Import all enabled blocklist sources.

View File

@@ -19,29 +19,28 @@ import structlog
from app.db import open_db
from app.models.blocklist import ScheduleFrequency
from app.services import blocklist_service
from app.services import blocklist_service, jail_service
from app.utils.runtime_state import get_effective_settings
if TYPE_CHECKING:
import aiosqlite
from aiohttp import ClientSession
from app.config import Settings
if TYPE_CHECKING:
from fastapi import FastAPI
from app.config import Settings
log: structlog.stdlib.BoundLogger = structlog.get_logger()
#: Stable APScheduler job id so the job can be replaced without duplicates.
JOB_ID: str = "blocklist_import"
async def _get_db(settings: "Settings") -> tuple[aiosqlite.Connection, bool]:
async def _get_db(settings: Settings) -> tuple[aiosqlite.Connection, bool]:
db = await open_db(settings.database_path)
return db, True
async def _run_import_with_resources(settings: "Settings", http_session: "ClientSession") -> None:
async def _run_import_with_resources(settings: Settings, http_session: ClientSession) -> None:
"""APScheduler callback that imports all enabled blocklist sources.
Args:
@@ -57,6 +56,7 @@ async def _run_import_with_resources(settings: "Settings", http_session: "Client
db,
http_session,
socket_path,
ban_ip=jail_service.ban_ip,
)
log.info(
"blocklist_import_finished",

View File

@@ -39,6 +39,9 @@ SESSION_TOKEN_BYTES: Final[int] = 32
SESSION_TOKEN_SIGNATURE_SEPARATOR: Final[str] = "."
"""Separator used to append a signature to a signed session token."""
SESSION_COOKIE_NAME: Final[str] = "bangui_session"
"""Name of the session cookie used by the browser SPA."""
# ---------------------------------------------------------------------------
# Time-range presets (used by dashboard and history endpoints)
# ---------------------------------------------------------------------------