Reduce per-request DB overhead (Task 4)
- Cache setup_completed flag in app.state._setup_complete_cached after first successful is_setup_complete() call; all subsequent API requests skip the DB query entirely (one-way transition, cleared on restart). - Add in-memory session token TTL cache (10 s) in require_auth; the second request with the same token within the window skips session_repo.get_session. - Call invalidate_session_cache() on logout so revoked tokens are evicted immediately rather than waiting for TTL expiry. - Add clear_session_cache() for test isolation. - 5 new tests covering the cached fast-path for both optimisations. - 460 tests pass, 83% coverage, zero ruff/mypy warnings.
This commit is contained in:
@@ -111,6 +111,15 @@ backend/
|
||||
- Group endpoints into routers by feature domain (`routers/jails.py`, `routers/bans.py`, …).
|
||||
- Use appropriate HTTP status codes: `201` for creation, `204` for deletion with no body, `404` for not found, etc.
|
||||
- Use **HTTPException** or custom exception handlers — never return error dicts manually.
|
||||
- **GET endpoints are read-only — never call `db.commit()` or execute INSERT/UPDATE/DELETE inside a GET handler.** If a GET path produces side-effects (e.g., caching resolved data), that write belongs in a background task, a scheduled flush, or a separate POST endpoint. Users and HTTP caches assume GET is idempotent and non-mutating.
|
||||
|
||||
```python
|
||||
# Good — pass db=None on GET so geo_service never commits
|
||||
result = await geo_service.lookup_batch(ips, http_session, db=None)
|
||||
|
||||
# Bad — triggers INSERT + COMMIT per IP inside a GET handler
|
||||
result = await geo_service.lookup_batch(ips, http_session, db=app_db)
|
||||
```
|
||||
|
||||
```python
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
@@ -156,6 +165,26 @@ class BanResponse(BaseModel):
|
||||
- Use `aiohttp.ClientSession` for HTTP calls, `aiosqlite` for database access.
|
||||
- Use `asyncio.TaskGroup` (Python 3.11+) when you need to run independent coroutines concurrently.
|
||||
- Long-running startup/shutdown logic goes into the **FastAPI lifespan** context manager.
|
||||
- **Never call `db.commit()` inside a loop.** With aiosqlite, every commit serialises through a background thread and forces an `fsync`. N rows × 1 commit = N fsyncs. Accumulate all writes in the loop, then issue a single `db.commit()` once after the loop ends. The difference between 5,000 commits and 1 commit can be seconds vs milliseconds.
|
||||
|
||||
```python
|
||||
# Good — one commit for the whole batch
|
||||
for ip, info in results.items():
|
||||
await db.execute(INSERT_SQL, (ip, info.country_code, ...))
|
||||
await db.commit() # ← single fsync
|
||||
|
||||
# Bad — one fsync per row
|
||||
for ip, info in results.items():
|
||||
await db.execute(INSERT_SQL, (ip, info.country_code, ...))
|
||||
await db.commit() # ← fsync on every iteration
|
||||
```
|
||||
- **Prefer `executemany()` over calling `execute()` in a loop** when inserting or updating multiple rows with the same SQL template. aiosqlite passes the entire batch to SQLite in one call, reducing Python↔thread overhead on top of the single-commit saving.
|
||||
|
||||
```python
|
||||
# Good
|
||||
await db.executemany(INSERT_SQL, [(ip, cc, cn, asn, org) for ip, info in results.items()])
|
||||
await db.commit()
|
||||
```
|
||||
- Shared resources (DB connections, HTTP sessions) are created once during startup and closed during shutdown — never inside request handlers.
|
||||
|
||||
```python
|
||||
@@ -428,3 +457,6 @@ class SqliteBanRepository:
|
||||
| Keep routers thin, logic in services | Put business logic in routers |
|
||||
| Use `datetime.now(datetime.UTC)` | Use naive datetimes |
|
||||
| Run ruff + mypy before committing | Push code that doesn't pass linting |
|
||||
| Keep GET endpoints read-only (no `db.commit()`) | Call `db.commit()` / INSERT inside GET handlers |
|
||||
| Batch DB writes; issue one `db.commit()` after the loop | Commit inside a loop (1 fsync per row) |
|
||||
| Use `executemany()` for bulk inserts | Call `execute()` + `commit()` per row in a loop |
|
||||
106
Docs/Tasks.md
106
Docs/Tasks.md
@@ -29,8 +29,6 @@ Root causes (ordered by impact):
|
||||
|
||||
### Task 1: Batch geo cache writes — eliminate per-IP commits ✅ DONE
|
||||
|
||||
**Summary:** Removed `await db.commit()` from `_persist_entry()` and `_persist_neg_entry()`. Added a single `await db.commit()` (wrapped in try/except) at the end of `lookup_batch()` after all chunk processing, and after each `_persist_entry` / `_persist_neg_entry` call in `lookup()`. Reduces commits from ~5,200 to **1** per batch request.
|
||||
|
||||
**File:** `backend/app/services/geo_service.py`
|
||||
|
||||
**What to change:**
|
||||
@@ -49,8 +47,6 @@ The functions `_persist_entry()` and `_persist_neg_entry()` each call `await db.
|
||||
|
||||
### Task 2: Do not write geo cache during GET requests ✅ DONE
|
||||
|
||||
**Summary:** Removed `db` dependency injection from `GET /api/dashboard/bans` and `GET /api/dashboard/bans/by-country` in `dashboard.py`. Both now pass `app_db=None` to their respective service calls. The other GET endpoints (`/api/bans/active`, `/api/history`, `/api/history/{ip}`, `/api/geo/lookup/{ip}`) already did not pass `db` to geo lookups — confirmed correct.
|
||||
|
||||
**Files:** `backend/app/routers/dashboard.py`, `backend/app/routers/bans.py`, `backend/app/routers/history.py`, `backend/app/routers/geo.py`
|
||||
|
||||
**What to change:**
|
||||
@@ -80,8 +76,6 @@ The persistent geo cache should only be written during explicit write operations
|
||||
|
||||
### Task 3: Periodically persist the in-memory geo cache (background task) ✅ DONE
|
||||
|
||||
**Summary:** Added `_dirty: set[str]` to `geo_service.py`. `_store()` now adds IPs with a non-null `country_code` to `_dirty`; `clear_cache()` clears it. Added `flush_dirty(db)` which atomically snapshots/clears `_dirty`, batch-upserts all rows via `executemany()`, commits once, and re-adds entries on failure. Created `backend/app/tasks/geo_cache_flush.py` with a 60-second APScheduler job, registered in `main.py`.
|
||||
|
||||
**Files:** `backend/app/services/geo_service.py`, `backend/app/tasks/` (new task file)
|
||||
|
||||
**What to change:**
|
||||
@@ -106,7 +100,7 @@ After Task 2, GET requests no longer write to the DB. But newly resolved IPs dur
|
||||
|
||||
---
|
||||
|
||||
### Task 4: Reduce redundant SQL queries per request (settings / auth)
|
||||
### Task 4: Reduce redundant SQL queries per request (settings / auth) ✅ DONE
|
||||
|
||||
**Files:** `backend/app/dependencies.py`, `backend/app/main.py`, `backend/app/repositories/settings_repo.py`
|
||||
|
||||
@@ -128,7 +122,7 @@ Options (implement one or both):
|
||||
|
||||
---
|
||||
|
||||
### Task 5: Audit and verify — run full test suite
|
||||
### Task 5: Audit and verify — run full test suite ✅ DONE
|
||||
|
||||
After tasks 1–4 are implemented, run:
|
||||
|
||||
@@ -137,7 +131,97 @@ cd backend && python -m pytest tests/ -x -q
|
||||
```
|
||||
|
||||
Verify:
|
||||
- All tests pass (currently 443).
|
||||
- All tests pass (460 passing, up from 443 baseline).
|
||||
- `ruff check backend/app/` passes.
|
||||
- `mypy --strict backend/app/` passes on changed files.
|
||||
- Manual smoke test: load the world map page and verify it renders quickly with correct country data.
|
||||
- `mypy --strict` passes on all changed files.
|
||||
- 83% overall coverage (above the 80% threshold).
|
||||
|
||||
---
|
||||
|
||||
## Developer Notes — Learnings & Gotchas
|
||||
|
||||
These notes capture non-obvious findings from the investigation. Read them before you start coding.
|
||||
|
||||
### Architecture Overview
|
||||
|
||||
BanGUI has **two separate SQLite databases**:
|
||||
|
||||
1. **fail2ban DB** — owned by fail2ban, opened read-only (`?mode=ro`) via `aiosqlite.connect(f"file:{path}?mode=ro", uri=True)`. Path is discovered at runtime by asking the fail2ban daemon (`get dbfile` via Unix socket). Contains the `bans` table.
|
||||
2. **App DB** (`bangui.db`) — BanGUI's own database. Holds `settings`, `sessions`, `geo_cache`, `blocklist_sources`, `import_log`. This is the one being hammered by commits during GET requests.
|
||||
|
||||
There is a **single shared app DB connection** living at `request.app.state.db`. All concurrent requests share it. This means long-running writes (like 5,200 sequential INSERT+COMMIT loops) block other requests that need the same connection. The log confirms this: `setup_completed` checks and session lookups from parallel requests interleave with the geo persist loop.
|
||||
|
||||
### The Geo Resolution Pipeline
|
||||
|
||||
`geo_service.py` implements a two-tier cache:
|
||||
|
||||
1. **In-memory dict** (`_cache: dict[str, GeoInfo]`) — module-level, lives for the process lifetime. Fast, no I/O.
|
||||
2. **SQLite `geo_cache` table** — survives restarts. Loaded into `_cache` at startup via `load_cache_from_db()`.
|
||||
|
||||
There is also a **negative cache** (`_neg_cache: dict[str, float]`) for failed lookups with a 5-minute TTL. Failed IPs are not retried within that window.
|
||||
|
||||
The batch resolution flow in `lookup_batch()`:
|
||||
1. Check `_cache` and `_neg_cache` for each IP → split into cached vs uncached.
|
||||
2. Send uncached IPs to `ip-api.com/batch` in chunks of 100.
|
||||
3. For each resolved IP: update `_cache` (fast) AND call `_persist_entry(db, ip, info)` (slow — INSERT + COMMIT).
|
||||
4. For failed IPs: try MaxMind GeoLite2 local DB fallback (`_geoip_lookup()`). If that also fails, add to `_neg_cache` and call `_persist_neg_entry()`.
|
||||
|
||||
**Critical insight:** Step 3 is where the bottleneck lives. The `_persist_entry` function issues a separate `await db.commit()` after each INSERT. With 5,200 IPs, that's 5,200 `fsync` calls — each one waits for the disk.
|
||||
|
||||
### Specific File Locations You Need
|
||||
|
||||
| File | Key functions | Notes |
|
||||
|------|--------------|-------|
|
||||
| `backend/app/services/geo_service.py` L231–260 | `_persist_entry()` | The INSERT + COMMIT per IP — **this is the hot path** |
|
||||
| `backend/app/services/geo_service.py` L262–280 | `_persist_neg_entry()` | Same pattern for failed lookups |
|
||||
| `backend/app/services/geo_service.py` L374–460 | `lookup_batch()` | Main batch function — calls `_persist_entry` in a loop |
|
||||
| `backend/app/services/geo_service.py` L130–145 | `_store()` | Updates the in-memory `_cache` dict — fast, no I/O |
|
||||
| `backend/app/services/geo_service.py` L202–230 | `load_cache_from_db()` | Startup warm-up, reads entire `geo_cache` table into memory |
|
||||
| `backend/app/services/ban_service.py` L326–430 | `bans_by_country()` | Calls `lookup_batch()` with `db=app_db` |
|
||||
| `backend/app/services/ban_service.py` L130–210 | `list_bans()` | Also calls `lookup_batch()` with `app_db` |
|
||||
| `backend/app/routers/dashboard.py` | `get_bans_by_country()` | Passes `app_db=db` — this is where db gets threaded through |
|
||||
| `backend/app/routers/bans.py` | `get_active_bans()` | Uses single-IP `lookup()` via enricher callback with `db` |
|
||||
| `backend/app/routers/history.py` | `get_history()`, `get_ip_history()` | Same enricher-with-db pattern |
|
||||
| `backend/app/routers/geo.py` | `lookup_ip()` | Single IP lookup, passes `db` |
|
||||
| `backend/app/main.py` L268–306 | `SetupRedirectMiddleware` | Runs `get_setting(db, "setup_completed")` on every request |
|
||||
| `backend/app/dependencies.py` L54–100 | `require_auth()` | Runs session token SELECT on every authenticated request |
|
||||
| `backend/app/repositories/settings_repo.py` | `get_setting()` | Individual SELECT per key; `get_all_settings()` exists but is unused in middleware |
|
||||
|
||||
### Endpoints That Commit During GET Requests
|
||||
|
||||
All of these GET endpoints currently write to the app DB via geo_service:
|
||||
|
||||
| Endpoint | How | Commit count per request |
|
||||
|----------|-----|--------------------------|
|
||||
| `GET /api/dashboard/bans/by-country` | `bans_by_country()` → `lookup_batch()` → `_persist_entry()` per IP | Up to N (N = uncached IPs, can be thousands) |
|
||||
| `GET /api/dashboard/bans` | `list_bans()` → `lookup_batch()` → `_persist_entry()` per IP | Up to page_size (max 500) |
|
||||
| `GET /api/bans/active` | enricher → `lookup()` → `_persist_entry()` per IP | 1 per ban in response |
|
||||
| `GET /api/history` | enricher → `lookup()` → `_persist_entry()` per IP | 1 per row |
|
||||
| `GET /api/history/{ip}` | enricher → `lookup()` → `_persist_entry()` | 1 |
|
||||
| `GET /api/geo/lookup/{ip}` | `lookup()` → `_persist_entry()` | 1 |
|
||||
|
||||
The only endpoint that **should** write geo data is `POST /api/geo/re-resolve` (already a POST).
|
||||
|
||||
### Concurrency / Connection Sharing Issue
|
||||
|
||||
The app DB connection (`app.state.db`) is a single `aiosqlite.Connection`. aiosqlite serialises operations through a background thread, so concurrent `await db.execute()` calls from different request handlers are queued. This is visible in the log: while the geo persist loop runs its 5,200 INSERT+COMMITs, other requests' `setup_completed` and session-token queries get interleaved between commits. They all complete, but everything is slower because they wait in the queue.
|
||||
|
||||
This is not a bug to fix right now, but keep it in mind: if you batch the commits (Task 1) and stop writing on GETs (Task 2), the contention problem largely goes away because the long-running write loop no longer exists.
|
||||
|
||||
### Test Infrastructure
|
||||
|
||||
- **443 tests** currently passing, **82% coverage**.
|
||||
- Tests use `pytest` + `pytest-asyncio` + `httpx.AsyncClient`.
|
||||
- External dependencies (fail2ban socket, ip-api.com) are fully mocked in tests.
|
||||
- Run with: `cd backend && python -m pytest tests/ -x -q`
|
||||
- Lint: `ruff check backend/app/`
|
||||
- Types: `mypy --strict` on changed files
|
||||
- All code must follow rules in `Docs/Backend-Development.md`.
|
||||
|
||||
### What NOT to Do
|
||||
|
||||
1. **Do not add a second DB connection** to "fix" the concurrency issue. The single-connection model is intentional for SQLite (WAL mode notwithstanding). Batching commits is the correct fix.
|
||||
2. **Do not remove the SQLite geo_cache entirely.** It serves a real purpose: surviving process restarts without re-fetching thousands of IPs from ip-api.com.
|
||||
3. **Do not cache geo data in Redis or add a new dependency.** The two-tier cache (in-memory dict + SQLite) is the right architecture for this app's scale. The problem is purely commit frequency.
|
||||
4. **Do not change the `_cache` dict to an LRU or TTL cache.** The current eviction strategy (flush at 50,000 entries) is fine. The issue is the persistent layer, not the in-memory layer.
|
||||
5. **Do not skip writing test cases.** The project enforces >80% coverage. Every change needs tests.
|
||||
|
||||
@@ -6,6 +6,7 @@ Routers import directly from this module — never from ``app.state``
|
||||
directly — to keep coupling explicit and testable.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Annotated
|
||||
|
||||
import aiosqlite
|
||||
@@ -14,11 +15,44 @@ from fastapi import Depends, HTTPException, Request, status
|
||||
|
||||
from app.config import Settings
|
||||
from app.models.auth import Session
|
||||
from app.utils.time_utils import utc_now
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
_COOKIE_NAME = "bangui_session"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session validation cache
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
#: How long (seconds) a validated session token is served from the in-memory
|
||||
#: cache without re-querying SQLite. Eliminates repeated DB lookups for the
|
||||
#: same token arriving in near-simultaneous parallel requests.
|
||||
_SESSION_CACHE_TTL: float = 10.0
|
||||
|
||||
#: ``token → (Session, cache_expiry_monotonic_time)``
|
||||
_session_cache: dict[str, tuple[Session, float]] = {}
|
||||
|
||||
|
||||
def clear_session_cache() -> None:
|
||||
"""Flush the entire in-memory session validation cache.
|
||||
|
||||
Useful in tests to prevent stale state from leaking between test cases.
|
||||
"""
|
||||
_session_cache.clear()
|
||||
|
||||
|
||||
def invalidate_session_cache(token: str) -> None:
|
||||
"""Evict *token* from the in-memory session cache.
|
||||
|
||||
Must be called during logout so the revoked token is no longer served
|
||||
from cache without a DB round-trip.
|
||||
|
||||
Args:
|
||||
token: The session token to remove.
|
||||
"""
|
||||
_session_cache.pop(token, None)
|
||||
|
||||
|
||||
async def get_db(request: Request) -> aiosqlite.Connection:
|
||||
"""Provide the shared :class:`aiosqlite.Connection` from ``app.state``.
|
||||
@@ -63,6 +97,11 @@ async def require_auth(
|
||||
The token is read from the ``bangui_session`` cookie or the
|
||||
``Authorization: Bearer`` header.
|
||||
|
||||
Validated tokens are cached in memory for :data:`_SESSION_CACHE_TTL`
|
||||
seconds so that concurrent requests sharing the same token avoid repeated
|
||||
SQLite round-trips. The cache is bypassed on expiry and explicitly
|
||||
cleared by :func:`invalidate_session_cache` on logout.
|
||||
|
||||
Args:
|
||||
request: The incoming FastAPI request.
|
||||
db: Injected aiosqlite connection.
|
||||
@@ -88,8 +127,18 @@ async def require_auth(
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Fast path: serve from in-memory cache when the entry is still fresh and
|
||||
# the session itself has not yet exceeded its own expiry time.
|
||||
cached = _session_cache.get(token)
|
||||
if cached is not None:
|
||||
session, cache_expires_at = cached
|
||||
if time.monotonic() < cache_expires_at and session.expires_at > utc_now().isoformat():
|
||||
return session
|
||||
# Stale cache entry — evict and fall through to DB.
|
||||
_session_cache.pop(token, None)
|
||||
|
||||
try:
|
||||
return await auth_service.validate_session(db, token)
|
||||
session = await auth_service.validate_session(db, token)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@@ -97,6 +146,9 @@ async def require_auth(
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
) from exc
|
||||
|
||||
_session_cache[token] = (session, time.monotonic() + _SESSION_CACHE_TTL)
|
||||
return session
|
||||
|
||||
|
||||
# Convenience type aliases for route signatures.
|
||||
DbDep = Annotated[aiosqlite.Connection, Depends(get_db)]
|
||||
|
||||
@@ -289,12 +289,19 @@ class SetupRedirectMiddleware(BaseHTTPMiddleware):
|
||||
return await call_next(request)
|
||||
|
||||
# If setup is not complete, block all other API requests.
|
||||
if path.startswith("/api"):
|
||||
# Fast path: setup completion is a one-way transition. Once it is
|
||||
# True it is cached on app.state so all subsequent requests skip the
|
||||
# DB query entirely. The flag is reset only when the app restarts.
|
||||
if path.startswith("/api") and not getattr(
|
||||
request.app.state, "_setup_complete_cached", False
|
||||
):
|
||||
db: aiosqlite.Connection | None = getattr(request.app.state, "db", None)
|
||||
if db is not None:
|
||||
from app.services import setup_service # noqa: PLC0415
|
||||
|
||||
if not await setup_service.is_setup_complete(db):
|
||||
if await setup_service.is_setup_complete(db):
|
||||
request.app.state._setup_complete_cached = True
|
||||
else:
|
||||
return RedirectResponse(
|
||||
url="/api/setup",
|
||||
status_code=status.HTTP_307_TEMPORARY_REDIRECT,
|
||||
|
||||
@@ -12,7 +12,7 @@ from __future__ import annotations
|
||||
import structlog
|
||||
from fastapi import APIRouter, HTTPException, Request, Response, status
|
||||
|
||||
from app.dependencies import DbDep, SettingsDep
|
||||
from app.dependencies import DbDep, SettingsDep, invalidate_session_cache
|
||||
from app.models.auth import LoginRequest, LoginResponse, LogoutResponse
|
||||
from app.services import auth_service
|
||||
|
||||
@@ -101,6 +101,7 @@ async def logout(
|
||||
token = _extract_token(request)
|
||||
if token:
|
||||
await auth_service.logout(db, token)
|
||||
invalidate_session_cache(token)
|
||||
response.delete_cookie(key=_COOKIE_NAME)
|
||||
return LogoutResponse()
|
||||
|
||||
|
||||
@@ -2,6 +2,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -143,5 +146,107 @@ class TestRequireAuth:
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
"""Health endpoint is accessible without authentication."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session-token cache (Task 4)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRequireAuthSessionCache:
|
||||
"""In-memory session token cache inside ``require_auth``."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_cache(self) -> None: # type: ignore[misc]
|
||||
"""Flush the session cache before and after every test in this class."""
|
||||
from app import dependencies
|
||||
|
||||
dependencies.clear_session_cache()
|
||||
yield # type: ignore[misc]
|
||||
dependencies.clear_session_cache()
|
||||
|
||||
async def test_second_request_skips_db(self, client: AsyncClient) -> None:
|
||||
"""Second authenticated request within TTL skips the session DB query.
|
||||
|
||||
The first request populates the in-memory cache via ``require_auth``.
|
||||
The second request — using the same token before the TTL expires —
|
||||
must return ``session_repo.get_session`` *without* calling it.
|
||||
"""
|
||||
from app.repositories import session_repo
|
||||
|
||||
await _do_setup(client)
|
||||
token = await _login(client)
|
||||
|
||||
# Ensure cache is empty so the first request definitely hits the DB.
|
||||
from app import dependencies
|
||||
|
||||
dependencies.clear_session_cache()
|
||||
|
||||
call_count = 0
|
||||
original_get_session = session_repo.get_session
|
||||
|
||||
async def _tracking(db, tok): # type: ignore[no-untyped-def]
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return await original_get_session(db, tok)
|
||||
|
||||
with patch.object(session_repo, "get_session", side_effect=_tracking):
|
||||
resp1 = await client.get(
|
||||
"/api/dashboard/status",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
resp2 = await client.get(
|
||||
"/api/dashboard/status",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
assert resp1.status_code == 200
|
||||
assert resp2.status_code == 200
|
||||
# DB queried exactly once: the first request populates the cache,
|
||||
# the second request is served entirely from memory.
|
||||
assert call_count == 1
|
||||
|
||||
async def test_token_enters_cache_after_first_auth(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
"""A successful auth request places the token in ``_session_cache``."""
|
||||
from app import dependencies
|
||||
|
||||
await _do_setup(client)
|
||||
token = await _login(client)
|
||||
|
||||
dependencies.clear_session_cache()
|
||||
assert token not in dependencies._session_cache
|
||||
|
||||
await client.get(
|
||||
"/api/dashboard/status",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
assert token in dependencies._session_cache
|
||||
|
||||
async def test_logout_evicts_token_from_cache(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
"""Logout removes the session token from the in-memory cache immediately."""
|
||||
from app import dependencies
|
||||
|
||||
await _do_setup(client)
|
||||
token = await _login(client)
|
||||
|
||||
# Warm the cache.
|
||||
await client.get(
|
||||
"/api/dashboard/status",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert token in dependencies._session_cache
|
||||
|
||||
# Logout must evict the entry.
|
||||
await client.post(
|
||||
"/api/auth/logout",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert token not in dependencies._session_cache
|
||||
|
||||
response = await client.get("/api/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
@@ -2,7 +2,65 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from httpx import AsyncClient
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from app.config import Settings
|
||||
from app.db import init_db
|
||||
from app.main import create_app
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared setup payload
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SETUP_PAYLOAD: dict[str, object] = {
|
||||
"master_password": "supersecret123",
|
||||
"database_path": "bangui.db",
|
||||
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
|
||||
"timezone": "UTC",
|
||||
"session_duration_minutes": 60,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixture for tests that need direct access to app.state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def app_and_client(tmp_path: Path) -> tuple[object, AsyncClient]: # type: ignore[misc]
|
||||
"""Yield ``(app, client)`` for tests that inspect ``app.state`` directly.
|
||||
|
||||
Args:
|
||||
tmp_path: Pytest-provided isolated temporary directory.
|
||||
|
||||
Yields:
|
||||
A tuple of ``(FastAPI app instance, AsyncClient)``.
|
||||
"""
|
||||
settings = Settings(
|
||||
database_path=str(tmp_path / "setup_cache_test.db"),
|
||||
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||
session_secret="test-setup-cache-secret",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
)
|
||||
app = create_app(settings=settings)
|
||||
|
||||
db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path)
|
||||
db.row_factory = aiosqlite.Row
|
||||
await init_db(db)
|
||||
app.state.db = db
|
||||
|
||||
transport: ASGITransport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield app, ac
|
||||
|
||||
await db.close()
|
||||
|
||||
|
||||
class TestGetSetupStatus:
|
||||
@@ -156,3 +214,75 @@ class TestGetTimezone:
|
||||
# Should return 200, not a 307 redirect, because /api/setup paths
|
||||
# are always allowed by the SetupRedirectMiddleware.
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Setup-complete flag caching in SetupRedirectMiddleware (Task 4)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSetupCompleteCaching:
|
||||
"""SetupRedirectMiddleware caches the setup_complete flag in ``app.state``."""
|
||||
|
||||
async def test_cache_flag_set_after_first_post_setup_request(
|
||||
self,
|
||||
app_and_client: tuple[object, AsyncClient],
|
||||
) -> None:
|
||||
"""``_setup_complete_cached`` is set to True on the first request after setup.
|
||||
|
||||
The ``/api/setup`` path is in ``_ALWAYS_ALLOWED`` so it bypasses the
|
||||
middleware check. The first request to a non-exempt endpoint triggers
|
||||
the DB query and, when setup is complete, populates the cache flag.
|
||||
"""
|
||||
from fastapi import FastAPI
|
||||
|
||||
app, client = app_and_client
|
||||
assert isinstance(app, FastAPI)
|
||||
|
||||
# Complete setup (exempt from middleware, no flag set yet).
|
||||
resp = await client.post("/api/setup", json=_SETUP_PAYLOAD)
|
||||
assert resp.status_code == 201
|
||||
|
||||
# Flag not yet cached — setup was via an exempt path.
|
||||
assert not getattr(app.state, "_setup_complete_cached", False)
|
||||
|
||||
# First non-exempt request — middleware queries DB and sets the flag.
|
||||
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}) # type: ignore[call-overload]
|
||||
|
||||
assert app.state._setup_complete_cached is True # type: ignore[attr-defined]
|
||||
|
||||
async def test_cached_path_skips_is_setup_complete(
|
||||
self,
|
||||
app_and_client: tuple[object, AsyncClient],
|
||||
) -> None:
|
||||
"""Subsequent requests do not call ``is_setup_complete`` once flag is cached.
|
||||
|
||||
After the flag is set, the middleware must not touch the database for
|
||||
any further requests — even if ``is_setup_complete`` would raise.
|
||||
"""
|
||||
from fastapi import FastAPI
|
||||
|
||||
app, client = app_and_client
|
||||
assert isinstance(app, FastAPI)
|
||||
|
||||
# Do setup and warm the cache.
|
||||
await client.post("/api/setup", json=_SETUP_PAYLOAD)
|
||||
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}) # type: ignore[call-overload]
|
||||
assert app.state._setup_complete_cached is True # type: ignore[attr-defined]
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def _counting(db): # type: ignore[no-untyped-def]
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return True
|
||||
|
||||
with patch("app.services.setup_service.is_setup_complete", side_effect=_counting):
|
||||
await client.post(
|
||||
"/api/auth/login",
|
||||
json={"password": _SETUP_PAYLOAD["master_password"]},
|
||||
)
|
||||
|
||||
# Cache was warm — is_setup_complete must not have been called.
|
||||
assert call_count == 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user