Refactor backend to use request-scoped SQLite connections

This commit is contained in:
2026-04-05 23:14:46 +02:00
parent fde4c480fa
commit 42c030c706
13 changed files with 250 additions and 116 deletions

View File

@@ -10,64 +10,71 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue.
--- ---
### TASK-001 — WorldMap: filter companion table by selected country (server-side) ### Backend Architecture
**Status:** Done - **Replace the single shared SQLite connection.** ✅
**Priority:** Medium - Current startup code opens one `aiosqlite.Connection` and reuses it for every request.
**Domain:** Full-stack (backend + frontend) - This was replaced with request-scoped connections to avoid concurrency and locking issues.
**References:** `Docs/Features.md §4`, `Docs/Web-Development.md` - Request dependencies, application lifecycle, and tests were updated to use the new pattern.
#### Background - **Refactor dependency wiring and shared resource management.**
- Remove hidden module-level import coupling between routers, services, and shared utilities.
- Introduce explicit factories or providers for shared resources such as DB, HTTP client session, scheduler, and settings.
- Ensure routers depend on injected providers rather than global state or dynamic imports.
The `GET /api/dashboard/bans/by-country` endpoint always returns the **200 most recent** ban rows in `bans` (constant `_MAX_COMPANION_BANS = 200` in `backend/app/services/ban_service.py`). `MapPage.tsx` stores a `selectedCountry` state and filters the returned rows client-side via `visibleBans`. This means the companion table can only show the fraction of a country's bans that fall within the global top-200 window. If the selected time range has, say, 1 500 bans and 300 are from China, but China's bans are not all in the top 200 overall, the table will silently display fewer than 300 rows. - **Harden fail2ban integration.**
- Remove the `sys.path` hack that locates `fail2ban-master` at runtime.
- Replace it with a deterministic packaging or configuration model so the backend does not depend on repository layout.
- Refactor `Fail2BanClient` so concurrency control is instance-based and not backed by hidden module globals.
When a country is selected the companion table **must** return the complete set of bans for that country so the user sees an accurate picture. - **Improve startup / setup guard behavior.**
- Convert `SetupRedirectMiddleware` from an on-demand DB check into a startup/initialisation guard where possible.
- Cache setup completion in a safe way and provide an explicit invalidation path if the application state changes.
- Reduce middleware responsibility and avoid DB access during normal request dispatch.
#### Desired behaviour - **Make deployment configuration explicit.**
- Move hard-coded environment assumptions such as CORS origins into settings.
- Ensure `fail2ban_socket`, `fail2ban_config_dir`, and startup commands are fully configurable via `Settings`.
- Document production-ready defaults separately from development defaults.
- No country selected → companion table shows the 200 most recent bans across all countries (existing behaviour, no change). ### Reliability and Resilience
- Country selected → the server returns **all** ban entries for that country in the selected time window; no client-side row-count cap applies.
- Deselecting a country (clicking the same country again, or the "Clear filter" button) reverts to the default 200-row unfiltered view.
- The existing `visibleBans` client-side filter in `MapPage.tsx` can remain as a defensive guard but must not be the only filter.
#### Implementation steps - **Add backend lifecycle tests for resource cleanup.**
- Verify startup opens and initialises DB, HTTP session, scheduler, and geo cache correctly.
- Verify shutdown closes those resources cleanly.
1. **Backend — router** (`backend/app/routers/dashboard.py`) - **Add concurrency/regression coverage for DB and fail2ban socket use.**
- Add `country_code: str | None = Query(default=None, description="ISO alpha-2 country code to filter companion rows.")` to `get_bans_by_country`. - Add tests that simulate multiple concurrent requests using the same DB dependency.
- Pass it to `ban_service.bans_by_country(..., country_code=country_code)`. - Add tests around fail2ban socket retries, protocol errors, and rate limiting.
2. **Backend — service** (`backend/app/services/ban_service.py`) - **Improve state caching and invalidation.**
- Add `country_code: str | None = None` keyword argument to `bans_by_country`. - Add tests for session cache invalidation on logout.
- After `geo_map` is built (existing geo-resolution step), collect IPs whose resolved country matches `country_code`. - Add tests for setup completion caching so stale state is never served.
- For the **fail2ban source**: call `fail2ban_db_repo.get_currently_banned` with `ip_filter=matched_ips` and no `limit` (remove the `_MAX_COMPANION_BANS` cap for filtered queries).
- For the **archive source**: filter `all_rows` to those whose IP is in `matched_ips` and return all of them (skip the `page_size=_MAX_COMPANION_BANS` call).
- When `country_code` is `None`, behaviour is identical to today.
3. **Backend — repository** (`backend/app/repositories/fail2ban_db_repo.py`) ### Backend Feature Work
- Add `ip_filter: list[str] | None = None` to `get_currently_banned`.
- When provided and non-empty, append `AND ip IN ({placeholders})` to the SQL `WHERE` clause, parameterised safely (never interpolated as a string).
4. **Backend — repository (archive)** (`backend/app/repositories/history_archive_repo.py`) - **Document and implement backend-safe environment-driven CORS.**
- Similarly add optional `ip_filter` to the archive companion-rows query used from `bans_by_country`. - Add support for production and local development origins through configuration.
- Avoid a hardcoded Vite origin in the core app factory.
5. **Frontend — API client** (`frontend/src/api/map.ts`) - **Centralise scheduler job registration.**
- Add optional `countryCode?: string` parameter to `fetchBansByCountry`. - Refactor APScheduler registration so background tasks are registered through a common lifecycle helper.
- When set, append `country_code=<value>` to the query string. - Ensure jobs can be discovered, replaced, and tested without requiring implicit `app.state` side effects.
6. **Frontend — hook** (`frontend/src/hooks/useMapData.ts`) - **Strengthen fail2ban error handling and reporting.**
- Add `countryCode?: string` to the function signature. - Standardise `502` responses for connection/protocol failures across all endpoints.
- Include it in the `useCallback` dependency array and pass it to `fetchBansByCountry`. - Add structured logging for retries and fatal socket failures.
- Ensure the UI can distinguish offline fail2ban from internal backend failures.
7. **Frontend — page** (`frontend/src/pages/MapPage.tsx`) - **Improve documentation of backend responsibilities.**
- Pass `selectedCountry ?? undefined` as `countryCode` to `useMapData`. - Keep `Docs/Tasks.md` aligned with the backend architecture review.
- The hook's effect will re-fetch automatically when `selectedCountry` changes; the existing `useEffect` that resets `page` to 1 already covers this. - Add references to the backend modules, resource lifecycle, and dependency model in the documentation.
#### Testing guidance ### Priority Execution Plan
- Select a country that has > 200 bans in the chosen time window; confirm the companion table shows more than the previous cap would allow. 1. Fix the global SQLite connection pattern and tests.
- With no country selected, confirm only 200 rows are returned (no regression). 2. Refactor dependency injection / explicit shared resources.
- Deselect the country; confirm the unfiltered 200-row view is restored. 3. Harden fail2ban client concurrency and packaging.
- Test with the archive source as well as the fail2ban live source. 4. Convert setup guard to a safer startup-driven model.
- Verify the `ip_filter` SQL clause is parameterised and cannot be injected. 5. Add deployment-safe configuration and production-ready CORS.
6. Add lifecycle and concurrency regression tests.
---

