Refactor backend to use request-scoped SQLite connections

This commit is contained in:
2026-04-05 23:14:46 +02:00
parent fde4c480fa
commit 42c030c706
13 changed files with 250 additions and 116 deletions

View File

@@ -125,3 +125,18 @@ async def init_db(db: aiosqlite.Connection) -> None:
await db.executescript(statement)
await db.commit()
log.info("database_schema_ready")
async def open_db(database_path: str) -> aiosqlite.Connection:
"""Open a new application SQLite connection with the standard settings.
Args:
database_path: Path to the BanGUI SQLite database.
Returns:
A configured :class:`aiosqlite.Connection` instance.
"""
db = await aiosqlite.connect(database_path)
db.row_factory = aiosqlite.Row
await db.execute("PRAGMA foreign_keys=ON;")
return db

View File

@@ -7,6 +7,7 @@ directly — to keep coupling explicit and testable.
"""
import time
from collections.abc import AsyncGenerator
from typing import Annotated, Protocol, cast
import aiosqlite
@@ -61,26 +62,35 @@ def invalidate_session_cache(token: str) -> None:
_session_cache.pop(token, None)
async def get_db(request: Request) -> aiosqlite.Connection:
"""Provide the shared :class:`aiosqlite.Connection` from ``app.state``.
async def get_db(request: Request) -> AsyncGenerator[aiosqlite.Connection, None]:
"""Provide a request-scoped :class:`aiosqlite.Connection` for the current request.
Opens a fresh connection for every request and closes it when the request
is finished. This avoids contention and locking issues from a single shared
SQLite connection across concurrent requests.
Args:
request: The current FastAPI request (injected automatically).
Returns:
The application-wide aiosqlite connection opened during startup.
Raises:
HTTPException: 503 if the database has not been initialised.
Yields:
An open :class:`aiosqlite.Connection` for the request.
"""
db: aiosqlite.Connection | None = getattr(request.app.state, "db", None)
if db is None:
log.error("database_not_initialised")
from app.db import open_db # noqa: PLC0415
settings = cast("AppState", request.app.state).settings
try:
db = await open_db(settings.database_path)
except Exception as exc:
log.error("database_open_failed", error=str(exc))
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Database is not available.",
)
return db
) from exc
try:
yield db
finally:
await db.close()
async def get_settings(request: Request) -> Settings:

View File

@@ -23,7 +23,6 @@ if TYPE_CHECKING:
from starlette.responses import Response as StarletteResponse
import aiohttp
import aiosqlite
import structlog
from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore[import-untyped]
from fastapi import FastAPI, Request, status
@@ -33,7 +32,7 @@ from starlette.middleware.base import BaseHTTPMiddleware
from app import __version__
from app.config import Settings, get_settings
from app.db import init_db
from app.db import init_db, open_db
from app.routers import (
auth,
bans,
@@ -145,11 +144,19 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# --- Application database ---
db_path: Path = Path(settings.database_path)
db_path.parent.mkdir(parents=True, exist_ok=True)
from app.services import geo_service # noqa: PLC0415
log.debug("database_directory_ensured", directory=str(db_path.parent))
db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path)
db.row_factory = aiosqlite.Row
await init_db(db)
app.state.db = db
db = await open_db(settings.database_path)
try:
await init_db(db)
await geo_service.load_cache_from_db(db)
unresolved_count = await geo_service.count_unresolved(db)
finally:
await db.close()
if unresolved_count > 0:
log.warning("geo_cache_unresolved_ips", unresolved=unresolved_count)
# --- Shared HTTP client session ---
http_session: aiohttp.ClientSession = aiohttp.ClientSession()
@@ -159,12 +166,6 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
from app.services import geo_service # noqa: PLC0415
geo_service.init_geoip(settings.geoip_db_path)
await geo_service.load_cache_from_db(db)
# Log unresolved geo entries so the operator can see the scope of the issue.
unresolved_count = await geo_service.count_unresolved(db)
if unresolved_count > 0:
log.warning("geo_cache_unresolved_ips", unresolved=unresolved_count)
# --- Background task scheduler ---
scheduler: AsyncIOScheduler = AsyncIOScheduler(timezone="UTC")
@@ -328,9 +329,27 @@ class SetupRedirectMiddleware(BaseHTTPMiddleware):
request.app.state, "_setup_complete_cached", False
):
from app.services import setup_service # noqa: PLC0415
from app.db import open_db # noqa: PLC0415
db: aiosqlite.Connection | None = getattr(request.app.state, "db", None)
if db is None or not await setup_service.is_setup_complete(db):
db = getattr(request.app.state, "db", None)
if db is None:
settings = request.app.state.settings
db = await open_db(settings.database_path)
try:
is_complete = await setup_service.is_setup_complete(db)
except Exception:
log.debug("setup_check_failed", reason="db_uninitialised_or_inaccessible")
is_complete = False
finally:
await db.close()
else:
try:
is_complete = await setup_service.is_setup_complete(db)
except Exception:
log.debug("setup_check_failed", reason="db_uninitialised_or_inaccessible")
is_complete = False
if not is_complete:
return RedirectResponse(
url="/api/setup",
status_code=status.HTTP_307_TEMPORARY_REDIRECT,

View File

@@ -17,7 +17,7 @@ if TYPE_CHECKING:
from fastapi import APIRouter, HTTPException, Request, status
from app.dependencies import AuthDep
from app.dependencies import AuthDep, DbDep
from app.models.ban import ActiveBanListResponse, BanRequest, UnbanAllResponse, UnbanRequest
from app.models.jail import JailCommandResponse
from app.services import geo_service, jail_service
@@ -50,6 +50,7 @@ def _bad_gateway(exc: Exception) -> HTTPException:
async def get_active_bans(
request: Request,
_auth: AuthDep,
db: DbDep,
) -> ActiveBanListResponse:
"""Return every IP that is currently banned across all fail2ban jails.
@@ -68,14 +69,13 @@ async def get_active_bans(
"""
socket_path: str = request.app.state.settings.fail2ban_socket
http_session: aiohttp.ClientSession = request.app.state.http_session
app_db = request.app.state.db
try:
return await jail_service.get_active_bans(
socket_path,
geo_batch_lookup=geo_service.lookup_batch,
http_session=http_session,
app_db=app_db,
app_db=db,
)
except Fail2BanConnectionError as exc:
raise _bad_gateway(exc) from exc

View File

@@ -43,7 +43,7 @@ from typing import Annotated
import structlog
from fastapi import APIRouter, HTTPException, Path, Query, Request, status
from app.dependencies import AuthDep
from app.dependencies import AuthDep, DbDep
from app.models.config import (
ActionConfig,
ActionCreateRequest,
@@ -594,6 +594,7 @@ async def preview_log(
async def get_map_color_thresholds(
request: Request,
_auth: AuthDep,
db: DbDep,
) -> MapColorThresholdsResponse:
"""Return the configured map color thresholds.
@@ -607,7 +608,7 @@ async def get_map_color_thresholds(
"""
from app.services import setup_service
high, medium, low = await setup_service.get_map_color_thresholds(request.app.state.db)
high, medium, low = await setup_service.get_map_color_thresholds(db)
return MapColorThresholdsResponse(
threshold_high=high,
threshold_medium=medium,
@@ -623,6 +624,7 @@ async def get_map_color_thresholds(
async def update_map_color_thresholds(
request: Request,
_auth: AuthDep,
db: DbDep,
body: MapColorThresholdsUpdate,
) -> MapColorThresholdsResponse:
"""Update the map color threshold configuration.
@@ -644,7 +646,7 @@ async def update_map_color_thresholds(
try:
await setup_service.set_map_color_thresholds(
request.app.state.db,
db,
threshold_high=body.threshold_high,
threshold_medium=body.threshold_medium,
threshold_low=body.threshold_low,

View File

@@ -20,7 +20,7 @@ if TYPE_CHECKING:
from fastapi import APIRouter, Query, Request
from app import __version__
from app.dependencies import AuthDep
from app.dependencies import AuthDep, DbDep
from app.models.ban import (
BanOrigin,
BansByCountryResponse,
@@ -82,6 +82,7 @@ async def get_server_status(
async def get_dashboard_bans(
request: Request,
_auth: AuthDep,
db: DbDep,
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
source: Literal["fail2ban", "archive"] = Query(
default="fail2ban",
@@ -125,7 +126,7 @@ async def get_dashboard_bans(
page=page,
page_size=page_size,
http_session=http_session,
app_db=request.app.state.db,
app_db=db,
geo_batch_lookup=geo_service.lookup_batch,
origin=origin,
)
@@ -139,6 +140,7 @@ async def get_dashboard_bans(
async def get_bans_by_country(
request: Request,
_auth: AuthDep,
db: DbDep,
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
source: Literal["fail2ban", "archive"] = Query(
default="fail2ban",
@@ -181,7 +183,7 @@ async def get_bans_by_country(
http_session=http_session,
geo_cache_lookup=geo_service.lookup_cached_only,
geo_batch_lookup=geo_service.lookup_batch,
app_db=request.app.state.db,
app_db=db,
origin=origin,
country_code=country_code,
)
@@ -195,6 +197,7 @@ async def get_bans_by_country(
async def get_ban_trend(
request: Request,
_auth: AuthDep,
db: DbDep,
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
source: Literal["fail2ban", "archive"] = Query(
default="fail2ban",
@@ -235,7 +238,7 @@ async def get_ban_trend(
socket_path,
range,
source=source,
app_db=request.app.state.db,
app_db=db,
origin=origin,
)
@@ -248,6 +251,7 @@ async def get_ban_trend(
async def get_bans_by_jail(
request: Request,
_auth: AuthDep,
db: DbDep,
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
source: Literal["fail2ban", "archive"] = Query(
default="fail2ban",
@@ -281,6 +285,6 @@ async def get_bans_by_jail(
socket_path,
range,
source=source,
app_db=request.app.state.db,
app_db=db,
origin=origin,
)

View File

@@ -22,7 +22,7 @@ if TYPE_CHECKING:
from fastapi import APIRouter, HTTPException, Query, Request
from app.dependencies import AuthDep
from app.dependencies import AuthDep, DbDep
from app.models.ban import BanOrigin, TimeRange
from app.models.history import HistoryListResponse, IpDetailResponse
from app.services import geo_service, history_service
@@ -40,6 +40,7 @@ _DEFAULT_PAGE_SIZE: int = 100
async def get_history(
request: Request,
_auth: AuthDep,
db: DbDep,
range: TimeRange | None = Query(
default=None,
description="Optional time-range filter. Omit for all-time.",
@@ -102,7 +103,7 @@ async def get_history(
page=page,
page_size=page_size,
geo_enricher=_enricher,
db=request.app.state.db,
db=db,
)
@@ -114,6 +115,7 @@ async def get_history(
async def get_history_archive(
request: Request,
_auth: AuthDep,
db: DbDep,
range: TimeRange | None = Query(
default=None,
description="Optional time-range filter. Omit for all-time.",
@@ -138,7 +140,7 @@ async def get_history_archive(
page=page,
page_size=page_size,
geo_enricher=_enricher,
db=request.app.state.db,
db=db,
)

View File

@@ -23,7 +23,7 @@ from typing import Annotated
from fastapi import APIRouter, Body, HTTPException, Path, Request, status
from app.dependencies import AuthDep
from app.dependencies import AuthDep, DbDep
from app.models.ban import JailBannedIpsResponse
from app.models.jail import (
IgnoreIpRequest,
@@ -557,6 +557,7 @@ async def toggle_ignore_self(
async def get_jail_banned_ips(
request: Request,
_auth: AuthDep,
db: DbDep,
name: _NamePath,
page: int = 1,
page_size: int = 25,
@@ -597,7 +598,6 @@ async def get_jail_banned_ips(
socket_path: str = request.app.state.settings.fail2ban_socket
http_session = getattr(request.app.state, "http_session", None)
app_db = getattr(request.app.state, "db", None)
try:
return await jail_service.get_jail_banned_ips(
@@ -608,7 +608,7 @@ async def get_jail_banned_ips(
search=search,
geo_batch_lookup=geo_service.lookup_batch,
http_session=http_session,
app_db=app_db,
app_db=db,
)
except JailNotFoundError:
raise _not_found(name) from None

View File

@@ -17,9 +17,13 @@ from typing import TYPE_CHECKING, Any
import structlog
from app.db import open_db
from app.models.blocklist import ScheduleFrequency
from app.services import blocklist_service
if TYPE_CHECKING:
import aiosqlite
if TYPE_CHECKING:
from fastapi import FastAPI
@@ -29,6 +33,15 @@ log: structlog.stdlib.BoundLogger = structlog.get_logger()
JOB_ID: str = "blocklist_import"
async def _get_db(app: Any) -> tuple[aiosqlite.Connection, bool]:
existing_db = getattr(app.state, "db", None)
if existing_db is not None:
return existing_db, False
db = await open_db(app.state.settings.database_path)
return db, True
async def _run_import(app: Any) -> None:
"""APScheduler callback that imports all enabled blocklist sources.
@@ -39,12 +52,10 @@ async def _run_import(app: Any) -> None:
app: The :class:`fastapi.FastAPI` application instance passed via
APScheduler ``kwargs``.
"""
db = app.state.db
db, close_db = await _get_db(app)
http_session = app.state.http_session
socket_path: str = app.state.settings.fail2ban_socket
from app.services import jail_service
log.info("blocklist_import_starting")
try:
result = await blocklist_service.import_all(
@@ -60,6 +71,9 @@ async def _run_import(app: Any) -> None:
)
except Exception:
log.exception("blocklist_import_unexpected_error")
finally:
if close_db:
await db.close()
def register(app: FastAPI) -> None:
@@ -78,7 +92,12 @@ def register(app: FastAPI) -> None:
import asyncio # noqa: PLC0415
async def _do_register() -> None:
config = await blocklist_service.get_schedule(app.state.db)
db, close_db = await _get_db(app)
try:
config = await blocklist_service.get_schedule(db)
finally:
if close_db:
await db.close()
_apply_schedule(app, config)
# APScheduler is synchronous at registration time; use asyncio to read
@@ -104,7 +123,12 @@ def reschedule(app: FastAPI) -> None:
import asyncio # noqa: PLC0415
async def _do_reschedule() -> None:
config = await blocklist_service.get_schedule(app.state.db)
db, close_db = await _get_db(app)
try:
config = await blocklist_service.get_schedule(db)
finally:
if close_db:
await db.close()
_apply_schedule(app, config)
asyncio.ensure_future(_do_reschedule())

View File

@@ -15,6 +15,10 @@ from typing import TYPE_CHECKING, Any
import structlog
from app.db import open_db
if TYPE_CHECKING:
import aiosqlite
from app.services import geo_service
if TYPE_CHECKING:
@@ -29,6 +33,15 @@ GEO_FLUSH_INTERVAL: int = 60
JOB_ID: str = "geo_cache_flush"
async def _get_db(app: Any) -> tuple[aiosqlite.Connection, bool]:
existing_db = getattr(app.state, "db", None)
if existing_db is not None:
return existing_db, False
db = await open_db(app.state.settings.database_path)
return db, True
async def _run_flush(app: Any) -> None:
"""Flush the geo service dirty set to the application database.
@@ -39,8 +52,13 @@ async def _run_flush(app: Any) -> None:
app: The :class:`fastapi.FastAPI` application instance passed via
APScheduler ``kwargs``.
"""
db = app.state.db
count = await geo_service.flush_dirty(db)
db, close_db = await _get_db(app)
try:
count = await geo_service.flush_dirty(db)
finally:
if close_db:
await db.close()
if count > 0:
log.debug("geo_cache_flush_ran", flushed=count)

View File

@@ -21,6 +21,10 @@ from typing import TYPE_CHECKING
import structlog
from app.db import open_db
if TYPE_CHECKING:
import aiosqlite
from app.services import geo_service
if TYPE_CHECKING:
@@ -35,6 +39,15 @@ GEO_RE_RESOLVE_INTERVAL: int = 600
JOB_ID: str = "geo_re_resolve"
async def _get_db(app: FastAPI) -> tuple[aiosqlite.Connection, bool]:
existing_db = getattr(app.state, "db", None)
if existing_db is not None:
return existing_db, False
db = await open_db(app.state.settings.database_path)
return db, True
async def _run_re_resolve(app: FastAPI) -> None:
"""Query NULL-country IPs from the database and re-resolve them.
@@ -45,33 +58,37 @@ async def _run_re_resolve(app: FastAPI) -> None:
app: The :class:`fastapi.FastAPI` application instance passed via
APScheduler ``kwargs``.
"""
db = app.state.db
db, close_db = await _get_db(app)
http_session = app.state.http_session
# Fetch all IPs with NULL country_code from the persistent cache.
unresolved_ips = await geo_service.get_unresolved_ips(db)
try:
# Fetch all IPs with NULL country_code from the persistent cache.
unresolved_ips = await geo_service.get_unresolved_ips(db)
if not unresolved_ips:
log.debug("geo_re_resolve_skip", reason="no_unresolved_ips")
return
if not unresolved_ips:
log.debug("geo_re_resolve_skip", reason="no_unresolved_ips")
return
log.info("geo_re_resolve_start", unresolved=len(unresolved_ips))
log.info("geo_re_resolve_start", unresolved=len(unresolved_ips))
# Clear the negative cache so these IPs are eligible for fresh API calls.
geo_service.clear_neg_cache()
# Clear the negative cache so these IPs are eligible for fresh API calls.
geo_service.clear_neg_cache()
# lookup_batch handles throttling, retries, and persistence when db is
# passed. This is a background task so DB writes are allowed.
results = await geo_service.lookup_batch(unresolved_ips, http_session, db=db)
# lookup_batch handles throttling, retries, and persistence when db is
# passed. This is a background task so DB writes are allowed.
results = await geo_service.lookup_batch(unresolved_ips, http_session, db=db)
resolved_count: int = sum(
1 for info in results.values() if info.country_code is not None
)
log.info(
"geo_re_resolve_complete",
retried=len(unresolved_ips),
resolved=resolved_count,
)
resolved_count: int = sum(
1 for info in results.values() if info.country_code is not None
)
log.info(
"geo_re_resolve_complete",
retried=len(unresolved_ips),
resolved=resolved_count,
)
finally:
if close_db:
await db.close()
def register(app: FastAPI) -> None:

View File

@@ -9,8 +9,12 @@ from __future__ import annotations
import datetime
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import aiosqlite
import structlog
from app.db import open_db
from app.repositories import fail2ban_db_repo
from app.utils.fail2ban_db_utils import get_fail2ban_db_path
@@ -29,6 +33,15 @@ HISTORY_SYNC_INTERVAL: int = 300
BACKFILL_WINDOW: int = 648000
async def _get_db(app: FastAPI) -> tuple[aiosqlite.Connection, bool]:
existing_db = getattr(app.state, "db", None)
if existing_db is not None:
return existing_db, False
db = await open_db(app.state.settings.database_path)
return db, True
async def _get_last_archive_ts(db) -> int | None:
async with db.execute("SELECT MAX(timeofban) FROM history_archive") as cur:
row = await cur.fetchone()
@@ -38,8 +51,8 @@ async def _get_last_archive_ts(db) -> int | None:
async def _run_sync(app: FastAPI) -> None:
db = app.state.db
socket_path: str = app.state.settings.fail2ban_socket
db, close_db = await _get_db(app)
try:
last_ts = await _get_last_archive_ts(db)
@@ -90,6 +103,9 @@ async def _run_sync(app: FastAPI) -> None:
except Exception:
log.exception("history_sync_failed")
finally:
if close_db:
await db.close()
def register(app: FastAPI) -> None: