Fix blocklist service injection and centralize session cookie name
This commit is contained in:
@@ -337,6 +337,8 @@ An import statement inside a loop is an unconventional pattern that confuses rea
|
|||||||
|
|
||||||
### Task 12 — Replace blocklist_service importlib workaround with proper injection
|
### Task 12 — Replace blocklist_service importlib workaround with proper injection
|
||||||
|
|
||||||
|
**Status:** Completed
|
||||||
|
|
||||||
**Severity:** Medium
|
**Severity:** Medium
|
||||||
|
|
||||||
**Where:**
|
**Where:**
|
||||||
@@ -511,6 +513,8 @@ None.
|
|||||||
|
|
||||||
### Task 18 — Move _COOKIE_NAME to constants.py
|
### Task 18 — Move _COOKIE_NAME to constants.py
|
||||||
|
|
||||||
|
**Status:** Completed
|
||||||
|
|
||||||
**Severity:** Low
|
**Severity:** Low
|
||||||
|
|
||||||
**Where:**
|
**Where:**
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from app.models.config import PendingRecovery
|
|||||||
from app.models.server import ServerStatus
|
from app.models.server import ServerStatus
|
||||||
from app.repositories.protocols import SessionRepository
|
from app.repositories.protocols import SessionRepository
|
||||||
from app.services.protocols import AuthService, JailService
|
from app.services.protocols import AuthService, JailService
|
||||||
|
from app.utils.constants import SESSION_COOKIE_NAME
|
||||||
from app.utils.runtime_state import RuntimeState
|
from app.utils.runtime_state import RuntimeState
|
||||||
from app.utils.session_cache import SessionCache
|
from app.utils.session_cache import SessionCache
|
||||||
|
|
||||||
@@ -58,8 +59,6 @@ class ApplicationContext:
|
|||||||
session_cache: SessionCache | None
|
session_cache: SessionCache | None
|
||||||
|
|
||||||
|
|
||||||
_COOKIE_NAME = "bangui_session"
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Session validation cache
|
# Session validation cache
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -137,7 +136,9 @@ async def get_db(
|
|||||||
await db.close()
|
await db.close()
|
||||||
|
|
||||||
|
|
||||||
async def get_http_session(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> aiohttp.ClientSession:
|
async def get_http_session(
|
||||||
|
app_context: Annotated[ApplicationContext, Depends(get_app_context)],
|
||||||
|
) -> aiohttp.ClientSession:
|
||||||
"""Provide the shared HTTP client session from application context.
|
"""Provide the shared HTTP client session from application context.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -209,14 +210,14 @@ async def get_auth_service() -> AuthService:
|
|||||||
"""Provide the concrete authentication service implementation."""
|
"""Provide the concrete authentication service implementation."""
|
||||||
from app.services import auth_service # noqa: PLC0415
|
from app.services import auth_service # noqa: PLC0415
|
||||||
|
|
||||||
return cast(AuthService, auth_service)
|
return cast("AuthService", auth_service)
|
||||||
|
|
||||||
|
|
||||||
async def get_jail_service() -> JailService:
|
async def get_jail_service() -> JailService:
|
||||||
"""Provide the concrete jail service implementation."""
|
"""Provide the concrete jail service implementation."""
|
||||||
from app.services import jail_service # noqa: PLC0415
|
from app.services import jail_service # noqa: PLC0415
|
||||||
|
|
||||||
return cast(JailService, jail_service)
|
return cast("JailService", jail_service)
|
||||||
|
|
||||||
|
|
||||||
async def get_session_repo() -> SessionRepository:
|
async def get_session_repo() -> SessionRepository:
|
||||||
@@ -241,7 +242,9 @@ async def get_server_status(app_context: Annotated[ApplicationContext, Depends(g
|
|||||||
return app_context.server_status
|
return app_context.server_status
|
||||||
|
|
||||||
|
|
||||||
async def get_pending_recovery(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> PendingRecovery | None:
|
async def get_pending_recovery(
|
||||||
|
app_context: Annotated[ApplicationContext, Depends(get_app_context)],
|
||||||
|
) -> PendingRecovery | None:
|
||||||
"""Return the current pending recovery record from application context."""
|
"""Return the current pending recovery record from application context."""
|
||||||
return app_context.pending_recovery
|
return app_context.pending_recovery
|
||||||
|
|
||||||
@@ -277,7 +280,7 @@ async def require_auth(
|
|||||||
HTTPException: 401 if no valid session token is found.
|
HTTPException: 401 if no valid session token is found.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
token: str | None = request.cookies.get(_COOKIE_NAME)
|
token: str | None = request.cookies.get(SESSION_COOKIE_NAME)
|
||||||
if not token:
|
if not token:
|
||||||
auth_header: str = request.headers.get("Authorization", "")
|
auth_header: str = request.headers.get("Authorization", "")
|
||||||
if auth_header.startswith("Bearer "):
|
if auth_header.startswith("Bearer "):
|
||||||
|
|||||||
@@ -16,18 +16,17 @@ from app.dependencies import (
|
|||||||
AuthServiceDep,
|
AuthServiceDep,
|
||||||
DbDep,
|
DbDep,
|
||||||
SessionCacheDep,
|
SessionCacheDep,
|
||||||
SettingsDep,
|
|
||||||
SessionRepoDep,
|
SessionRepoDep,
|
||||||
|
SettingsDep,
|
||||||
)
|
)
|
||||||
from app.models.auth import LoginRequest, LoginResponse, LogoutResponse
|
from app.models.auth import LoginRequest, LoginResponse, LogoutResponse
|
||||||
from app.services.auth_service import sign_session_token
|
from app.services.auth_service import sign_session_token
|
||||||
|
from app.utils.constants import SESSION_COOKIE_NAME
|
||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||||
|
|
||||||
_COOKIE_NAME = "bangui_session"
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/login",
|
"/login",
|
||||||
@@ -77,7 +76,7 @@ async def login(
|
|||||||
settings.session_secret,
|
settings.session_secret,
|
||||||
)
|
)
|
||||||
response.set_cookie(
|
response.set_cookie(
|
||||||
key=_COOKIE_NAME,
|
key=SESSION_COOKIE_NAME,
|
||||||
value=signed_token,
|
value=signed_token,
|
||||||
httponly=settings.session_cookie_httponly,
|
httponly=settings.session_cookie_httponly,
|
||||||
samesite=settings.session_cookie_samesite,
|
samesite=settings.session_cookie_samesite,
|
||||||
@@ -127,7 +126,7 @@ async def logout(
|
|||||||
if raw_token:
|
if raw_token:
|
||||||
session_cache.invalidate(raw_token)
|
session_cache.invalidate(raw_token)
|
||||||
session_cache.invalidate(token)
|
session_cache.invalidate(token)
|
||||||
response.delete_cookie(key=_COOKIE_NAME)
|
response.delete_cookie(key=SESSION_COOKIE_NAME)
|
||||||
return LogoutResponse()
|
return LogoutResponse()
|
||||||
|
|
||||||
|
|
||||||
@@ -145,7 +144,7 @@ def _extract_token(request: Request) -> str | None:
|
|||||||
Returns:
|
Returns:
|
||||||
The token string, or ``None`` if absent.
|
The token string, or ``None`` if absent.
|
||||||
"""
|
"""
|
||||||
token: str | None = request.cookies.get(_COOKIE_NAME)
|
token: str | None = request.cookies.get(SESSION_COOKIE_NAME)
|
||||||
if token:
|
if token:
|
||||||
return token
|
return token
|
||||||
auth_header: str = request.headers.get("Authorization", "")
|
auth_header: str = request.headers.get("Authorization", "")
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ from app.models.blocklist import (
|
|||||||
ScheduleConfig,
|
ScheduleConfig,
|
||||||
ScheduleInfo,
|
ScheduleInfo,
|
||||||
)
|
)
|
||||||
from app.services import blocklist_service, geo_service
|
from app.services import blocklist_service, geo_service, jail_service
|
||||||
from app.tasks import blocklist_import as blocklist_import_task
|
from app.tasks import blocklist_import as blocklist_import_task
|
||||||
|
|
||||||
router: APIRouter = APIRouter(prefix="/api/blocklists", tags=["Blocklists"])
|
router: APIRouter = APIRouter(prefix="/api/blocklists", tags=["Blocklists"])
|
||||||
@@ -141,6 +141,7 @@ async def run_import_now(
|
|||||||
socket_path,
|
socket_path,
|
||||||
geo_is_cached=geo_service.is_cached,
|
geo_is_cached=geo_service.is_cached,
|
||||||
geo_batch_lookup=geo_service.lookup_batch,
|
geo_batch_lookup=geo_service.lookup_batch,
|
||||||
|
ban_ip=jail_service.ban_ip,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -15,14 +15,13 @@ under the key ``"blocklist_schedule"``.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import importlib
|
|
||||||
import json
|
import json
|
||||||
from collections.abc import Awaitable
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
|
from app.exceptions import JailNotFoundError
|
||||||
from app.models.blocklist import (
|
from app.models.blocklist import (
|
||||||
BlocklistSource,
|
BlocklistSource,
|
||||||
ImportLogEntry,
|
ImportLogEntry,
|
||||||
@@ -33,12 +32,11 @@ from app.models.blocklist import (
|
|||||||
ScheduleConfig,
|
ScheduleConfig,
|
||||||
ScheduleInfo,
|
ScheduleInfo,
|
||||||
)
|
)
|
||||||
from app.exceptions import JailNotFoundError
|
|
||||||
from app.repositories import blocklist_repo, import_log_repo, settings_repo
|
from app.repositories import blocklist_repo, import_log_repo, settings_repo
|
||||||
from app.utils.ip_utils import is_valid_ip, is_valid_network
|
from app.utils.ip_utils import is_valid_ip, is_valid_network
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Callable
|
from collections.abc import Awaitable, Callable
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
@@ -299,9 +297,10 @@ async def import_source(
|
|||||||
http_session: aiohttp.ClientSession,
|
http_session: aiohttp.ClientSession,
|
||||||
socket_path: str,
|
socket_path: str,
|
||||||
db: aiosqlite.Connection,
|
db: aiosqlite.Connection,
|
||||||
|
*,
|
||||||
|
ban_ip: Callable[[str, str, str], Awaitable[None]],
|
||||||
geo_is_cached: Callable[[str], bool] | None = None,
|
geo_is_cached: Callable[[str], bool] | None = None,
|
||||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||||
ban_ip: Callable[[str, str, str], Awaitable[None]] | None = None,
|
|
||||||
) -> ImportSourceResult:
|
) -> ImportSourceResult:
|
||||||
"""Download and apply bans from a single blocklist source.
|
"""Download and apply bans from a single blocklist source.
|
||||||
|
|
||||||
@@ -360,14 +359,7 @@ async def import_source(
|
|||||||
ban_error: str | None = None
|
ban_error: str | None = None
|
||||||
imported_ips: list[str] = []
|
imported_ips: list[str] = []
|
||||||
|
|
||||||
if ban_ip is None:
|
ban_ip_fn = ban_ip
|
||||||
try:
|
|
||||||
jail_svc = importlib.import_module("app.services.jail_service")
|
|
||||||
ban_ip_fn = jail_svc.ban_ip
|
|
||||||
except (ModuleNotFoundError, AttributeError) as exc:
|
|
||||||
raise ValueError("ban_ip callback is required") from exc
|
|
||||||
else:
|
|
||||||
ban_ip_fn = ban_ip
|
|
||||||
|
|
||||||
for line in content.splitlines():
|
for line in content.splitlines():
|
||||||
stripped = line.strip()
|
stripped = line.strip()
|
||||||
@@ -450,9 +442,10 @@ async def import_all(
|
|||||||
db: aiosqlite.Connection,
|
db: aiosqlite.Connection,
|
||||||
http_session: aiohttp.ClientSession,
|
http_session: aiohttp.ClientSession,
|
||||||
socket_path: str,
|
socket_path: str,
|
||||||
|
*,
|
||||||
|
ban_ip: Callable[[str, str, str], Awaitable[None]],
|
||||||
geo_is_cached: Callable[[str], bool] | None = None,
|
geo_is_cached: Callable[[str], bool] | None = None,
|
||||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||||
ban_ip: Callable[[str, str, str], Awaitable[None]] | None = None,
|
|
||||||
) -> ImportRunResult:
|
) -> ImportRunResult:
|
||||||
"""Import all enabled blocklist sources.
|
"""Import all enabled blocklist sources.
|
||||||
|
|
||||||
|
|||||||
@@ -19,29 +19,28 @@ import structlog
|
|||||||
|
|
||||||
from app.db import open_db
|
from app.db import open_db
|
||||||
from app.models.blocklist import ScheduleFrequency
|
from app.models.blocklist import ScheduleFrequency
|
||||||
from app.services import blocklist_service
|
from app.services import blocklist_service, jail_service
|
||||||
from app.utils.runtime_state import get_effective_settings
|
from app.utils.runtime_state import get_effective_settings
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
from app.config import Settings
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from app.config import Settings
|
||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
|
|
||||||
#: Stable APScheduler job id so the job can be replaced without duplicates.
|
#: Stable APScheduler job id so the job can be replaced without duplicates.
|
||||||
JOB_ID: str = "blocklist_import"
|
JOB_ID: str = "blocklist_import"
|
||||||
|
|
||||||
|
|
||||||
async def _get_db(settings: "Settings") -> tuple[aiosqlite.Connection, bool]:
|
async def _get_db(settings: Settings) -> tuple[aiosqlite.Connection, bool]:
|
||||||
db = await open_db(settings.database_path)
|
db = await open_db(settings.database_path)
|
||||||
return db, True
|
return db, True
|
||||||
|
|
||||||
|
|
||||||
async def _run_import_with_resources(settings: "Settings", http_session: "ClientSession") -> None:
|
async def _run_import_with_resources(settings: Settings, http_session: ClientSession) -> None:
|
||||||
"""APScheduler callback that imports all enabled blocklist sources.
|
"""APScheduler callback that imports all enabled blocklist sources.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -57,6 +56,7 @@ async def _run_import_with_resources(settings: "Settings", http_session: "Client
|
|||||||
db,
|
db,
|
||||||
http_session,
|
http_session,
|
||||||
socket_path,
|
socket_path,
|
||||||
|
ban_ip=jail_service.ban_ip,
|
||||||
)
|
)
|
||||||
log.info(
|
log.info(
|
||||||
"blocklist_import_finished",
|
"blocklist_import_finished",
|
||||||
|
|||||||
@@ -39,6 +39,9 @@ SESSION_TOKEN_BYTES: Final[int] = 32
|
|||||||
SESSION_TOKEN_SIGNATURE_SEPARATOR: Final[str] = "."
|
SESSION_TOKEN_SIGNATURE_SEPARATOR: Final[str] = "."
|
||||||
"""Separator used to append a signature to a signed session token."""
|
"""Separator used to append a signature to a signed session token."""
|
||||||
|
|
||||||
|
SESSION_COOKIE_NAME: Final[str] = "bangui_session"
|
||||||
|
"""Name of the session cookie used by the browser SPA."""
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Time-range presets (used by dashboard and history endpoints)
|
# Time-range presets (used by dashboard and history endpoints)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -170,11 +170,17 @@ class TestImport:
|
|||||||
|
|
||||||
source = await blocklist_service.create_source(db, "Import Test", "https://t.test/")
|
source = await blocklist_service.create_source(db, "Import Test", "https://t.test/")
|
||||||
|
|
||||||
|
from app.services import jail_service
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.services.jail_service.ban_ip", new_callable=AsyncMock
|
"app.services.jail_service.ban_ip", new_callable=AsyncMock
|
||||||
) as mock_ban:
|
) as mock_ban:
|
||||||
result = await blocklist_service.import_source(
|
result = await blocklist_service.import_source(
|
||||||
source, session, "/tmp/fake.sock", db
|
source,
|
||||||
|
session,
|
||||||
|
"/tmp/fake.sock",
|
||||||
|
db,
|
||||||
|
ban_ip=jail_service.ban_ip,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result.ips_imported == 2
|
assert result.ips_imported == 2
|
||||||
@@ -188,9 +194,15 @@ class TestImport:
|
|||||||
session = _make_session(content)
|
session = _make_session(content)
|
||||||
source = await blocklist_service.create_source(db, "CIDR Test", "https://c.test/")
|
source = await blocklist_service.create_source(db, "CIDR Test", "https://c.test/")
|
||||||
|
|
||||||
|
from app.services import jail_service
|
||||||
|
|
||||||
with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock):
|
with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock):
|
||||||
result = await blocklist_service.import_source(
|
result = await blocklist_service.import_source(
|
||||||
source, session, "/tmp/fake.sock", db
|
source,
|
||||||
|
session,
|
||||||
|
"/tmp/fake.sock",
|
||||||
|
db,
|
||||||
|
ban_ip=jail_service.ban_ip,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result.ips_imported == 1
|
assert result.ips_imported == 1
|
||||||
@@ -201,8 +213,14 @@ class TestImport:
|
|||||||
session = _make_session("", status=503)
|
session = _make_session("", status=503)
|
||||||
source = await blocklist_service.create_source(db, "Err Source", "https://err.test/")
|
source = await blocklist_service.create_source(db, "Err Source", "https://err.test/")
|
||||||
|
|
||||||
|
from app.services import jail_service
|
||||||
|
|
||||||
result = await blocklist_service.import_source(
|
result = await blocklist_service.import_source(
|
||||||
source, session, "/tmp/fake.sock", db
|
source,
|
||||||
|
session,
|
||||||
|
"/tmp/fake.sock",
|
||||||
|
db,
|
||||||
|
ban_ip=jail_service.ban_ip,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result.ips_imported == 0
|
assert result.ips_imported == 0
|
||||||
@@ -224,7 +242,7 @@ class TestImport:
|
|||||||
call_count += 1
|
call_count += 1
|
||||||
raise JailNotFoundError(jail)
|
raise JailNotFoundError(jail)
|
||||||
|
|
||||||
with patch("app.services.jail_service.ban_ip", side_effect=_raise_jail_not_found) as mocked_ban_ip:
|
with patch("app.services.jail_service.ban_ip", side_effect=_raise_jail_not_found):
|
||||||
from app.services import jail_service
|
from app.services import jail_service
|
||||||
|
|
||||||
result = await blocklist_service.import_source(
|
result = await blocklist_service.import_source(
|
||||||
@@ -349,6 +367,8 @@ class TestGeoPrewarmCacheFilter:
|
|||||||
def _mock_is_cached(ip: str) -> bool:
|
def _mock_is_cached(ip: str) -> bool:
|
||||||
return ip == "1.2.3.4"
|
return ip == "1.2.3.4"
|
||||||
|
|
||||||
|
from app.services import jail_service
|
||||||
|
|
||||||
mock_batch = AsyncMock(return_value={})
|
mock_batch = AsyncMock(return_value={})
|
||||||
with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock):
|
with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock):
|
||||||
result = await blocklist_service.import_source(
|
result = await blocklist_service.import_source(
|
||||||
@@ -356,6 +376,7 @@ class TestGeoPrewarmCacheFilter:
|
|||||||
session,
|
session,
|
||||||
"/tmp/fake.sock",
|
"/tmp/fake.sock",
|
||||||
db,
|
db,
|
||||||
|
ban_ip=jail_service.ban_ip,
|
||||||
geo_is_cached=_mock_is_cached,
|
geo_is_cached=_mock_is_cached,
|
||||||
geo_batch_lookup=mock_batch,
|
geo_batch_lookup=mock_batch,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user