View File

@@ -125,3 +125,18 @@ async def init_db(db: aiosqlite.Connection) -> None:
await db.executescript(statement) await db.executescript(statement)
await db.commit() await db.commit()
log.info("database_schema_ready") log.info("database_schema_ready")
async def open_db(database_path: str) -> aiosqlite.Connection:
"""Open a new application SQLite connection with the standard settings.
Args:
database_path: Path to the BanGUI SQLite database.
Returns:
A configured :class:`aiosqlite.Connection` instance.
"""
db = await aiosqlite.connect(database_path)
db.row_factory = aiosqlite.Row
await db.execute("PRAGMA foreign_keys=ON;")
return db

View File

@@ -7,6 +7,7 @@ directly — to keep coupling explicit and testable.
""" """
import time import time
from collections.abc import AsyncGenerator
from typing import Annotated, Protocol, cast from typing import Annotated, Protocol, cast
import aiosqlite import aiosqlite
@@ -61,26 +62,35 @@ def invalidate_session_cache(token: str) -> None:
_session_cache.pop(token, None) _session_cache.pop(token, None)
async def get_db(request: Request) -> aiosqlite.Connection: async def get_db(request: Request) -> AsyncGenerator[aiosqlite.Connection, None]:
"""Provide the shared :class:`aiosqlite.Connection` from ``app.state``. """Provide a request-scoped :class:`aiosqlite.Connection` for the current request.
Opens a fresh connection for every request and closes it when the request
is finished. This avoids contention and locking issues from a single shared
SQLite connection across concurrent requests.
Args: Args:
request: The current FastAPI request (injected automatically). request: The current FastAPI request (injected automatically).
Returns: Yields:
The application-wide aiosqlite connection opened during startup. An open :class:`aiosqlite.Connection` for the request.
Raises:
HTTPException: 503 if the database has not been initialised.
""" """
db: aiosqlite.Connection | None = getattr(request.app.state, "db", None) from app.db import open_db # noqa: PLC0415
if db is None:
log.error("database_not_initialised") settings = cast("AppState", request.app.state).settings
try:
db = await open_db(settings.database_path)
except Exception as exc:
log.error("database_open_failed", error=str(exc))
raise HTTPException( raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Database is not available.", detail="Database is not available.",
) ) from exc
return db
try:
yield db
finally:
await db.close()
async def get_settings(request: Request) -> Settings: async def get_settings(request: Request) -> Settings:

View File

@@ -23,7 +23,6 @@ if TYPE_CHECKING:
from starlette.responses import Response as StarletteResponse from starlette.responses import Response as StarletteResponse
import aiohttp import aiohttp
import aiosqlite
import structlog import structlog
from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore[import-untyped] from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore[import-untyped]
from fastapi import FastAPI, Request, status from fastapi import FastAPI, Request, status
@@ -33,7 +32,7 @@ from starlette.middleware.base import BaseHTTPMiddleware
from app import __version__ from app import __version__
from app.config import Settings, get_settings from app.config import Settings, get_settings
from app.db import init_db from app.db import init_db, open_db
from app.routers import ( from app.routers import (
auth, auth,
bans, bans,
@@ -145,11 +144,19 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# --- Application database --- # --- Application database ---
db_path: Path = Path(settings.database_path) db_path: Path = Path(settings.database_path)
db_path.parent.mkdir(parents=True, exist_ok=True) db_path.parent.mkdir(parents=True, exist_ok=True)
from app.services import geo_service # noqa: PLC0415
log.debug("database_directory_ensured", directory=str(db_path.parent)) log.debug("database_directory_ensured", directory=str(db_path.parent))
db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path) db = await open_db(settings.database_path)
db.row_factory = aiosqlite.Row try:
await init_db(db) await init_db(db)
app.state.db = db await geo_service.load_cache_from_db(db)
unresolved_count = await geo_service.count_unresolved(db)
finally:
await db.close()
if unresolved_count > 0:
log.warning("geo_cache_unresolved_ips", unresolved=unresolved_count)
# --- Shared HTTP client session --- # --- Shared HTTP client session ---
http_session: aiohttp.ClientSession = aiohttp.ClientSession() http_session: aiohttp.ClientSession = aiohttp.ClientSession()
@@ -159,12 +166,6 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
from app.services import geo_service # noqa: PLC0415 from app.services import geo_service # noqa: PLC0415
geo_service.init_geoip(settings.geoip_db_path) geo_service.init_geoip(settings.geoip_db_path)
await geo_service.load_cache_from_db(db)
# Log unresolved geo entries so the operator can see the scope of the issue.
unresolved_count = await geo_service.count_unresolved(db)
if unresolved_count > 0:
log.warning("geo_cache_unresolved_ips", unresolved=unresolved_count)
# --- Background task scheduler --- # --- Background task scheduler ---
scheduler: AsyncIOScheduler = AsyncIOScheduler(timezone="UTC") scheduler: AsyncIOScheduler = AsyncIOScheduler(timezone="UTC")
@@ -328,9 +329,27 @@ class SetupRedirectMiddleware(BaseHTTPMiddleware):
request.app.state, "_setup_complete_cached", False request.app.state, "_setup_complete_cached", False
): ):
from app.services import setup_service # noqa: PLC0415 from app.services import setup_service # noqa: PLC0415
from app.db import open_db # noqa: PLC0415
db: aiosqlite.Connection | None = getattr(request.app.state, "db", None) db = getattr(request.app.state, "db", None)
if db is None or not await setup_service.is_setup_complete(db): if db is None:
settings = request.app.state.settings
db = await open_db(settings.database_path)
try:
is_complete = await setup_service.is_setup_complete(db)
except Exception:
log.debug("setup_check_failed", reason="db_uninitialised_or_inaccessible")
is_complete = False
finally:
await db.close()
else:
try:
is_complete = await setup_service.is_setup_complete(db)
except Exception:
log.debug("setup_check_failed", reason="db_uninitialised_or_inaccessible")
is_complete = False
if not is_complete:
return RedirectResponse( return RedirectResponse(
url="/api/setup", url="/api/setup",
status_code=status.HTTP_307_TEMPORARY_REDIRECT, status_code=status.HTTP_307_TEMPORARY_REDIRECT,

View File

@@ -17,7 +17,7 @@ if TYPE_CHECKING:
from fastapi import APIRouter, HTTPException, Request, status from fastapi import APIRouter, HTTPException, Request, status
from app.dependencies import AuthDep from app.dependencies import AuthDep, DbDep
from app.models.ban import ActiveBanListResponse, BanRequest, UnbanAllResponse, UnbanRequest from app.models.ban import ActiveBanListResponse, BanRequest, UnbanAllResponse, UnbanRequest
from app.models.jail import JailCommandResponse from app.models.jail import JailCommandResponse
from app.services import geo_service, jail_service from app.services import geo_service, jail_service
@@ -50,6 +50,7 @@ def _bad_gateway(exc: Exception) -> HTTPException:
async def get_active_bans( async def get_active_bans(
request: Request, request: Request,
_auth: AuthDep, _auth: AuthDep,
db: DbDep,
) -> ActiveBanListResponse: ) -> ActiveBanListResponse:
"""Return every IP that is currently banned across all fail2ban jails. """Return every IP that is currently banned across all fail2ban jails.
@@ -68,14 +69,13 @@ async def get_active_bans(
""" """
socket_path: str = request.app.state.settings.fail2ban_socket socket_path: str = request.app.state.settings.fail2ban_socket
http_session: aiohttp.ClientSession = request.app.state.http_session http_session: aiohttp.ClientSession = request.app.state.http_session
app_db = request.app.state.db
try: try:
return await jail_service.get_active_bans( return await jail_service.get_active_bans(
socket_path, socket_path,
geo_batch_lookup=geo_service.lookup_batch, geo_batch_lookup=geo_service.lookup_batch,
http_session=http_session, http_session=http_session,
app_db=app_db, app_db=db,
) )
except Fail2BanConnectionError as exc: except Fail2BanConnectionError as exc:
raise _bad_gateway(exc) from exc raise _bad_gateway(exc) from exc

View File

@@ -43,7 +43,7 @@ from typing import Annotated
import structlog import structlog
from fastapi import APIRouter, HTTPException, Path, Query, Request, status from fastapi import APIRouter, HTTPException, Path, Query, Request, status
from app.dependencies import AuthDep from app.dependencies import AuthDep, DbDep
from app.models.config import ( from app.models.config import (
ActionConfig, ActionConfig,
ActionCreateRequest, ActionCreateRequest,
@@ -594,6 +594,7 @@ async def preview_log(
async def get_map_color_thresholds( async def get_map_color_thresholds(
request: Request, request: Request,
_auth: AuthDep, _auth: AuthDep,
db: DbDep,
) -> MapColorThresholdsResponse: ) -> MapColorThresholdsResponse:
"""Return the configured map color thresholds. """Return the configured map color thresholds.
@@ -607,7 +608,7 @@ async def get_map_color_thresholds(
""" """
from app.services import setup_service from app.services import setup_service
high, medium, low = await setup_service.get_map_color_thresholds(request.app.state.db) high, medium, low = await setup_service.get_map_color_thresholds(db)
return MapColorThresholdsResponse( return MapColorThresholdsResponse(
threshold_high=high, threshold_high=high,
threshold_medium=medium, threshold_medium=medium,
@@ -623,6 +624,7 @@ async def get_map_color_thresholds(
async def update_map_color_thresholds( async def update_map_color_thresholds(
request: Request, request: Request,
_auth: AuthDep, _auth: AuthDep,
db: DbDep,
body: MapColorThresholdsUpdate, body: MapColorThresholdsUpdate,
) -> MapColorThresholdsResponse: ) -> MapColorThresholdsResponse:
"""Update the map color threshold configuration. """Update the map color threshold configuration.
@@ -644,7 +646,7 @@ async def update_map_color_thresholds(
try: try:
await setup_service.set_map_color_thresholds( await setup_service.set_map_color_thresholds(
request.app.state.db, db,
threshold_high=body.threshold_high, threshold_high=body.threshold_high,
threshold_medium=body.threshold_medium, threshold_medium=body.threshold_medium,
threshold_low=body.threshold_low, threshold_low=body.threshold_low,

View File

@@ -20,7 +20,7 @@ if TYPE_CHECKING:
from fastapi import APIRouter, Query, Request from fastapi import APIRouter, Query, Request
from app import __version__ from app import __version__
from app.dependencies import AuthDep from app.dependencies import AuthDep, DbDep
from app.models.ban import ( from app.models.ban import (
BanOrigin, BanOrigin,
BansByCountryResponse, BansByCountryResponse,
@@ -82,6 +82,7 @@ async def get_server_status(
async def get_dashboard_bans( async def get_dashboard_bans(
request: Request, request: Request,
_auth: AuthDep, _auth: AuthDep,
db: DbDep,
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."), range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
source: Literal["fail2ban", "archive"] = Query( source: Literal["fail2ban", "archive"] = Query(
default="fail2ban", default="fail2ban",
@@ -125,7 +126,7 @@ async def get_dashboard_bans(
page=page, page=page,
page_size=page_size, page_size=page_size,
http_session=http_session, http_session=http_session,
app_db=request.app.state.db, app_db=db,
geo_batch_lookup=geo_service.lookup_batch, geo_batch_lookup=geo_service.lookup_batch,
origin=origin, origin=origin,
) )
@@ -139,6 +140,7 @@ async def get_dashboard_bans(
async def get_bans_by_country( async def get_bans_by_country(
request: Request, request: Request,
_auth: AuthDep, _auth: AuthDep,
db: DbDep,
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."), range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
source: Literal["fail2ban", "archive"] = Query( source: Literal["fail2ban", "archive"] = Query(
default="fail2ban", default="fail2ban",
@@ -181,7 +183,7 @@ async def get_bans_by_country(
http_session=http_session, http_session=http_session,
geo_cache_lookup=geo_service.lookup_cached_only, geo_cache_lookup=geo_service.lookup_cached_only,
geo_batch_lookup=geo_service.lookup_batch, geo_batch_lookup=geo_service.lookup_batch,
app_db=request.app.state.db, app_db=db,
origin=origin, origin=origin,
country_code=country_code, country_code=country_code,
) )
@@ -195,6 +197,7 @@ async def get_bans_by_country(
async def get_ban_trend( async def get_ban_trend(
request: Request, request: Request,
_auth: AuthDep, _auth: AuthDep,
db: DbDep,
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."), range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
source: Literal["fail2ban", "archive"] = Query( source: Literal["fail2ban", "archive"] = Query(
default="fail2ban", default="fail2ban",
@@ -235,7 +238,7 @@ async def get_ban_trend(
socket_path, socket_path,
range, range,
source=source, source=source,
app_db=request.app.state.db, app_db=db,
origin=origin, origin=origin,
) )
@@ -248,6 +251,7 @@ async def get_ban_trend(
async def get_bans_by_jail( async def get_bans_by_jail(
request: Request, request: Request,
_auth: AuthDep, _auth: AuthDep,
db: DbDep,
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."), range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
source: Literal["fail2ban", "archive"] = Query( source: Literal["fail2ban", "archive"] = Query(
default="fail2ban", default="fail2ban",
@@ -281,6 +285,6 @@ async def get_bans_by_jail(
socket_path, socket_path,
range, range,
source=source, source=source,
app_db=request.app.state.db, app_db=db,
origin=origin, origin=origin,
) )

View File

@@ -22,7 +22,7 @@ if TYPE_CHECKING:
from fastapi import APIRouter, HTTPException, Query, Request from fastapi import APIRouter, HTTPException, Query, Request
from app.dependencies import AuthDep from app.dependencies import AuthDep, DbDep
from app.models.ban import BanOrigin, TimeRange from app.models.ban import BanOrigin, TimeRange
from app.models.history import HistoryListResponse, IpDetailResponse from app.models.history import HistoryListResponse, IpDetailResponse
from app.services import geo_service, history_service from app.services import geo_service, history_service
@@ -40,6 +40,7 @@ _DEFAULT_PAGE_SIZE: int = 100
async def get_history( async def get_history(
request: Request, request: Request,
_auth: AuthDep, _auth: AuthDep,
db: DbDep,
range: TimeRange | None = Query( range: TimeRange | None = Query(
default=None, default=None,
description="Optional time-range filter. Omit for all-time.", description="Optional time-range filter. Omit for all-time.",
@@ -102,7 +103,7 @@ async def get_history(
page=page, page=page,
page_size=page_size, page_size=page_size,
geo_enricher=_enricher, geo_enricher=_enricher,
db=request.app.state.db, db=db,
) )
@@ -114,6 +115,7 @@ async def get_history(
async def get_history_archive( async def get_history_archive(
request: Request, request: Request,
_auth: AuthDep, _auth: AuthDep,
db: DbDep,
range: TimeRange | None = Query( range: TimeRange | None = Query(
default=None, default=None,
description="Optional time-range filter. Omit for all-time.", description="Optional time-range filter. Omit for all-time.",
@@ -138,7 +140,7 @@ async def get_history_archive(
page=page, page=page,
page_size=page_size, page_size=page_size,
geo_enricher=_enricher, geo_enricher=_enricher,
db=request.app.state.db, db=db,
) )

View File

@@ -23,7 +23,7 @@ from typing import Annotated
from fastapi import APIRouter, Body, HTTPException, Path, Request, status from fastapi import APIRouter, Body, HTTPException, Path, Request, status
from app.dependencies import AuthDep from app.dependencies import AuthDep, DbDep
from app.models.ban import JailBannedIpsResponse from app.models.ban import JailBannedIpsResponse
from app.models.jail import ( from app.models.jail import (
IgnoreIpRequest, IgnoreIpRequest,
@@ -557,6 +557,7 @@ async def toggle_ignore_self(
async def get_jail_banned_ips( async def get_jail_banned_ips(
request: Request, request: Request,
_auth: AuthDep, _auth: AuthDep,
db: DbDep,
name: _NamePath, name: _NamePath,
page: int = 1, page: int = 1,
page_size: int = 25, page_size: int = 25,
@@ -597,7 +598,6 @@ async def get_jail_banned_ips(
socket_path: str = request.app.state.settings.fail2ban_socket socket_path: str = request.app.state.settings.fail2ban_socket
http_session = getattr(request.app.state, "http_session", None) http_session = getattr(request.app.state, "http_session", None)
app_db = getattr(request.app.state, "db", None)
try: try:
return await jail_service.get_jail_banned_ips( return await jail_service.get_jail_banned_ips(
@@ -608,7 +608,7 @@ async def get_jail_banned_ips(
search=search, search=search,
geo_batch_lookup=geo_service.lookup_batch, geo_batch_lookup=geo_service.lookup_batch,
http_session=http_session, http_session=http_session,
app_db=app_db, app_db=db,
) )
except JailNotFoundError: except JailNotFoundError:
raise _not_found(name) from None raise _not_found(name) from None

View File

@@ -17,9 +17,13 @@ from typing import TYPE_CHECKING, Any
import structlog import structlog
from app.db import open_db
from app.models.blocklist import ScheduleFrequency from app.models.blocklist import ScheduleFrequency
from app.services import blocklist_service from app.services import blocklist_service
if TYPE_CHECKING:
import aiosqlite
if TYPE_CHECKING: if TYPE_CHECKING:
from fastapi import FastAPI from fastapi import FastAPI
@@ -29,6 +33,15 @@ log: structlog.stdlib.BoundLogger = structlog.get_logger()
JOB_ID: str = "blocklist_import" JOB_ID: str = "blocklist_import"
async def _get_db(app: Any) -> tuple[aiosqlite.Connection, bool]:
existing_db = getattr(app.state, "db", None)
if existing_db is not None:
return existing_db, False
db = await open_db(app.state.settings.database_path)
return db, True
async def _run_import(app: Any) -> None: async def _run_import(app: Any) -> None:
"""APScheduler callback that imports all enabled blocklist sources. """APScheduler callback that imports all enabled blocklist sources.
@@ -39,12 +52,10 @@ async def _run_import(app: Any) -> None:
app: The :class:`fastapi.FastAPI` application instance passed via app: The :class:`fastapi.FastAPI` application instance passed via
APScheduler ``kwargs``. APScheduler ``kwargs``.
""" """
db = app.state.db db, close_db = await _get_db(app)
http_session = app.state.http_session http_session = app.state.http_session
socket_path: str = app.state.settings.fail2ban_socket socket_path: str = app.state.settings.fail2ban_socket
from app.services import jail_service
log.info("blocklist_import_starting") log.info("blocklist_import_starting")
try: try:
result = await blocklist_service.import_all( result = await blocklist_service.import_all(
@@ -60,6 +71,9 @@ async def _run_import(app: Any) -> None:
) )
except Exception: except Exception:
log.exception("blocklist_import_unexpected_error") log.exception("blocklist_import_unexpected_error")
finally:
if close_db:
await db.close()
def register(app: FastAPI) -> None: def register(app: FastAPI) -> None:
@@ -78,7 +92,12 @@ def register(app: FastAPI) -> None:
import asyncio # noqa: PLC0415 import asyncio # noqa: PLC0415
async def _do_register() -> None: async def _do_register() -> None:
config = await blocklist_service.get_schedule(app.state.db) db, close_db = await _get_db(app)
try:
config = await blocklist_service.get_schedule(db)
finally:
if close_db:
await db.close()
_apply_schedule(app, config) _apply_schedule(app, config)
# APScheduler is synchronous at registration time; use asyncio to read # APScheduler is synchronous at registration time; use asyncio to read
@@ -104,7 +123,12 @@ def reschedule(app: FastAPI) -> None:
import asyncio # noqa: PLC0415 import asyncio # noqa: PLC0415
async def _do_reschedule() -> None: async def _do_reschedule() -> None:
config = await blocklist_service.get_schedule(app.state.db) db, close_db = await _get_db(app)
try:
config = await blocklist_service.get_schedule(db)
finally:
if close_db:
await db.close()
_apply_schedule(app, config) _apply_schedule(app, config)
asyncio.ensure_future(_do_reschedule()) asyncio.ensure_future(_do_reschedule())

View File

@@ -15,6 +15,10 @@ from typing import TYPE_CHECKING, Any
import structlog import structlog
from app.db import open_db
if TYPE_CHECKING:
import aiosqlite
from app.services import geo_service from app.services import geo_service
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -29,6 +33,15 @@ GEO_FLUSH_INTERVAL: int = 60
JOB_ID: str = "geo_cache_flush" JOB_ID: str = "geo_cache_flush"
async def _get_db(app: Any) -> tuple[aiosqlite.Connection, bool]:
existing_db = getattr(app.state, "db", None)
if existing_db is not None:
return existing_db, False
db = await open_db(app.state.settings.database_path)
return db, True
async def _run_flush(app: Any) -> None: async def _run_flush(app: Any) -> None:
"""Flush the geo service dirty set to the application database. """Flush the geo service dirty set to the application database.
@@ -39,8 +52,13 @@ async def _run_flush(app: Any) -> None:
app: The :class:`fastapi.FastAPI` application instance passed via app: The :class:`fastapi.FastAPI` application instance passed via
APScheduler ``kwargs``. APScheduler ``kwargs``.
""" """
db = app.state.db db, close_db = await _get_db(app)
try:
count = await geo_service.flush_dirty(db) count = await geo_service.flush_dirty(db)
finally:
if close_db:
await db.close()
if count > 0: if count > 0:
log.debug("geo_cache_flush_ran", flushed=count) log.debug("geo_cache_flush_ran", flushed=count)

View File

@@ -21,6 +21,10 @@ from typing import TYPE_CHECKING
import structlog import structlog
from app.db import open_db
if TYPE_CHECKING:
import aiosqlite
from app.services import geo_service from app.services import geo_service
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -35,6 +39,15 @@ GEO_RE_RESOLVE_INTERVAL: int = 600
JOB_ID: str = "geo_re_resolve" JOB_ID: str = "geo_re_resolve"
async def _get_db(app: FastAPI) -> tuple[aiosqlite.Connection, bool]:
existing_db = getattr(app.state, "db", None)
if existing_db is not None:
return existing_db, False
db = await open_db(app.state.settings.database_path)
return db, True
async def _run_re_resolve(app: FastAPI) -> None: async def _run_re_resolve(app: FastAPI) -> None:
"""Query NULL-country IPs from the database and re-resolve them. """Query NULL-country IPs from the database and re-resolve them.
@@ -45,9 +58,10 @@ async def _run_re_resolve(app: FastAPI) -> None:
app: The :class:`fastapi.FastAPI` application instance passed via app: The :class:`fastapi.FastAPI` application instance passed via
APScheduler ``kwargs``. APScheduler ``kwargs``.
""" """
db = app.state.db db, close_db = await _get_db(app)
http_session = app.state.http_session http_session = app.state.http_session
try:
# Fetch all IPs with NULL country_code from the persistent cache. # Fetch all IPs with NULL country_code from the persistent cache.
unresolved_ips = await geo_service.get_unresolved_ips(db) unresolved_ips = await geo_service.get_unresolved_ips(db)
@@ -72,6 +86,9 @@ async def _run_re_resolve(app: FastAPI) -> None:
retried=len(unresolved_ips), retried=len(unresolved_ips),
resolved=resolved_count, resolved=resolved_count,
) )
finally:
if close_db:
await db.close()
def register(app: FastAPI) -> None: def register(app: FastAPI) -> None:

View File

@@ -9,8 +9,12 @@ from __future__ import annotations
import datetime import datetime
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING:
import aiosqlite
import structlog import structlog
from app.db import open_db
from app.repositories import fail2ban_db_repo from app.repositories import fail2ban_db_repo
from app.utils.fail2ban_db_utils import get_fail2ban_db_path from app.utils.fail2ban_db_utils import get_fail2ban_db_path
@@ -29,6 +33,15 @@ HISTORY_SYNC_INTERVAL: int = 300
BACKFILL_WINDOW: int = 648000 BACKFILL_WINDOW: int = 648000
async def _get_db(app: FastAPI) -> tuple[aiosqlite.Connection, bool]:
existing_db = getattr(app.state, "db", None)
if existing_db is not None:
return existing_db, False
db = await open_db(app.state.settings.database_path)
return db, True
async def _get_last_archive_ts(db) -> int | None: async def _get_last_archive_ts(db) -> int | None:
async with db.execute("SELECT MAX(timeofban) FROM history_archive") as cur: async with db.execute("SELECT MAX(timeofban) FROM history_archive") as cur:
row = await cur.fetchone() row = await cur.fetchone()
@@ -38,8 +51,8 @@ async def _get_last_archive_ts(db) -> int | None:
async def _run_sync(app: FastAPI) -> None: async def _run_sync(app: FastAPI) -> None:
db = app.state.db
socket_path: str = app.state.settings.fail2ban_socket socket_path: str = app.state.settings.fail2ban_socket
db, close_db = await _get_db(app)
try: try:
last_ts = await _get_last_archive_ts(db) last_ts = await _get_last_archive_ts(db)
@@ -90,6 +103,9 @@ async def _run_sync(app: FastAPI) -> None:
except Exception: except Exception:
log.exception("history_sync_failed") log.exception("history_sync_failed")
finally:
if close_db:
await db.close()
def register(app: FastAPI) -> None: def register(app: FastAPI) -> None: