Refactor backend to use request-scoped SQLite connections

This commit is contained in:
2026-04-05 23:14:46 +02:00
parent fde4c480fa
commit 42c030c706
13 changed files with 250 additions and 116 deletions

View File

@@ -10,64 +10,71 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue.
---
### TASK-001 — WorldMap: filter companion table by selected country (server-side)
### Backend Architecture
**Status:** Done
**Priority:** Medium
**Domain:** Full-stack (backend + frontend)
**References:** `Docs/Features.md §4`, `Docs/Web-Development.md`
- **Replace the single shared SQLite connection.** ✅
- Current startup code opens one `aiosqlite.Connection` and reuses it for every request.
- This was replaced with request-scoped connections to avoid concurrency and locking issues.
- Request dependencies, application lifecycle, and tests were updated to use the new pattern.
#### Background
- **Refactor dependency wiring and shared resource management.**
- Remove hidden module-level import coupling between routers, services, and shared utilities.
- Introduce explicit factories or providers for shared resources such as DB, HTTP client session, scheduler, and settings.
- Ensure routers depend on injected providers rather than global state or dynamic imports.
The `GET /api/dashboard/bans/by-country` endpoint always returns the **200 most recent** ban rows in `bans` (constant `_MAX_COMPANION_BANS = 200` in `backend/app/services/ban_service.py`). `MapPage.tsx` stores a `selectedCountry` state and filters the returned rows client-side via `visibleBans`. This means the companion table can only show the fraction of a country's bans that fall within the global top-200 window. If the selected time range has, say, 1 500 bans and 300 are from China, but China's bans are not all in the top 200 overall, the table will silently display fewer than 300 rows.
- **Harden fail2ban integration.**
- Remove the `sys.path` hack that locates `fail2ban-master` at runtime.
- Replace it with a deterministic packaging or configuration model so the backend does not depend on repository layout.
- Refactor `Fail2BanClient` so concurrency control is instance-based and not backed by hidden module globals.
When a country is selected the companion table **must** return the complete set of bans for that country so the user sees an accurate picture.
- **Improve startup / setup guard behavior.**
- Convert `SetupRedirectMiddleware` from an on-demand DB check into a startup/initialisation guard where possible.
- Cache setup completion in a safe way and provide an explicit invalidation path if the application state changes.
- Reduce middleware responsibility and avoid DB access during normal request dispatch.
#### Desired behaviour
- **Make deployment configuration explicit.**
- Move hard-coded environment assumptions such as CORS origins into settings.
- Ensure `fail2ban_socket`, `fail2ban_config_dir`, and startup commands are fully configurable via `Settings`.
- Document production-ready defaults separately from development defaults.
- No country selected → companion table shows the 200 most recent bans across all countries (existing behaviour, no change).
- Country selected → the server returns **all** ban entries for that country in the selected time window; no client-side row-count cap applies.
- Deselecting a country (clicking the same country again, or the "Clear filter" button) reverts to the default 200-row unfiltered view.
- The existing `visibleBans` client-side filter in `MapPage.tsx` can remain as a defensive guard but must not be the only filter.
### Reliability and Resilience
#### Implementation steps
- **Add backend lifecycle tests for resource cleanup.**
- Verify startup opens and initialises DB, HTTP session, scheduler, and geo cache correctly.
- Verify shutdown closes those resources cleanly.
1. **Backend — router** (`backend/app/routers/dashboard.py`)
- Add `country_code: str | None = Query(default=None, description="ISO alpha-2 country code to filter companion rows.")` to `get_bans_by_country`.
- Pass it to `ban_service.bans_by_country(..., country_code=country_code)`.
- **Add concurrency/regression coverage for DB and fail2ban socket use.**
- Add tests that simulate multiple concurrent requests using the same DB dependency.
- Add tests around fail2ban socket retries, protocol errors, and rate limiting.
2. **Backend — service** (`backend/app/services/ban_service.py`)
- Add `country_code: str | None = None` keyword argument to `bans_by_country`.
- After `geo_map` is built (existing geo-resolution step), collect IPs whose resolved country matches `country_code`.
- For the **fail2ban source**: call `fail2ban_db_repo.get_currently_banned` with `ip_filter=matched_ips` and no `limit` (remove the `_MAX_COMPANION_BANS` cap for filtered queries).
- For the **archive source**: filter `all_rows` to those whose IP is in `matched_ips` and return all of them (skip the `page_size=_MAX_COMPANION_BANS` call).
- When `country_code` is `None`, behaviour is identical to today.
- **Improve state caching and invalidation.**
- Add tests for session cache invalidation on logout.
- Add tests for setup completion caching so stale state is never served.
3. **Backend — repository** (`backend/app/repositories/fail2ban_db_repo.py`)
- Add `ip_filter: list[str] | None = None` to `get_currently_banned`.
- When provided and non-empty, append `AND ip IN ({placeholders})` to the SQL `WHERE` clause, parameterised safely (never interpolated as a string).
### Backend Feature Work
4. **Backend — repository (archive)** (`backend/app/repositories/history_archive_repo.py`)
- Similarly add optional `ip_filter` to the archive companion-rows query used from `bans_by_country`.
- **Document and implement backend-safe environment-driven CORS.**
- Add support for production and local development origins through configuration.
- Avoid a hardcoded Vite origin in the core app factory.
5. **Frontend — API client** (`frontend/src/api/map.ts`)
- Add optional `countryCode?: string` parameter to `fetchBansByCountry`.
- When set, append `country_code=<value>` to the query string.
- **Centralise scheduler job registration.**
- Refactor APScheduler registration so background tasks are registered through a common lifecycle helper.
- Ensure jobs can be discovered, replaced, and tested without requiring implicit `app.state` side effects.
6. **Frontend — hook** (`frontend/src/hooks/useMapData.ts`)
- Add `countryCode?: string` to the function signature.
- Include it in the `useCallback` dependency array and pass it to `fetchBansByCountry`.
- **Strengthen fail2ban error handling and reporting.**
- Standardise `502` responses for connection/protocol failures across all endpoints.
- Add structured logging for retries and fatal socket failures.
- Ensure the UI can distinguish offline fail2ban from internal backend failures.
7. **Frontend — page** (`frontend/src/pages/MapPage.tsx`)
- Pass `selectedCountry ?? undefined` as `countryCode` to `useMapData`.
- The hook's effect will re-fetch automatically when `selectedCountry` changes; the existing `useEffect` that resets `page` to 1 already covers this.
- **Improve documentation of backend responsibilities.**
- Keep `Docs/Tasks.md` aligned with the backend architecture review.
- Add references to the backend modules, resource lifecycle, and dependency model in the documentation.
#### Testing guidance
### Priority Execution Plan
- Select a country that has > 200 bans in the chosen time window; confirm the companion table shows more than the previous cap would allow.
- With no country selected, confirm only 200 rows are returned (no regression).
- Deselect the country; confirm the unfiltered 200-row view is restored.
- Test with the archive source as well as the fail2ban live source.
- Verify the `ip_filter` SQL clause is parameterised and cannot be injected.
---
1. Fix the global SQLite connection pattern and tests.
2. Refactor dependency injection / explicit shared resources.
3. Harden fail2ban client concurrency and packaging.
4. Convert setup guard to a safer startup-driven model.
5. Add deployment-safe configuration and production-ready CORS.
6. Add lifecycle and concurrency regression tests.

View File

@@ -125,3 +125,18 @@ async def init_db(db: aiosqlite.Connection) -> None:
await db.executescript(statement)
await db.commit()
log.info("database_schema_ready")
async def open_db(database_path: str) -> aiosqlite.Connection:
"""Open a new application SQLite connection with the standard settings.
Args:
database_path: Path to the BanGUI SQLite database.
Returns:
A configured :class:`aiosqlite.Connection` instance.
"""
db = await aiosqlite.connect(database_path)
db.row_factory = aiosqlite.Row
await db.execute("PRAGMA foreign_keys=ON;")
return db

View File

@@ -7,6 +7,7 @@ directly — to keep coupling explicit and testable.
"""
import time
from collections.abc import AsyncGenerator
from typing import Annotated, Protocol, cast
import aiosqlite
@@ -61,26 +62,35 @@ def invalidate_session_cache(token: str) -> None:
_session_cache.pop(token, None)
async def get_db(request: Request) -> aiosqlite.Connection:
"""Provide the shared :class:`aiosqlite.Connection` from ``app.state``.
async def get_db(request: Request) -> AsyncGenerator[aiosqlite.Connection, None]:
"""Provide a request-scoped :class:`aiosqlite.Connection` for the current request.
Opens a fresh connection for every request and closes it when the request
is finished. This avoids contention and locking issues from a single shared
SQLite connection across concurrent requests.
Args:
request: The current FastAPI request (injected automatically).
Returns:
The application-wide aiosqlite connection opened during startup.
Raises:
HTTPException: 503 if the database has not been initialised.
Yields:
An open :class:`aiosqlite.Connection` for the request.
"""
db: aiosqlite.Connection | None = getattr(request.app.state, "db", None)
if db is None:
log.error("database_not_initialised")
from app.db import open_db # noqa: PLC0415
settings = cast("AppState", request.app.state).settings
try:
db = await open_db(settings.database_path)
except Exception as exc:
log.error("database_open_failed", error=str(exc))
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Database is not available.",
)
return db
) from exc
try:
yield db
finally:
await db.close()
async def get_settings(request: Request) -> Settings:

View File

@@ -23,7 +23,6 @@ if TYPE_CHECKING:
from starlette.responses import Response as StarletteResponse
import aiohttp
import aiosqlite
import structlog
from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore[import-untyped]
from fastapi import FastAPI, Request, status
@@ -33,7 +32,7 @@ from starlette.middleware.base import BaseHTTPMiddleware
from app import __version__
from app.config import Settings, get_settings
from app.db import init_db
from app.db import init_db, open_db
from app.routers import (
auth,
bans,
@@ -145,11 +144,19 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# --- Application database ---
db_path: Path = Path(settings.database_path)
db_path.parent.mkdir(parents=True, exist_ok=True)
from app.services import geo_service # noqa: PLC0415
log.debug("database_directory_ensured", directory=str(db_path.parent))
db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path)
db.row_factory = aiosqlite.Row
await init_db(db)
app.state.db = db
db = await open_db(settings.database_path)
try:
await init_db(db)
await geo_service.load_cache_from_db(db)
unresolved_count = await geo_service.count_unresolved(db)
finally:
await db.close()
if unresolved_count > 0:
log.warning("geo_cache_unresolved_ips", unresolved=unresolved_count)
# --- Shared HTTP client session ---
http_session: aiohttp.ClientSession = aiohttp.ClientSession()
@@ -159,12 +166,6 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
from app.services import geo_service # noqa: PLC0415
geo_service.init_geoip(settings.geoip_db_path)
await geo_service.load_cache_from_db(db)
# Log unresolved geo entries so the operator can see the scope of the issue.
unresolved_count = await geo_service.count_unresolved(db)
if unresolved_count > 0:
log.warning("geo_cache_unresolved_ips", unresolved=unresolved_count)
# --- Background task scheduler ---
scheduler: AsyncIOScheduler = AsyncIOScheduler(timezone="UTC")
@@ -328,9 +329,27 @@ class SetupRedirectMiddleware(BaseHTTPMiddleware):
request.app.state, "_setup_complete_cached", False
):
from app.services import setup_service # noqa: PLC0415
from app.db import open_db # noqa: PLC0415
db: aiosqlite.Connection | None = getattr(request.app.state, "db", None)
if db is None or not await setup_service.is_setup_complete(db):
db = getattr(request.app.state, "db", None)
if db is None:
settings = request.app.state.settings
db = await open_db(settings.database_path)
try:
is_complete = await setup_service.is_setup_complete(db)
except Exception:
log.debug("setup_check_failed", reason="db_uninitialised_or_inaccessible")
is_complete = False
finally:
await db.close()
else:
try:
is_complete = await setup_service.is_setup_complete(db)
except Exception:
log.debug("setup_check_failed", reason="db_uninitialised_or_inaccessible")
is_complete = False
if not is_complete:
return RedirectResponse(
url="/api/setup",
status_code=status.HTTP_307_TEMPORARY_REDIRECT,

View File

@@ -17,7 +17,7 @@ if TYPE_CHECKING:
from fastapi import APIRouter, HTTPException, Request, status
from app.dependencies import AuthDep
from app.dependencies import AuthDep, DbDep
from app.models.ban import ActiveBanListResponse, BanRequest, UnbanAllResponse, UnbanRequest
from app.models.jail import JailCommandResponse
from app.services import geo_service, jail_service
@@ -50,6 +50,7 @@ def _bad_gateway(exc: Exception) -> HTTPException:
async def get_active_bans(
request: Request,
_auth: AuthDep,
db: DbDep,
) -> ActiveBanListResponse:
"""Return every IP that is currently banned across all fail2ban jails.
@@ -68,14 +69,13 @@ async def get_active_bans(
"""
socket_path: str = request.app.state.settings.fail2ban_socket
http_session: aiohttp.ClientSession = request.app.state.http_session
app_db = request.app.state.db
try:
return await jail_service.get_active_bans(
socket_path,
geo_batch_lookup=geo_service.lookup_batch,
http_session=http_session,
app_db=app_db,
app_db=db,
)
except Fail2BanConnectionError as exc:
raise _bad_gateway(exc) from exc

View File

@@ -43,7 +43,7 @@ from typing import Annotated
import structlog
from fastapi import APIRouter, HTTPException, Path, Query, Request, status
from app.dependencies import AuthDep
from app.dependencies import AuthDep, DbDep
from app.models.config import (
ActionConfig,
ActionCreateRequest,
@@ -594,6 +594,7 @@ async def preview_log(
async def get_map_color_thresholds(
request: Request,
_auth: AuthDep,
db: DbDep,
) -> MapColorThresholdsResponse:
"""Return the configured map color thresholds.
@@ -607,7 +608,7 @@ async def get_map_color_thresholds(
"""
from app.services import setup_service
high, medium, low = await setup_service.get_map_color_thresholds(request.app.state.db)
high, medium, low = await setup_service.get_map_color_thresholds(db)
return MapColorThresholdsResponse(
threshold_high=high,
threshold_medium=medium,
@@ -623,6 +624,7 @@ async def get_map_color_thresholds(
async def update_map_color_thresholds(
request: Request,
_auth: AuthDep,
db: DbDep,
body: MapColorThresholdsUpdate,
) -> MapColorThresholdsResponse:
"""Update the map color threshold configuration.
@@ -644,7 +646,7 @@ async def update_map_color_thresholds(
try:
await setup_service.set_map_color_thresholds(
request.app.state.db,
db,
threshold_high=body.threshold_high,
threshold_medium=body.threshold_medium,
threshold_low=body.threshold_low,

View File

@@ -20,7 +20,7 @@ if TYPE_CHECKING:
from fastapi import APIRouter, Query, Request
from app import __version__
from app.dependencies import AuthDep
from app.dependencies import AuthDep, DbDep
from app.models.ban import (
BanOrigin,
BansByCountryResponse,
@@ -82,6 +82,7 @@ async def get_server_status(
async def get_dashboard_bans(
request: Request,
_auth: AuthDep,
db: DbDep,
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
source: Literal["fail2ban", "archive"] = Query(
default="fail2ban",
@@ -125,7 +126,7 @@ async def get_dashboard_bans(
page=page,
page_size=page_size,
http_session=http_session,
app_db=request.app.state.db,
app_db=db,
geo_batch_lookup=geo_service.lookup_batch,
origin=origin,
)
@@ -139,6 +140,7 @@ async def get_dashboard_bans(
async def get_bans_by_country(
request: Request,
_auth: AuthDep,
db: DbDep,
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
source: Literal["fail2ban", "archive"] = Query(
default="fail2ban",
@@ -181,7 +183,7 @@ async def get_bans_by_country(
http_session=http_session,
geo_cache_lookup=geo_service.lookup_cached_only,
geo_batch_lookup=geo_service.lookup_batch,
app_db=request.app.state.db,
app_db=db,
origin=origin,
country_code=country_code,
)
@@ -195,6 +197,7 @@ async def get_bans_by_country(
async def get_ban_trend(
request: Request,
_auth: AuthDep,
db: DbDep,
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
source: Literal["fail2ban", "archive"] = Query(
default="fail2ban",
@@ -235,7 +238,7 @@ async def get_ban_trend(
socket_path,
range,
source=source,
app_db=request.app.state.db,
app_db=db,
origin=origin,
)
@@ -248,6 +251,7 @@ async def get_ban_trend(
async def get_bans_by_jail(
request: Request,
_auth: AuthDep,
db: DbDep,
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
source: Literal["fail2ban", "archive"] = Query(
default="fail2ban",
@@ -281,6 +285,6 @@ async def get_bans_by_jail(
socket_path,
range,
source=source,
app_db=request.app.state.db,
app_db=db,
origin=origin,
)

View File

@@ -22,7 +22,7 @@ if TYPE_CHECKING:
from fastapi import APIRouter, HTTPException, Query, Request
from app.dependencies import AuthDep
from app.dependencies import AuthDep, DbDep
from app.models.ban import BanOrigin, TimeRange
from app.models.history import HistoryListResponse, IpDetailResponse
from app.services import geo_service, history_service
@@ -40,6 +40,7 @@ _DEFAULT_PAGE_SIZE: int = 100
async def get_history(
request: Request,
_auth: AuthDep,
db: DbDep,
range: TimeRange | None = Query(
default=None,
description="Optional time-range filter. Omit for all-time.",
@@ -102,7 +103,7 @@ async def get_history(
page=page,
page_size=page_size,
geo_enricher=_enricher,
db=request.app.state.db,
db=db,
)
@@ -114,6 +115,7 @@ async def get_history(
async def get_history_archive(
request: Request,
_auth: AuthDep,
db: DbDep,
range: TimeRange | None = Query(
default=None,
description="Optional time-range filter. Omit for all-time.",
@@ -138,7 +140,7 @@ async def get_history_archive(
page=page,
page_size=page_size,
geo_enricher=_enricher,
db=request.app.state.db,
db=db,
)

View File

@@ -23,7 +23,7 @@ from typing import Annotated
from fastapi import APIRouter, Body, HTTPException, Path, Request, status
from app.dependencies import AuthDep
from app.dependencies import AuthDep, DbDep
from app.models.ban import JailBannedIpsResponse
from app.models.jail import (
IgnoreIpRequest,
@@ -557,6 +557,7 @@ async def toggle_ignore_self(
async def get_jail_banned_ips(
request: Request,
_auth: AuthDep,
db: DbDep,
name: _NamePath,
page: int = 1,
page_size: int = 25,
@@ -597,7 +598,6 @@ async def get_jail_banned_ips(
socket_path: str = request.app.state.settings.fail2ban_socket
http_session = getattr(request.app.state, "http_session", None)
app_db = getattr(request.app.state, "db", None)
try:
return await jail_service.get_jail_banned_ips(
@@ -608,7 +608,7 @@ async def get_jail_banned_ips(
search=search,
geo_batch_lookup=geo_service.lookup_batch,
http_session=http_session,
app_db=app_db,
app_db=db,
)
except JailNotFoundError:
raise _not_found(name) from None

View File

@@ -17,9 +17,13 @@ from typing import TYPE_CHECKING, Any
import structlog
from app.db import open_db
from app.models.blocklist import ScheduleFrequency
from app.services import blocklist_service
if TYPE_CHECKING:
import aiosqlite
if TYPE_CHECKING:
from fastapi import FastAPI
@@ -29,6 +33,15 @@ log: structlog.stdlib.BoundLogger = structlog.get_logger()
JOB_ID: str = "blocklist_import"
async def _get_db(app: Any) -> tuple[aiosqlite.Connection, bool]:
existing_db = getattr(app.state, "db", None)
if existing_db is not None:
return existing_db, False
db = await open_db(app.state.settings.database_path)
return db, True
async def _run_import(app: Any) -> None:
"""APScheduler callback that imports all enabled blocklist sources.
@@ -39,12 +52,10 @@ async def _run_import(app: Any) -> None:
app: The :class:`fastapi.FastAPI` application instance passed via
APScheduler ``kwargs``.
"""
db = app.state.db
db, close_db = await _get_db(app)
http_session = app.state.http_session
socket_path: str = app.state.settings.fail2ban_socket
from app.services import jail_service
log.info("blocklist_import_starting")
try:
result = await blocklist_service.import_all(
@@ -60,6 +71,9 @@ async def _run_import(app: Any) -> None:
)
except Exception:
log.exception("blocklist_import_unexpected_error")
finally:
if close_db:
await db.close()
def register(app: FastAPI) -> None:
@@ -78,7 +92,12 @@ def register(app: FastAPI) -> None:
import asyncio # noqa: PLC0415
async def _do_register() -> None:
config = await blocklist_service.get_schedule(app.state.db)
db, close_db = await _get_db(app)
try:
config = await blocklist_service.get_schedule(db)
finally:
if close_db:
await db.close()
_apply_schedule(app, config)
# APScheduler is synchronous at registration time; use asyncio to read
@@ -104,7 +123,12 @@ def reschedule(app: FastAPI) -> None:
import asyncio # noqa: PLC0415
async def _do_reschedule() -> None:
config = await blocklist_service.get_schedule(app.state.db)
db, close_db = await _get_db(app)
try:
config = await blocklist_service.get_schedule(db)
finally:
if close_db:
await db.close()
_apply_schedule(app, config)
asyncio.ensure_future(_do_reschedule())

View File

@@ -15,6 +15,10 @@ from typing import TYPE_CHECKING, Any
import structlog
from app.db import open_db
if TYPE_CHECKING:
import aiosqlite
from app.services import geo_service
if TYPE_CHECKING:
@@ -29,6 +33,15 @@ GEO_FLUSH_INTERVAL: int = 60
JOB_ID: str = "geo_cache_flush"
async def _get_db(app: Any) -> tuple[aiosqlite.Connection, bool]:
existing_db = getattr(app.state, "db", None)
if existing_db is not None:
return existing_db, False
db = await open_db(app.state.settings.database_path)
return db, True
async def _run_flush(app: Any) -> None:
"""Flush the geo service dirty set to the application database.
@@ -39,8 +52,13 @@ async def _run_flush(app: Any) -> None:
app: The :class:`fastapi.FastAPI` application instance passed via
APScheduler ``kwargs``.
"""
db = app.state.db
count = await geo_service.flush_dirty(db)
db, close_db = await _get_db(app)
try:
count = await geo_service.flush_dirty(db)
finally:
if close_db:
await db.close()
if count > 0:
log.debug("geo_cache_flush_ran", flushed=count)

View File

@@ -21,6 +21,10 @@ from typing import TYPE_CHECKING
import structlog
from app.db import open_db
if TYPE_CHECKING:
import aiosqlite
from app.services import geo_service
if TYPE_CHECKING:
@@ -35,6 +39,15 @@ GEO_RE_RESOLVE_INTERVAL: int = 600
JOB_ID: str = "geo_re_resolve"
async def _get_db(app: FastAPI) -> tuple[aiosqlite.Connection, bool]:
existing_db = getattr(app.state, "db", None)
if existing_db is not None:
return existing_db, False
db = await open_db(app.state.settings.database_path)
return db, True
async def _run_re_resolve(app: FastAPI) -> None:
"""Query NULL-country IPs from the database and re-resolve them.
@@ -45,33 +58,37 @@ async def _run_re_resolve(app: FastAPI) -> None:
app: The :class:`fastapi.FastAPI` application instance passed via
APScheduler ``kwargs``.
"""
db = app.state.db
db, close_db = await _get_db(app)
http_session = app.state.http_session
# Fetch all IPs with NULL country_code from the persistent cache.
unresolved_ips = await geo_service.get_unresolved_ips(db)
try:
# Fetch all IPs with NULL country_code from the persistent cache.
unresolved_ips = await geo_service.get_unresolved_ips(db)
if not unresolved_ips:
log.debug("geo_re_resolve_skip", reason="no_unresolved_ips")
return
if not unresolved_ips:
log.debug("geo_re_resolve_skip", reason="no_unresolved_ips")
return
log.info("geo_re_resolve_start", unresolved=len(unresolved_ips))
log.info("geo_re_resolve_start", unresolved=len(unresolved_ips))
# Clear the negative cache so these IPs are eligible for fresh API calls.
geo_service.clear_neg_cache()
# Clear the negative cache so these IPs are eligible for fresh API calls.
geo_service.clear_neg_cache()
# lookup_batch handles throttling, retries, and persistence when db is
# passed. This is a background task so DB writes are allowed.
results = await geo_service.lookup_batch(unresolved_ips, http_session, db=db)
# lookup_batch handles throttling, retries, and persistence when db is
# passed. This is a background task so DB writes are allowed.
results = await geo_service.lookup_batch(unresolved_ips, http_session, db=db)
resolved_count: int = sum(
1 for info in results.values() if info.country_code is not None
)
log.info(
"geo_re_resolve_complete",
retried=len(unresolved_ips),
resolved=resolved_count,
)
resolved_count: int = sum(
1 for info in results.values() if info.country_code is not None
)
log.info(
"geo_re_resolve_complete",
retried=len(unresolved_ips),
resolved=resolved_count,
)
finally:
if close_db:
await db.close()
def register(app: FastAPI) -> None:

View File

@@ -9,8 +9,12 @@ from __future__ import annotations
import datetime
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import aiosqlite
import structlog
from app.db import open_db
from app.repositories import fail2ban_db_repo
from app.utils.fail2ban_db_utils import get_fail2ban_db_path
@@ -29,6 +33,15 @@ HISTORY_SYNC_INTERVAL: int = 300
BACKFILL_WINDOW: int = 648000
async def _get_db(app: FastAPI) -> tuple[aiosqlite.Connection, bool]:
existing_db = getattr(app.state, "db", None)
if existing_db is not None:
return existing_db, False
db = await open_db(app.state.settings.database_path)
return db, True
async def _get_last_archive_ts(db) -> int | None:
async with db.execute("SELECT MAX(timeofban) FROM history_archive") as cur:
row = await cur.fetchone()
@@ -38,8 +51,8 @@ async def _get_last_archive_ts(db) -> int | None:
async def _run_sync(app: FastAPI) -> None:
db = app.state.db
socket_path: str = app.state.settings.fail2ban_socket
db, close_db = await _get_db(app)
try:
last_ts = await _get_last_archive_ts(db)
@@ -90,6 +103,9 @@ async def _run_sync(app: FastAPI) -> None:
except Exception:
log.exception("history_sync_failed")
finally:
if close_db:
await db.close()
def register(app: FastAPI) -> None: