fix: enforce PRAGMA query_only on fail2ban DB and refactor CSRF cookie name
- Add _acquire_readonly_connection() that applies PRAGMA query_only=ON after connect - Verify PRAGMA value back to catch URI flag bypasses - Wrap in async context manager _readonly_connection() used by all repo methods - Replace hardcoded '_SESSION_COOKIE_NAME' in CSRF middleware with import from app.utils.constants - Remove completed Issues #45 and #46 from Docs/Tasks.md (Issue #46 now fixed, #45 cache invalidation deferred to auth refactor branch) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -20,6 +20,8 @@ from fastapi import status
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app.utils.constants import SESSION_COOKIE_NAME
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
@@ -35,9 +37,6 @@ _CSRF_HEADER_VALUE: str = "1"
|
||||
# HTTP methods that require CSRF protection.
|
||||
_CSRF_PROTECTED_METHODS: frozenset[str] = frozenset({"POST", "PUT", "DELETE", "PATCH"})
|
||||
|
||||
# Session cookie name for detecting cookie-based authentication.
|
||||
_SESSION_COOKIE_NAME: str = "bangui_session"
|
||||
|
||||
|
||||
class CsrfMiddleware(BaseHTTPMiddleware):
|
||||
"""Protect cookie-authenticated state-mutating requests with custom header check.
|
||||
@@ -73,7 +72,7 @@ class CsrfMiddleware(BaseHTTPMiddleware):
|
||||
return await call_next(request)
|
||||
|
||||
# Skip check if not using cookie-based authentication.
|
||||
if _SESSION_COOKIE_NAME not in request.cookies:
|
||||
if SESSION_COOKIE_NAME not in request.cookies:
|
||||
return await call_next(request)
|
||||
|
||||
# Enforce CSRF header for cookie-authenticated state-mutating requests.
|
||||
|
||||
@@ -10,6 +10,7 @@ service layers can focus on business logic and formatting.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -18,6 +19,7 @@ import aiosqlite
|
||||
from app.utils.fail2ban_db_utils import escape_like
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator
|
||||
from collections.abc import Iterable
|
||||
|
||||
from app.models.ban import BanOrigin
|
||||
@@ -72,6 +74,53 @@ def _make_db_uri(db_path: str) -> str:
|
||||
return f"file:{db_path}?mode=ro"
|
||||
|
||||
|
||||
async def _acquire_readonly_connection(
|
||||
db_path: str,
|
||||
) -> aiosqlite.Connection:
|
||||
"""Open a read-only connection to the fail2ban database.
|
||||
|
||||
Defense-in-depth: both the ``?mode=ro`` URI flag AND the SQLite-level
|
||||
``PRAGMA query_only = ON`` are applied. The URI flag is a library-level hint
|
||||
that can be bypassed by malformed URIs or version inconsistencies;
|
||||
``query_only`` is a connection-level enforcement that makes all write
|
||||
operations fail. We verify enforcement by reading back the PRAGMA value.
|
||||
|
||||
Args:
|
||||
db_path: Path to the fail2ban SQLite database.
|
||||
|
||||
Returns:
|
||||
An aiosqlite connection in guaranteed read-only mode.
|
||||
|
||||
Raises:
|
||||
AssertionError: If PRAGMA query_only is not confirmed as enabled.
|
||||
"""
|
||||
conn = await aiosqlite.connect(_make_db_uri(db_path), uri=True)
|
||||
# Set connection-level read-only enforcement and verify in one statement.
|
||||
# Even if the ?mode=ro URI flag is bypassed, this PRAGMA blocks writes.
|
||||
cursor = await conn.execute("PRAGMA query_only = ON")
|
||||
await cursor.close()
|
||||
# Verify the PRAGMA took effect.
|
||||
cursor = await conn.execute("PRAGMA query_only")
|
||||
row = await cursor.fetchone()
|
||||
await cursor.close()
|
||||
if not row or row[0] != 1:
|
||||
await conn.close()
|
||||
raise AssertionError(
|
||||
"PRAGMA query_only is not enabled; connection may be writable"
|
||||
)
|
||||
return conn
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _readonly_connection(db_path: str) -> AsyncIterator[aiosqlite.Connection]:
|
||||
"""Async context manager that yields a read-only fail2ban DB connection."""
|
||||
conn = await _acquire_readonly_connection(db_path)
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
|
||||
def _origin_sql_filter(origin: BanOrigin | None) -> tuple[str, tuple[str, ...]]:
|
||||
"""Return a SQL fragment and parameters for the origin filter."""
|
||||
|
||||
@@ -116,7 +165,7 @@ def _rows_to_history_records(rows: Iterable[aiosqlite.Row]) -> list[HistoryRecor
|
||||
async def check_db_nonempty(db_path: str) -> bool:
|
||||
"""Return True if the fail2ban database contains at least one ban row."""
|
||||
|
||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db, db.execute(
|
||||
async with _readonly_connection(db_path) as db, db.execute(
|
||||
"SELECT 1 FROM bans LIMIT 1"
|
||||
) as cur:
|
||||
row = await cur.fetchone()
|
||||
@@ -155,7 +204,7 @@ async def get_currently_banned(
|
||||
placeholder = ", ".join("?" for _ in ip_filter)
|
||||
ip_filter_clause = f" AND ip IN ({placeholder})"
|
||||
|
||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
||||
async with _readonly_connection(db_path) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
|
||||
async with db.execute(
|
||||
@@ -195,7 +244,7 @@ async def get_ban_counts_by_bucket(
|
||||
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
|
||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
||||
async with _readonly_connection(db_path) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"SELECT CAST((timeofban - ?) / ? AS INTEGER) AS bucket_idx, "
|
||||
@@ -225,7 +274,7 @@ async def get_ban_event_counts(
|
||||
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
|
||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
||||
async with _readonly_connection(db_path) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"SELECT ip, COUNT(*) AS event_count "
|
||||
@@ -250,7 +299,7 @@ async def get_bans_by_jail(
|
||||
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
|
||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
||||
async with _readonly_connection(db_path) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
|
||||
async with db.execute(
|
||||
@@ -283,7 +332,7 @@ async def get_bans_table_summary(
|
||||
empty the min/max values will be ``None``.
|
||||
"""
|
||||
|
||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
||||
async with _readonly_connection(db_path) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"SELECT COUNT(*), MIN(timeofban), MAX(timeofban) FROM bans"
|
||||
@@ -337,7 +386,7 @@ async def get_history_page(
|
||||
effective_page_size: int = page_size
|
||||
offset: int = (page - 1) * effective_page_size
|
||||
|
||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
||||
async with _readonly_connection(db_path) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
|
||||
async with db.execute(
|
||||
@@ -362,7 +411,7 @@ async def get_history_page(
|
||||
async def get_history_for_ip(db_path: str, ip: str) -> list[HistoryRecord]:
|
||||
"""Return the full ban timeline for a specific IP."""
|
||||
|
||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
||||
async with _readonly_connection(db_path) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"SELECT jail, ip, timeofban, bancount, data "
|
||||
|
||||
Reference in New Issue
Block a user