Fix blocklist service injection and centralize session cookie name
This commit is contained in:
@@ -337,6 +337,8 @@ An import statement inside a loop is an unconventional pattern that confuses rea
|
||||
|
||||
### Task 12 — Replace blocklist_service importlib workaround with proper injection
|
||||
|
||||
**Status:** Completed
|
||||
|
||||
**Severity:** Medium
|
||||
|
||||
**Where:**
|
||||
@@ -511,6 +513,8 @@ None.
|
||||
|
||||
### Task 18 — Move _COOKIE_NAME to constants.py
|
||||
|
||||
**Status:** Completed
|
||||
|
||||
**Severity:** Low
|
||||
|
||||
**Where:**
|
||||
|
||||
@@ -23,6 +23,7 @@ from app.models.config import PendingRecovery
|
||||
from app.models.server import ServerStatus
|
||||
from app.repositories.protocols import SessionRepository
|
||||
from app.services.protocols import AuthService, JailService
|
||||
from app.utils.constants import SESSION_COOKIE_NAME
|
||||
from app.utils.runtime_state import RuntimeState
|
||||
from app.utils.session_cache import SessionCache
|
||||
|
||||
@@ -58,8 +59,6 @@ class ApplicationContext:
|
||||
session_cache: SessionCache | None
|
||||
|
||||
|
||||
_COOKIE_NAME = "bangui_session"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session validation cache
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -137,7 +136,9 @@ async def get_db(
|
||||
await db.close()
|
||||
|
||||
|
||||
async def get_http_session(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> aiohttp.ClientSession:
|
||||
async def get_http_session(
|
||||
app_context: Annotated[ApplicationContext, Depends(get_app_context)],
|
||||
) -> aiohttp.ClientSession:
|
||||
"""Provide the shared HTTP client session from application context.
|
||||
|
||||
Args:
|
||||
@@ -209,14 +210,14 @@ async def get_auth_service() -> AuthService:
|
||||
"""Provide the concrete authentication service implementation."""
|
||||
from app.services import auth_service # noqa: PLC0415
|
||||
|
||||
return cast(AuthService, auth_service)
|
||||
return cast("AuthService", auth_service)
|
||||
|
||||
|
||||
async def get_jail_service() -> JailService:
|
||||
"""Provide the concrete jail service implementation."""
|
||||
from app.services import jail_service # noqa: PLC0415
|
||||
|
||||
return cast(JailService, jail_service)
|
||||
return cast("JailService", jail_service)
|
||||
|
||||
|
||||
async def get_session_repo() -> SessionRepository:
|
||||
@@ -241,7 +242,9 @@ async def get_server_status(app_context: Annotated[ApplicationContext, Depends(g
|
||||
return app_context.server_status
|
||||
|
||||
|
||||
async def get_pending_recovery(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> PendingRecovery | None:
|
||||
async def get_pending_recovery(
|
||||
app_context: Annotated[ApplicationContext, Depends(get_app_context)],
|
||||
) -> PendingRecovery | None:
|
||||
"""Return the current pending recovery record from application context."""
|
||||
return app_context.pending_recovery
|
||||
|
||||
@@ -277,7 +280,7 @@ async def require_auth(
|
||||
HTTPException: 401 if no valid session token is found.
|
||||
"""
|
||||
|
||||
token: str | None = request.cookies.get(_COOKIE_NAME)
|
||||
token: str | None = request.cookies.get(SESSION_COOKIE_NAME)
|
||||
if not token:
|
||||
auth_header: str = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
|
||||
@@ -16,18 +16,17 @@ from app.dependencies import (
|
||||
AuthServiceDep,
|
||||
DbDep,
|
||||
SessionCacheDep,
|
||||
SettingsDep,
|
||||
SessionRepoDep,
|
||||
SettingsDep,
|
||||
)
|
||||
from app.models.auth import LoginRequest, LoginResponse, LogoutResponse
|
||||
from app.services.auth_service import sign_session_token
|
||||
from app.utils.constants import SESSION_COOKIE_NAME
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
|
||||
_COOKIE_NAME = "bangui_session"
|
||||
|
||||
|
||||
@router.post(
|
||||
"/login",
|
||||
@@ -77,7 +76,7 @@ async def login(
|
||||
settings.session_secret,
|
||||
)
|
||||
response.set_cookie(
|
||||
key=_COOKIE_NAME,
|
||||
key=SESSION_COOKIE_NAME,
|
||||
value=signed_token,
|
||||
httponly=settings.session_cookie_httponly,
|
||||
samesite=settings.session_cookie_samesite,
|
||||
@@ -127,7 +126,7 @@ async def logout(
|
||||
if raw_token:
|
||||
session_cache.invalidate(raw_token)
|
||||
session_cache.invalidate(token)
|
||||
response.delete_cookie(key=_COOKIE_NAME)
|
||||
response.delete_cookie(key=SESSION_COOKIE_NAME)
|
||||
return LogoutResponse()
|
||||
|
||||
|
||||
@@ -145,7 +144,7 @@ def _extract_token(request: Request) -> str | None:
|
||||
Returns:
|
||||
The token string, or ``None`` if absent.
|
||||
"""
|
||||
token: str | None = request.cookies.get(_COOKIE_NAME)
|
||||
token: str | None = request.cookies.get(SESSION_COOKIE_NAME)
|
||||
if token:
|
||||
return token
|
||||
auth_header: str = request.headers.get("Authorization", "")
|
||||
|
||||
@@ -46,7 +46,7 @@ from app.models.blocklist import (
|
||||
ScheduleConfig,
|
||||
ScheduleInfo,
|
||||
)
|
||||
from app.services import blocklist_service, geo_service
|
||||
from app.services import blocklist_service, geo_service, jail_service
|
||||
from app.tasks import blocklist_import as blocklist_import_task
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/api/blocklists", tags=["Blocklists"])
|
||||
@@ -141,6 +141,7 @@ async def run_import_now(
|
||||
socket_path,
|
||||
geo_is_cached=geo_service.is_cached,
|
||||
geo_batch_lookup=geo_service.lookup_batch,
|
||||
ban_ip=jail_service.ban_ip,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -15,14 +15,13 @@ under the key ``"blocklist_schedule"``.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
import json
|
||||
from collections.abc import Awaitable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
import structlog
|
||||
|
||||
from app.exceptions import JailNotFoundError
|
||||
from app.models.blocklist import (
|
||||
BlocklistSource,
|
||||
ImportLogEntry,
|
||||
@@ -33,12 +32,11 @@ from app.models.blocklist import (
|
||||
ScheduleConfig,
|
||||
ScheduleInfo,
|
||||
)
|
||||
from app.exceptions import JailNotFoundError
|
||||
from app.repositories import blocklist_repo, import_log_repo, settings_repo
|
||||
from app.utils.ip_utils import is_valid_ip, is_valid_network
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
@@ -299,9 +297,10 @@ async def import_source(
|
||||
http_session: aiohttp.ClientSession,
|
||||
socket_path: str,
|
||||
db: aiosqlite.Connection,
|
||||
*,
|
||||
ban_ip: Callable[[str, str, str], Awaitable[None]],
|
||||
geo_is_cached: Callable[[str], bool] | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
ban_ip: Callable[[str, str, str], Awaitable[None]] | None = None,
|
||||
) -> ImportSourceResult:
|
||||
"""Download and apply bans from a single blocklist source.
|
||||
|
||||
@@ -360,14 +359,7 @@ async def import_source(
|
||||
ban_error: str | None = None
|
||||
imported_ips: list[str] = []
|
||||
|
||||
if ban_ip is None:
|
||||
try:
|
||||
jail_svc = importlib.import_module("app.services.jail_service")
|
||||
ban_ip_fn = jail_svc.ban_ip
|
||||
except (ModuleNotFoundError, AttributeError) as exc:
|
||||
raise ValueError("ban_ip callback is required") from exc
|
||||
else:
|
||||
ban_ip_fn = ban_ip
|
||||
ban_ip_fn = ban_ip
|
||||
|
||||
for line in content.splitlines():
|
||||
stripped = line.strip()
|
||||
@@ -450,9 +442,10 @@ async def import_all(
|
||||
db: aiosqlite.Connection,
|
||||
http_session: aiohttp.ClientSession,
|
||||
socket_path: str,
|
||||
*,
|
||||
ban_ip: Callable[[str, str, str], Awaitable[None]],
|
||||
geo_is_cached: Callable[[str], bool] | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
ban_ip: Callable[[str, str, str], Awaitable[None]] | None = None,
|
||||
) -> ImportRunResult:
|
||||
"""Import all enabled blocklist sources.
|
||||
|
||||
|
||||
@@ -19,29 +19,28 @@ import structlog
|
||||
|
||||
from app.db import open_db
|
||||
from app.models.blocklist import ScheduleFrequency
|
||||
from app.services import blocklist_service
|
||||
from app.services import blocklist_service, jail_service
|
||||
from app.utils.runtime_state import get_effective_settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiosqlite
|
||||
from aiohttp import ClientSession
|
||||
from app.config import Settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.config import Settings
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
#: Stable APScheduler job id so the job can be replaced without duplicates.
|
||||
JOB_ID: str = "blocklist_import"
|
||||
|
||||
|
||||
async def _get_db(settings: "Settings") -> tuple[aiosqlite.Connection, bool]:
|
||||
async def _get_db(settings: Settings) -> tuple[aiosqlite.Connection, bool]:
|
||||
db = await open_db(settings.database_path)
|
||||
return db, True
|
||||
|
||||
|
||||
async def _run_import_with_resources(settings: "Settings", http_session: "ClientSession") -> None:
|
||||
async def _run_import_with_resources(settings: Settings, http_session: ClientSession) -> None:
|
||||
"""APScheduler callback that imports all enabled blocklist sources.
|
||||
|
||||
Args:
|
||||
@@ -57,6 +56,7 @@ async def _run_import_with_resources(settings: "Settings", http_session: "Client
|
||||
db,
|
||||
http_session,
|
||||
socket_path,
|
||||
ban_ip=jail_service.ban_ip,
|
||||
)
|
||||
log.info(
|
||||
"blocklist_import_finished",
|
||||
|
||||
@@ -39,6 +39,9 @@ SESSION_TOKEN_BYTES: Final[int] = 32
|
||||
SESSION_TOKEN_SIGNATURE_SEPARATOR: Final[str] = "."
|
||||
"""Separator used to append a signature to a signed session token."""
|
||||
|
||||
SESSION_COOKIE_NAME: Final[str] = "bangui_session"
|
||||
"""Name of the session cookie used by the browser SPA."""
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Time-range presets (used by dashboard and history endpoints)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -170,11 +170,17 @@ class TestImport:
|
||||
|
||||
source = await blocklist_service.create_source(db, "Import Test", "https://t.test/")
|
||||
|
||||
from app.services import jail_service
|
||||
|
||||
with patch(
|
||||
"app.services.jail_service.ban_ip", new_callable=AsyncMock
|
||||
) as mock_ban:
|
||||
result = await blocklist_service.import_source(
|
||||
source, session, "/tmp/fake.sock", db
|
||||
source,
|
||||
session,
|
||||
"/tmp/fake.sock",
|
||||
db,
|
||||
ban_ip=jail_service.ban_ip,
|
||||
)
|
||||
|
||||
assert result.ips_imported == 2
|
||||
@@ -188,9 +194,15 @@ class TestImport:
|
||||
session = _make_session(content)
|
||||
source = await blocklist_service.create_source(db, "CIDR Test", "https://c.test/")
|
||||
|
||||
from app.services import jail_service
|
||||
|
||||
with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock):
|
||||
result = await blocklist_service.import_source(
|
||||
source, session, "/tmp/fake.sock", db
|
||||
source,
|
||||
session,
|
||||
"/tmp/fake.sock",
|
||||
db,
|
||||
ban_ip=jail_service.ban_ip,
|
||||
)
|
||||
|
||||
assert result.ips_imported == 1
|
||||
@@ -201,8 +213,14 @@ class TestImport:
|
||||
session = _make_session("", status=503)
|
||||
source = await blocklist_service.create_source(db, "Err Source", "https://err.test/")
|
||||
|
||||
from app.services import jail_service
|
||||
|
||||
result = await blocklist_service.import_source(
|
||||
source, session, "/tmp/fake.sock", db
|
||||
source,
|
||||
session,
|
||||
"/tmp/fake.sock",
|
||||
db,
|
||||
ban_ip=jail_service.ban_ip,
|
||||
)
|
||||
|
||||
assert result.ips_imported == 0
|
||||
@@ -224,7 +242,7 @@ class TestImport:
|
||||
call_count += 1
|
||||
raise JailNotFoundError(jail)
|
||||
|
||||
with patch("app.services.jail_service.ban_ip", side_effect=_raise_jail_not_found) as mocked_ban_ip:
|
||||
with patch("app.services.jail_service.ban_ip", side_effect=_raise_jail_not_found):
|
||||
from app.services import jail_service
|
||||
|
||||
result = await blocklist_service.import_source(
|
||||
@@ -349,6 +367,8 @@ class TestGeoPrewarmCacheFilter:
|
||||
def _mock_is_cached(ip: str) -> bool:
|
||||
return ip == "1.2.3.4"
|
||||
|
||||
from app.services import jail_service
|
||||
|
||||
mock_batch = AsyncMock(return_value={})
|
||||
with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock):
|
||||
result = await blocklist_service.import_source(
|
||||
@@ -356,6 +376,7 @@ class TestGeoPrewarmCacheFilter:
|
||||
session,
|
||||
"/tmp/fake.sock",
|
||||
db,
|
||||
ban_ip=jail_service.ban_ip,
|
||||
geo_is_cached=_mock_is_cached,
|
||||
geo_batch_lookup=mock_batch,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user