Fix geo cache write performance: batch commits, read-only GETs, dirty flush

- Remove per-IP db.commit() from _persist_entry() and _persist_neg_entry();
  add a single commit after the full lookup_batch() chunk loop instead.
  Reduces commits from ~5,200 to 1 per bans/by-country request.

- Remove db dependency from GET /api/dashboard/bans and
  GET /api/dashboard/bans/by-country; pass app_db=None so no SQLite
  writes occur during read-only requests.

- Add _dirty set to geo_service; _store() marks resolved IPs dirty.
  New flush_dirty(db) batch-upserts all dirty entries in one transaction.
  New geo_cache_flush APScheduler task flushes every 60 s so geo data
  is persisted without blocking requests.
This commit is contained in:
2026-03-10 18:45:58 +01:00
parent 0225f32901
commit 44a5a3d70e
6 changed files with 505 additions and 34 deletions

View File

@@ -34,7 +34,7 @@ from starlette.middleware.base import BaseHTTPMiddleware
from app.config import Settings, get_settings
from app.db import init_db
from app.routers import auth, bans, blocklist, config, dashboard, geo, health, history, jails, server, setup
from app.tasks import blocklist_import, health_check
from app.tasks import blocklist_import, geo_cache_flush, health_check
from app.utils.fail2ban_client import Fail2BanConnectionError, Fail2BanProtocolError
# ---------------------------------------------------------------------------
@@ -151,6 +151,9 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# --- Blocklist import scheduled task ---
blocklist_import.register(app)
# --- Periodic geo cache flush to SQLite ---
geo_cache_flush.register(app)
log.info("bangui_started")
try:

View File

@@ -9,16 +9,14 @@ Also provides ``GET /api/dashboard/bans`` for the dashboard ban-list table.
from __future__ import annotations
from typing import TYPE_CHECKING, Annotated
import aiosqlite
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import aiohttp
from fastapi import APIRouter, Depends, Query, Request
from fastapi import APIRouter, Query, Request
from app.dependencies import AuthDep, get_db
from app.dependencies import AuthDep
from app.models.ban import (
BanOrigin,
BansByCountryResponse,
@@ -77,7 +75,6 @@ async def get_server_status(
async def get_dashboard_bans(
request: Request,
_auth: AuthDep,
db: Annotated[aiosqlite.Connection, Depends(get_db)],
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
page: int = Query(default=1, ge=1, description="1-based page number."),
page_size: int = Query(default=_DEFAULT_PAGE_SIZE, ge=1, le=500, description="Items per page."),
@@ -90,12 +87,13 @@ async def get_dashboard_bans(
Reads from the fail2ban database and enriches each entry with
geolocation data (country, ASN, organisation) from the ip-api.com
free API. Results are sorted newest-first.
free API. Results are sorted newest-first. Geo lookups are served
from the in-memory cache only; no database writes occur during this
GET request.
Args:
request: The incoming request (used to access ``app.state``).
_auth: Validated session dependency.
db: BanGUI application database (for persistent geo cache writes).
range: Time-range preset — ``"24h"``, ``"7d"``, ``"30d"``, or
``"365d"``.
page: 1-based page number.
@@ -115,7 +113,7 @@ async def get_dashboard_bans(
page=page,
page_size=page_size,
http_session=http_session,
app_db=db,
app_db=None,
origin=origin,
)
@@ -128,7 +126,6 @@ async def get_dashboard_bans(
async def get_bans_by_country(
request: Request,
_auth: AuthDep,
db: Annotated[aiosqlite.Connection, Depends(get_db)],
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
origin: BanOrigin | None = Query(
default=None,
@@ -139,12 +136,13 @@ async def get_bans_by_country(
Uses SQL aggregation (``GROUP BY ip``) and batch geo-resolution to handle
10 000+ banned IPs efficiently. Returns a ``{country_code: count}`` map
and the 200 most recent raw ban rows for the companion access table.
and the 200 most recent raw ban rows for the companion access table. Geo
lookups are served from the in-memory cache only; no database writes occur
during this GET request.
Args:
request: The incoming request.
_auth: Validated session dependency.
db: BanGUI application database (for persistent geo cache writes).
range: Time-range preset.
origin: Optional filter by ban origin.
@@ -159,7 +157,7 @@ async def get_bans_by_country(
socket_path,
range,
http_session=http_session,
app_db=db,
app_db=None,
origin=origin,
)

View File

@@ -118,6 +118,10 @@ _cache: dict[str, GeoInfo] = {}
#: Entries within :data:`_NEG_CACHE_TTL` seconds are not re-queried.
_neg_cache: dict[str, float] = {}
#: IPs added to :data:`_cache` but not yet persisted to the database.
#: Consumed and cleared atomically by :func:`flush_dirty`.
_dirty: set[str] = set()
#: Optional MaxMind GeoLite2 reader initialised by :func:`init_geoip`.
_geoip_reader: geoip2.database.Reader | None = None
@@ -125,10 +129,12 @@ _geoip_reader: geoip2.database.Reader | None = None
def clear_cache() -> None:
"""Flush both the positive and negative lookup caches.
Useful in tests and when the operator suspects stale data.
Also clears the dirty set so any pending-but-unpersisted entries are
discarded. Useful in tests and when the operator suspects stale data.
"""
_cache.clear()
_neg_cache.clear()
_dirty.clear()
def clear_neg_cache() -> None:
@@ -256,7 +262,6 @@ async def _persist_entry(
""",
(ip, info.country_code, info.country_name, info.asn, info.org),
)
await db.commit()
async def _persist_neg_entry(db: aiosqlite.Connection, ip: str) -> None:
@@ -273,7 +278,6 @@ async def _persist_neg_entry(db: aiosqlite.Connection, ip: str) -> None:
"INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)",
(ip,),
)
await db.commit()
# ---------------------------------------------------------------------------
@@ -330,6 +334,7 @@ async def lookup(
if result.country_code is not None and db is not None:
try:
await _persist_entry(db, ip, result)
await db.commit()
except Exception as exc: # noqa: BLE001
log.warning("geo_persist_failed", ip=ip, error=str(exc))
log.debug("geo_lookup_success", ip=ip, country=result.country_code, asn=result.asn)
@@ -350,6 +355,7 @@ async def lookup(
if fallback.country_code is not None and db is not None:
try:
await _persist_entry(db, ip, fallback)
await db.commit()
except Exception as exc: # noqa: BLE001
log.warning("geo_persist_failed", ip=ip, error=str(exc))
log.debug("geo_geoip_fallback_success", ip=ip, country=fallback.country_code)
@@ -360,6 +366,7 @@ async def lookup(
if db is not None:
try:
await _persist_neg_entry(db, ip)
await db.commit()
except Exception as exc: # noqa: BLE001
log.warning("geo_persist_neg_failed", ip=ip, error=str(exc))
@@ -449,6 +456,12 @@ async def lookup_batch(
except Exception as exc: # noqa: BLE001
log.warning("geo_persist_neg_failed", ip=ip, error=str(exc))
if db is not None:
try:
await db.commit()
except Exception as exc: # noqa: BLE001
log.warning("geo_batch_commit_failed", error=str(exc))
log.info(
"geo_batch_lookup_complete",
requested=len(uncached),
@@ -561,11 +574,77 @@ def _str_or_none(value: object) -> str | None:
def _store(ip: str, info: GeoInfo) -> None:
"""Insert *info* into the module-level cache, flushing if over capacity.
When the IP resolved successfully (``country_code is not None``) it is
also added to the :data:`_dirty` set so :func:`flush_dirty` can persist
it to the database on the next scheduled flush.
Args:
ip: The IP address key.
info: The :class:`GeoInfo` to store.
"""
if len(_cache) >= _MAX_CACHE_SIZE:
_cache.clear()
_dirty.clear()
log.info("geo_cache_flushed", reason="capacity")
_cache[ip] = info
if info.country_code is not None:
_dirty.add(ip)
async def flush_dirty(db: aiosqlite.Connection) -> int:
"""Persist all new in-memory geo entries to the ``geo_cache`` table.
Takes an atomic snapshot of :data:`_dirty`, clears it, then batch-inserts
all entries that are still present in :data:`_cache` using a single
``executemany`` call and one ``COMMIT``. This is the only place that
writes to the persistent cache during normal operation after startup.
If the database write fails the entries are re-added to :data:`_dirty`
so they will be retried on the next flush cycle.
Args:
db: Open :class:`aiosqlite.Connection` to the BanGUI application
database.
Returns:
The number of rows successfully upserted.
"""
if not _dirty:
return 0
# Atomically snapshot and clear in a single-threaded async context.
# No ``await`` between copy and clear ensures no interleaving.
to_flush = _dirty.copy()
_dirty.clear()
rows = [
(ip, _cache[ip].country_code, _cache[ip].country_name, _cache[ip].asn, _cache[ip].org)
for ip in to_flush
if ip in _cache
]
if not rows:
return 0
try:
await db.executemany(
"""
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(ip) DO UPDATE SET
country_code = excluded.country_code,
country_name = excluded.country_name,
asn = excluded.asn,
org = excluded.org,
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
""",
rows,
)
await db.commit()
except Exception as exc: # noqa: BLE001
log.warning("geo_flush_dirty_failed", error=str(exc))
# Re-add to dirty so they are retried on the next flush cycle.
_dirty.update(to_flush)
return 0
log.info("geo_flush_dirty_complete", count=len(rows))
return len(rows)

View File

@@ -0,0 +1,66 @@
"""Geo cache flush background task.
Registers an APScheduler job that periodically persists newly resolved IP
geo entries from the in-memory ``_dirty`` set to the ``geo_cache`` table.
After Task 2 removed geo cache writes from GET requests, newly resolved IPs
are only held in the in-memory cache until this task flushes them. With the
default 60-second interval, at most one minute of new resolution results is
at risk on an unexpected process restart.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import structlog
from app.services import geo_service
if TYPE_CHECKING:
from fastapi import FastAPI
log: structlog.stdlib.BoundLogger = structlog.get_logger()
#: How often the flush job fires (seconds). Configurable tuning constant.
GEO_FLUSH_INTERVAL: int = 60
#: Stable APScheduler job ID — ensures re-registration replaces, not duplicates.
JOB_ID: str = "geo_cache_flush"
async def _run_flush(app: Any) -> None:
"""Flush the geo service dirty set to the application database.
Reads shared resources from ``app.state`` and delegates to
:func:`~app.services.geo_service.flush_dirty`.
Args:
app: The :class:`fastapi.FastAPI` application instance passed via
APScheduler ``kwargs``.
"""
db = app.state.db
count = await geo_service.flush_dirty(db)
if count > 0:
log.debug("geo_cache_flush_ran", flushed=count)
def register(app: FastAPI) -> None:
"""Add (or replace) the geo cache flush job in the application scheduler.
Must be called after the scheduler has been started (i.e., inside the
lifespan handler, after ``scheduler.start()``).
Args:
app: The :class:`fastapi.FastAPI` application instance whose
``app.state.scheduler`` will receive the job.
"""
app.state.scheduler.add_job(
_run_flush,
trigger="interval",
seconds=GEO_FLUSH_INTERVAL,
kwargs={"app": app},
id=JOB_ID,
replace_existing=True,
)
log.info("geo_cache_flush_scheduled", interval_seconds=GEO_FLUSH_INTERVAL)

View File

@@ -356,3 +356,212 @@ class TestGeoipFallback:
assert result is not None
assert result.country_code is None
# ---------------------------------------------------------------------------
# Batch single-commit behaviour (Task 1)
# ---------------------------------------------------------------------------
def _make_batch_session(batch_response: list[dict[str, object]]) -> MagicMock:
"""Build a mock aiohttp.ClientSession for batch POST calls.
Args:
batch_response: The list that the mock response's ``json()`` returns.
Returns:
A :class:`MagicMock` with a ``post`` method wired as an async context.
"""
mock_resp = AsyncMock()
mock_resp.status = 200
mock_resp.json = AsyncMock(return_value=batch_response)
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_resp)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
session = MagicMock()
session.post = MagicMock(return_value=mock_ctx)
return session
def _make_async_db() -> MagicMock:
"""Build a minimal mock :class:`aiosqlite.Connection`.
Returns:
MagicMock with ``execute``, ``executemany``, and ``commit`` wired as
async coroutines.
"""
db = MagicMock()
db.execute = AsyncMock()
db.executemany = AsyncMock()
db.commit = AsyncMock()
return db
class TestLookupBatchSingleCommit:
"""lookup_batch() issues exactly one commit per call, not one per IP."""
async def test_single_commit_for_multiple_ips(self) -> None:
"""A batch of N IPs produces exactly one db.commit(), not N."""
ips = ["1.1.1.1", "2.2.2.2", "3.3.3.3"]
batch_response = [
{"query": ip, "status": "success", "countryCode": "DE", "country": "Germany", "as": "AS1", "org": "Org"}
for ip in ips
]
session = _make_batch_session(batch_response)
db = _make_async_db()
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
db.commit.assert_awaited_once()
async def test_commit_called_even_on_failed_lookups(self) -> None:
"""A batch with all-failed lookups still triggers one commit."""
ips = ["10.0.0.1", "10.0.0.2"]
batch_response = [
{"query": ip, "status": "fail", "message": "private range"}
for ip in ips
]
session = _make_batch_session(batch_response)
db = _make_async_db()
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
db.commit.assert_awaited_once()
async def test_no_commit_when_db_is_none(self) -> None:
"""When db=None, no commit is attempted."""
ips = ["1.1.1.1"]
batch_response = [
{"query": "1.1.1.1", "status": "success", "countryCode": "US", "country": "United States", "as": "AS15169", "org": "Google LLC"},
]
session = _make_batch_session(batch_response)
# Should not raise; without db there is nothing to commit.
result = await geo_service.lookup_batch(ips, session, db=None)
assert result["1.1.1.1"].country_code == "US"
async def test_no_commit_for_all_cached_ips(self) -> None:
"""When all IPs are already cached, no HTTP call and no commit occur."""
geo_service._cache["5.5.5.5"] = GeoInfo( # type: ignore[attr-defined]
country_code="FR", country_name="France", asn="AS1", org="ISP"
)
db = _make_async_db()
session = _make_batch_session([])
result = await geo_service.lookup_batch(["5.5.5.5"], session, db=db) # type: ignore[arg-type]
assert result["5.5.5.5"].country_code == "FR"
db.commit.assert_not_awaited()
session.post.assert_not_called()
# ---------------------------------------------------------------------------
# Dirty-set tracking and flush_dirty (Task 3)
# ---------------------------------------------------------------------------
class TestDirtySetTracking:
"""_store() marks successfully resolved IPs as dirty."""
def test_successful_resolution_adds_to_dirty(self) -> None:
"""Storing a GeoInfo with a country_code adds the IP to _dirty."""
info = GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")
geo_service._store("1.2.3.4", info) # type: ignore[attr-defined]
assert "1.2.3.4" in geo_service._dirty # type: ignore[attr-defined]
def test_null_country_does_not_add_to_dirty(self) -> None:
"""Storing a GeoInfo with country_code=None must not pollute _dirty."""
info = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
geo_service._store("10.0.0.1", info) # type: ignore[attr-defined]
assert "10.0.0.1" not in geo_service._dirty # type: ignore[attr-defined]
def test_clear_cache_also_clears_dirty(self) -> None:
"""clear_cache() must discard any pending dirty entries."""
info = GeoInfo(country_code="US", country_name="United States", asn="AS1", org="ISP")
geo_service._store("8.8.8.8", info) # type: ignore[attr-defined]
assert geo_service._dirty # type: ignore[attr-defined]
geo_service.clear_cache()
assert not geo_service._dirty # type: ignore[attr-defined]
async def test_lookup_batch_populates_dirty(self) -> None:
"""After lookup_batch() with db=None, resolved IPs appear in _dirty."""
ips = ["1.1.1.1", "2.2.2.2"]
batch_response = [
{"query": ip, "status": "success", "countryCode": "JP", "country": "Japan", "as": "AS7500", "org": "IIJ"}
for ip in ips
]
session = _make_batch_session(batch_response)
await geo_service.lookup_batch(ips, session, db=None)
for ip in ips:
assert ip in geo_service._dirty # type: ignore[attr-defined]
class TestFlushDirty:
"""flush_dirty() persists dirty entries and clears the set."""
async def test_flush_writes_and_clears_dirty(self) -> None:
"""flush_dirty() inserts all dirty IPs and clears _dirty afterwards."""
info = GeoInfo(country_code="GB", country_name="United Kingdom", asn="AS2856", org="BT")
geo_service._store("100.0.0.1", info) # type: ignore[attr-defined]
assert "100.0.0.1" in geo_service._dirty # type: ignore[attr-defined]
db = _make_async_db()
count = await geo_service.flush_dirty(db)
assert count == 1
db.executemany.assert_awaited_once()
db.commit.assert_awaited_once()
assert "100.0.0.1" not in geo_service._dirty # type: ignore[attr-defined]
async def test_flush_returns_zero_when_nothing_dirty(self) -> None:
"""flush_dirty() returns 0 and makes no DB calls when _dirty is empty."""
db = _make_async_db()
count = await geo_service.flush_dirty(db)
assert count == 0
db.executemany.assert_not_awaited()
db.commit.assert_not_awaited()
async def test_flush_re_adds_to_dirty_on_db_error(self) -> None:
"""When the DB write fails, entries are re-added to _dirty for retry."""
info = GeoInfo(country_code="AU", country_name="Australia", asn="AS1", org="ISP")
geo_service._store("200.0.0.1", info) # type: ignore[attr-defined]
db = _make_async_db()
db.executemany = AsyncMock(side_effect=OSError("disk full"))
count = await geo_service.flush_dirty(db)
assert count == 0
assert "200.0.0.1" in geo_service._dirty # type: ignore[attr-defined]
async def test_flush_batch_and_lookup_batch_integration(self) -> None:
"""lookup_batch() populates _dirty; flush_dirty() then persists them."""
ips = ["10.1.2.3", "10.1.2.4"]
batch_response = [
{"query": ip, "status": "success", "countryCode": "CA", "country": "Canada", "as": "AS812", "org": "Bell"}
for ip in ips
]
session = _make_batch_session(batch_response)
# Resolve without DB to populate only in-memory cache and _dirty.
await geo_service.lookup_batch(ips, session, db=None)
assert geo_service._dirty == set(ips) # type: ignore[attr-defined]
# Now flush to the DB.
db = _make_async_db()
count = await geo_service.flush_dirty(db)
assert count == 2
assert not geo_service._dirty # type: ignore[attr-defined]
db.commit.assert_awaited_once()