17 Commits

Author SHA1 Message Date
3af8f0571b feat: graceful shutdown and WAL cleanup
Some checks failed
CI / Backend Tests (push) Has been cancelled
CI / Lint (push) Has been cancelled
CI / Type Check (push) Has been cancelled
CI / Import Boundary (push) Has been cancelled
CI / OpenAPI Breaking Changes (push) Has been cancelled
CI / OpenAPI Baseline Commit (push) Has been cancelled
- Add stop_grace_period to backend container for graceful shutdown
- Document WAL mode rationale and orphaned file cleanup in db.py
- Handle database close errors gracefully in lifespan
- Clean up orphaned WAL files during startup before opening DB
- Reorder imports and fix formatting in startup.py
2026-05-24 22:05:34 +02:00
d5a78a251a Remove Tasks.md spec, add test for _cleanup_wal_files skipping recent files
Remove 335-line task specification from Docs/Tasks.md.
Add test confirming _cleanup_wal_files skips recently-modified WAL/SHM files.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-05-24 22:05:34 +02:00
904db63fa2 Add tests for since timestamp accuracy in ban_service
- test_since_unix_returns_utc_epoch: validates since_unix('24h') returns UTC epoch
- test_ban_trend_since_is_within_expected_range: validates 23h-ago ban falls in 24h+slack window

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-05-24 22:05:34 +02:00
d737a1c319 Add logging duplication tests
- test_logging_configuration_no_duplicate_handlers: verify create_app() twice leaves ≤1 StreamHandler
- test_uvicorn_access_logs_go_through_root_handler: verify uvicorn.access can emit JSON via JSONFormatter
- test_external_logging_processor_queues_record: verify _external_logging_processor queues to handler
- test_plain_text_logs_not_emitted_after_startup: verify app.db emits JSON not plain text

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-05-24 22:05:34 +02:00
9e765c6cb7 Add granular DB error types with retry logic
New exceptions: DatabaseBusyError, DatabasePermissionDeniedError,
DatabasePathInvalidError, DatabaseCorruptedError, DatabaseUnavailableError.

open_db creates parent directory if missing. Catches all aiosqlite errors
and maps to specific exception types.

get_db retries up to 3x on locked database with backoff.
Propagates specific exceptions instead of generic HTTPException.

Tests for all new error types and retry behavior.
2026-05-24 22:05:34 +02:00
ecb8542496 docs: add comprehensive task backlog and bump version to rc.5
- Document database error handling, logging duplication, ban service
timestamp, and orphaned SQLite file issues in Tasks.md
- Bump backend version from 0.9.19-rc.4 to 0.9.19-rc.5
2026-05-24 22:05:34 +02:00
97f4df4a61 chore: release v0.9.19-rc.5 2026-05-24 22:05:34 +02:00
44542b93c0 chore(release): bump version to 0.9.19-rc.4
- Add production Docker Compose configuration

- Add check_auth.py diagnostic script for session 401 debugging
2026-05-24 22:05:34 +02:00
01a4215f60 chore: release v0.9.19-rc.4 2026-05-24 22:05:34 +02:00
bc49b7cd5b fix(db): fix migration failures when upgrading from 0.8.0 schema
Migration 1: remove idx_sessions_token_hash from _SCHEMA_STATEMENTS.
The legacy schema has sessions.token (not token_hash). The IF NOT EXISTS
guard only prevents duplicate index names — it still requires the column
to exist. Migration 2 drops and rebuilds sessions with token_hash anyway,
so creating the index in migration 1 was redundant.

Migration 3: replace ALTER TABLE ADD COLUMN with a table rebuild.
SQLite rejects ALTER TABLE ADD COLUMN NOT NULL DEFAULT <expression> when
the table already contains rows. The old DB has ~181k geo_cache rows, so
the ALTER always failed. Rebuild copies existing rows with last_seen set
to cached_at as a reasonable approximation of last-seen time.
2026-05-24 22:05:34 +02:00
fa4fe4bbdf chore: release v0.9.19-rc.3 2026-05-24 22:05:34 +02:00
ee0fe9c695 fix(auth): suppress misleading 502 warning during session validation
A 502 Bad Gateway is a server/gateway error, not a network error.
Logging it as a 'Session validation network error' is noisy and
misleading during startup when nginx is temporarily unreachable.

Silently skip the console.warn for 5xx errors in handleValidationError
while keeping the warning for actual network errors.
2026-05-24 22:05:34 +02:00
551db0bb9c chore: release v0.9.19-rc.2 2026-05-24 22:05:34 +02:00
4a649e7347 chore: bump to v0.9.19-rc.1 and add local OpenAPI build support
- Add release candidate (rc) support to release.sh with latestRC tagging
- Bump VERSION, backend pyproject.toml, and frontend package.json to 0.9.19-rc.1
- Add local frontend/openapi.json so build no longer needs running backend
- Update generate:types and validate-types.sh to use local openapi.json
- Fix frontend tests: remove unused imports/variables and update mock data
2026-05-24 22:05:34 +02:00
025c82a982 Merge pull request 'refactoring-backend' (#3) from refactoring-backend into main
Some checks failed
CI / Backend Tests (push) Has been cancelled
CI / Lint (push) Has been cancelled
CI / Type Check (push) Has been cancelled
CI / Import Boundary (push) Has been cancelled
CI / OpenAPI Breaking Changes (push) Has been cancelled
CI / OpenAPI Baseline Commit (push) Has been cancelled
Reviewed-on: #3
2026-05-20 20:23:46 +02:00
83b2cb67b1 backup
Some checks failed
CI / Backend Tests (pull_request) Has been cancelled
CI / Lint (pull_request) Has been cancelled
CI / Type Check (pull_request) Has been cancelled
CI / Import Boundary (pull_request) Has been cancelled
CI / OpenAPI Breaking Changes (pull_request) Has been cancelled
CI / OpenAPI Baseline Commit (pull_request) Has been cancelled
2026-05-20 20:18:58 +02:00
7308ff88d6 fix(rate-limit): stop double-counting requests in middleware
Multiple RateLimitMiddleware instances were each calling
check_allowed() on every request, halving the effective global
limit (200 req/min became ~100). Added path_prefixes and skip_paths
so each instance only checks the paths it owns.

- Auth middleware scoped to /api/v1/auth/login and /api/v1/setup
- History middleware scoped to /api/v1/history
- Global middleware skips auth and history paths
- Updated tests to match single-count behavior
2026-05-15 23:04:02 +02:00
30 changed files with 11507 additions and 133 deletions

View File

@@ -18,7 +18,7 @@ WORKDIR /build
COPY frontend/package.json frontend/package-lock.json* /build/
RUN npm ci --ignore-scripts
# Copy source and build
# Copy source + local OpenAPI spec (avoids needing a running backend during build)
COPY frontend/ /build/
RUN npm run build

View File

@@ -1 +1 @@
v0.9.19
v0.9.19-rc.5

106
Docker/compose.prod.yml Normal file
View File

@@ -0,0 +1,106 @@
# ──────────────────────────────────────────────────────────────
# BanGUI — Production Compose
#
# Usage:
# docker compose -f Docker/compose.prod.yml up -d
# podman compose -f Docker/compose.prod.yml up -d
#
# Features:
# - Multi-stage built images (no volume-mounted source code)
# - Frontend served by nginx with API reverse proxy
# - Backend running uvicorn without --reload
# - Only port 8080 exposed to host
# ──────────────────────────────────────────────────────────────
name: bangui
services:
# ── fail2ban ─────────────────────────────────────────────────
fail2ban:
image: lscr.io/linuxserver/fail2ban:latest
container_name: bangui-fail2ban
restart: unless-stopped
cap_add:
- NET_ADMIN
- NET_RAW
network_mode: host
environment:
TZ: "${BANGUI_TIMEZONE:-UTC}"
PUID: 0
PGID: 0
volumes:
- ../data/fail2ban-dev-config:/config
- fail2ban-run:/var/run/fail2ban
- /var/log:/var/log:ro
- ../data/log:/remotelogs/bangui
healthcheck:
test: ["CMD", "fail2ban-client", "ping"]
interval: 30s
timeout: 5s
start_period: 15s
retries: 3
# ── Backend (FastAPI + uvicorn) ─────────────────────────────
backend:
build:
context: ..
dockerfile: Docker/Dockerfile.backend
target: runtime
container_name: bangui-backend
restart: unless-stopped
stop_grace_period: 30s # Give lifespan 30s to complete before SIGKILL
depends_on:
fail2ban:
condition: service_healthy
environment:
BANGUI_DATABASE_PATH: "/data/bangui.db"
BANGUI_FAIL2BAN_SOCKET: "/var/run/fail2ban/fail2ban.sock"
BANGUI_FAIL2BAN_CONFIG_DIR: "/config/fail2ban"
BANGUI_LOG_FILE: "/data/log/bangui.log"
BANGUI_LOG_LEVEL: "${BANGUI_LOG_LEVEL:-info}"
BANGUI_SESSION_SECRET: "${BANGUI_SESSION_SECRET:?BANGUI_SESSION_SECRET must be set — generate with: python -c 'import secrets; print(secrets.token_hex(32))'}"
BANGUI_TIMEZONE: "${BANGUI_TIMEZONE:-UTC}"
BANGUI_SESSION_COOKIE_SECURE: "${BANGUI_SESSION_COOKIE_SECURE:-true}"
BANGUI_CORS_ALLOWED_ORIGINS: "${BANGUI_CORS_ALLOWED_ORIGINS:-}"
volumes:
- ../data:/data
- ../fail2ban-master:/app/fail2ban-master:ro
- fail2ban-run:/var/run/fail2ban:ro
- ../data/fail2ban-dev-config:/config:rw
networks:
- bangui-net
healthcheck:
test: ["CMD-SHELL", "curl -f http://localhost:8000/api/v1/health/live || exit 1"]
interval: 30s
timeout: 10s
start_period: 40s
retries: 3
# ── Frontend (nginx serving built SPA) ──────────────────────
frontend:
build:
context: ..
dockerfile: Docker/Dockerfile.frontend
container_name: bangui-frontend
restart: unless-stopped
depends_on:
backend:
condition: service_healthy
ports:
- "${BANGUI_PORT:-8080}:80"
networks:
- bangui-net
healthcheck:
test: ["CMD-SHELL", "wget -qO /dev/null http://localhost:80/ || exit 1"]
interval: 30s
timeout: 5s
start_period: 5s
retries: 3
volumes:
fail2ban-run:
driver: local
networks:
bangui-net:
driver: bridge

View File

@@ -6,7 +6,7 @@
# ./release.sh
#
# The current version is stored in VERSION (next to this script).
# You will be asked whether to bump major, minor, or patch.
# You will be asked whether to bump major, minor, patch, or release candidate (rc).
set -euo pipefail
@@ -24,24 +24,60 @@ CURRENT="$(cat "${VERSION_FILE}")"
# Strip leading 'v' for arithmetic
VERSION="${CURRENT#v}"
IFS='.' read -r MAJOR MINOR PATCH <<< "${VERSION}"
# Parse version: X.Y.Z or X.Y.Z-rc.N
if [[ "${VERSION}" =~ ^([0-9]+)\.([0-9]+)\.([0-9]+)(-rc\.([0-9]+))?$ ]]; then
MAJOR="${BASH_REMATCH[1]}"
MINOR="${BASH_REMATCH[2]}"
PATCH="${BASH_REMATCH[3]}"
RC_SUFFIX="${BASH_REMATCH[4]:-}"
RC_NUM="${BASH_REMATCH[5]:-0}"
else
echo "Error: version '${VERSION}' does not match expected format X.Y.Z or X.Y.Z-rc.N" >&2
exit 1
fi
echo "============================================"
echo " BanGUI — Release"
echo " Current version: v${MAJOR}.${MINOR}.${PATCH}"
if [[ -n "${RC_SUFFIX}" ]]; then
echo " Current version: v${MAJOR}.${MINOR}.${PATCH}-rc.${RC_NUM}"
else
echo " Current version: v${MAJOR}.${MINOR}.${PATCH}"
fi
echo "============================================"
echo ""
echo "How would you like to bump the version?"
echo " 1) patch (v${MAJOR}.${MINOR}.${PATCH} → v${MAJOR}.${MINOR}.$((PATCH + 1)))"
echo " 2) minor (v${MAJOR}.${MINOR}.${PATCH} → v${MAJOR}.$((MINOR + 1)).0)"
echo " 3) major (v${MAJOR}.${MINOR}.${PATCH} → v$((MAJOR + 1)).0.0)"
if [[ -n "${RC_SUFFIX}" ]]; then
echo " 1) patch (v${MAJOR}.${MINOR}.${PATCH}-rc.${RC_NUM} → v${MAJOR}.${MINOR}.${PATCH})"
echo " 2) minor (v${MAJOR}.${MINOR}.${PATCH}-rc.${RC_NUM} → v${MAJOR}.$((MINOR + 1)).0)"
echo " 3) major (v${MAJOR}.${MINOR}.${PATCH}-rc.${RC_NUM} → v$((MAJOR + 1)).0.0)"
echo " 4) rc (v${MAJOR}.${MINOR}.${PATCH}-rc.${RC_NUM} → v${MAJOR}.${MINOR}.${PATCH}-rc.$((RC_NUM + 1)))"
else
echo " 1) patch (v${MAJOR}.${MINOR}.${PATCH} → v${MAJOR}.${MINOR}.$((PATCH + 1)))"
echo " 2) minor (v${MAJOR}.${MINOR}.${PATCH} → v${MAJOR}.$((MINOR + 1)).0)"
echo " 3) major (v${MAJOR}.${MINOR}.${PATCH} → v$((MAJOR + 1)).0.0)"
echo " 4) rc (v${MAJOR}.${MINOR}.${PATCH} → v${MAJOR}.${MINOR}.${PATCH}-rc.1)"
fi
echo ""
read -rp "Enter choice [1/2/3]: " CHOICE
read -rp "Enter choice [1/2/3/4]: " CHOICE
case "${CHOICE}" in
1) NEW_TAG="v${MAJOR}.${MINOR}.$((PATCH + 1))" ;;
1)
if [[ -n "${RC_SUFFIX}" ]]; then
# Release the RC: strip RC suffix
NEW_TAG="v${MAJOR}.${MINOR}.${PATCH}"
else
NEW_TAG="v${MAJOR}.${MINOR}.$((PATCH + 1))"
fi
;;
2) NEW_TAG="v${MAJOR}.$((MINOR + 1)).0" ;;
3) NEW_TAG="v$((MAJOR + 1)).0.0" ;;
4)
if [[ "${RC_NUM}" -gt 0 ]]; then
NEW_TAG="v${MAJOR}.${MINOR}.${PATCH}-rc.$((RC_NUM + 1))"
else
NEW_TAG="v${MAJOR}.${MINOR}.${PATCH}-rc.1"
fi
;;
*)
echo "Invalid choice. Aborting." >&2
exit 1
@@ -81,7 +117,13 @@ fi
# Push containers
# ---------------------------------------------------------------------------
bash "${SCRIPT_DIR}/push.sh" "${NEW_TAG}"
bash "${SCRIPT_DIR}/push.sh"
# Push to "latest" or "latestRC" depending on whether this is a release candidate
if [[ "${NEW_TAG}" == *-rc* ]]; then
bash "${SCRIPT_DIR}/push.sh" "latestRC"
else
bash "${SCRIPT_DIR}/push.sh" "latest"
fi
# ---------------------------------------------------------------------------

View File

@@ -0,0 +1,44 @@
## Task: Investigate Orphaned SQLite Shared Memory Files on Startup
### Issue in Detail
The log shows repeated warnings:
```
event=orphaned_sqlite_file_removed path=/data/bangui.db-shm
```
This occurs at `19:39:48` and again at `19:49:39` (after restart). The `-shm` file is SQLite's shared memory file for WAL mode. Its presence indicates **unclean shutdowns** (crashes or SIGKILL instead of graceful SIGTERM).
### Why This Happens
1. **Docker stop timeout:** Docker sends SIGTERM, waits `stop_grace_period` (default 10s), then sends SIGKILL. The backend allows 25s for graceful shutdown, but if the container's `stop_grace_period` is shorter, the process is killed before cleanup completes.
2. **Missing connection close:** If the application crashes or is killed, SQLite connections are not closed, leaving `.wal` and `.shm` files behind.
3. **`_cleanup_wal_files()` is a workaround, not a fix:** It removes stale files on the *next* startup, but the underlying cause (unclean shutdown) remains.
### How to Fix It
1. **Verify Docker Compose `stop_grace_period`:** In `Docker/compose.prod.yml`, ensure the backend service has `stop_grace_period: 30s` (matching the 25s internal timeout + margin).
2. **Improve shutdown logging:** Add explicit logs when the database connection is closed during lifespan shutdown.
3. **Consider `PRAGMA journal_mode = DELETE` for single-process setups:** WAL mode is beneficial for concurrent readers, but if BanGUI runs with a single worker and single process, DELETE mode eliminates `.wal`/`.shm` files entirely. Evaluate the tradeoff.
### Issues and Trapfalls
1. **WAL mode is required for concurrent reads:** If you switch to DELETE mode, readers block writers. This may degrade API performance under load.
2. **The `_cleanup_wal_files()` 10-second threshold:** Files modified within 10 seconds are skipped. If the container restarts rapidly (e.g., health check failure → restart), the files may not be cleaned up.
### Documentation References
- **`Docs/Deployment.md`:** Docker deployment configuration and graceful shutdown behavior.
- **`Docs/Architekture.md`:** Deployment constraints and process-local state.
### Tests to Write
#### 1. `test_cleanup_wal_files_removes_stale_files`
- **Setup:** Create fake `.wal` and `.shm` files with mtime > 10s ago.
- **Action:** Call `_cleanup_wal_files()`.
- **Assert:** Files are removed.
#### 2. `test_cleanup_wal_files_skips_recent_files`
- **Setup:** Create fake `.wal` and `.shm` files with mtime < 10s ago.
- **Action:** Call `_cleanup_wal_files()`.
- **Assert:** Files are NOT removed.

View File

@@ -102,10 +102,15 @@ CREATE TABLE IF NOT EXISTS schema_migrations (
"""
# Ordered list of DDL statements to execute on initialisation.
# NOTE: _CREATE_SESSIONS_TOKEN_INDEX is intentionally omitted here.
# The old 0.8.0 schema has a `sessions.token` column (not `token_hash`), so
# running CREATE INDEX … ON sessions (token_hash) in migration 1 would fail
# with "no such column: token_hash" on legacy databases. Migration 2 drops
# and recreates the sessions table with token_hash and also creates the index,
# so there is no need to create it in migration 1.
_SCHEMA_STATEMENTS: list[str] = [
_CREATE_SETTINGS,
_CREATE_SESSIONS,
_CREATE_SESSIONS_TOKEN_INDEX,
_CREATE_BLOCKLIST_SOURCES,
_CREATE_IMPORT_LOG,
_CREATE_GEO_CACHE,
@@ -133,8 +138,24 @@ CREATE UNIQUE INDEX idx_sessions_token_hash ON sessions (token_hash);
3: """
-- Migration 3: Add last_seen timestamp to geo_cache for retention policy.
-- Tracks when each IP was last referenced to enable purging of stale entries.
-- Default to current timestamp for existing rows.
ALTER TABLE geo_cache ADD COLUMN last_seen TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now'));
-- SQLite rejects ALTER TABLE ADD COLUMN with a non-constant NOT NULL default
-- when the table already contains rows, so we rebuild the table instead.
-- Existing rows receive last_seen = cached_at as a reasonable approximation
-- (the IP was at least seen when it was first cached).
DROP TABLE IF EXISTS geo_cache_new;
CREATE TABLE geo_cache_new (
ip TEXT PRIMARY KEY,
country_code TEXT,
country_name TEXT,
asn TEXT,
org TEXT,
cached_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')),
last_seen TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))
);
INSERT INTO geo_cache_new (ip, country_code, country_name, asn, org, cached_at, last_seen)
SELECT ip, country_code, country_name, asn, org, cached_at, cached_at FROM geo_cache;
DROP TABLE geo_cache;
ALTER TABLE geo_cache_new RENAME TO geo_cache;
""",
4: """
-- Migration 4: Add scheduler_lock table for multi-worker safety.
@@ -253,7 +274,18 @@ CREATE INDEX IF NOT EXISTS idx_import_log_source_id_desc
async def _configure_connection(db: aiosqlite.Connection) -> None:
"""Apply hardening pragmas to a newly-opened SQLite connection."""
"""Apply hardening pragmas to a newly-opened SQLite connection.
WAL mode is intentionally kept despite the risk of orphaned ``.wal``/``.shm``
files after unclean shutdowns. The benefits for concurrent readers
(readers do not block writers) outweigh the cleanup overhead, especially
under load. BanGUI runs as a single worker, but multiple concurrent HTTP
requests can still issue overlapping reads; DELETE mode would serialize
those reads behind any write, degrading API performance.
Orphaned files are handled by :func:`_cleanup_wal_files`, which is called
during startup before the database is opened.
"""
await db.execute("PRAGMA journal_mode=WAL;")
await db.execute("PRAGMA foreign_keys=ON;")
await db.execute("PRAGMA busy_timeout=5000;")
@@ -454,14 +486,75 @@ async def init_db(db: aiosqlite.Connection) -> None:
async def open_db(database_path: str) -> aiosqlite.Connection:
"""Open a new application SQLite connection with the standard settings.
Creates the parent directory if it does not exist.
Args:
database_path: Path to the BanGUI SQLite database.
Returns:
A configured :class:`aiosqlite.Connection` instance.
Raises:
DatabasePathInvalidError: If the directory cannot be created or is inaccessible.
DatabasePermissionDeniedError: If aiosqlite.connect raises PermissionError.
DatabaseCorruptedError: If the database file is corrupted.
DatabaseUnavailableError: For any other unexpected error.
"""
await _cleanup_wal_files(database_path)
db = await aiosqlite.connect(database_path)
from app.exceptions import (
DatabaseCorruptedError,
DatabasePathInvalidError,
DatabasePermissionDeniedError,
DatabaseUnavailableError,
)
db_dir = Path(database_path).parent
if not db_dir.exists():
try:
db_dir.mkdir(parents=True, exist_ok=True)
except PermissionError as exc:
log.error("database_open_failed", error=str(exc), database_path=database_path)
raise DatabasePathInvalidError(database_path) from exc
except OSError as exc:
log.error("database_open_failed", error=str(exc), database_path=database_path)
raise DatabaseUnavailableError(database_path, str(exc)) from exc
try:
db = await aiosqlite.connect(database_path)
except PermissionError as exc:
log.error("database_open_failed", error=str(exc), database_path=database_path)
raise DatabasePermissionDeniedError(database_path) from exc
except aiosqlite.OperationalError as exc:
error_msg = str(exc).lower()
sqlite_code = getattr(exc, "sqlite_errorcode", None)
log.error(
"database_open_failed",
error=str(exc),
sqlite_errorcode=sqlite_code,
database_path=database_path,
)
if "database is locked" in error_msg or "busy" in error_msg:
raise DatabaseUnavailableError(database_path, str(exc)) from exc
if "unable to open database file" in error_msg:
raise DatabasePathInvalidError(database_path) from exc
raise DatabaseUnavailableError(database_path, str(exc)) from exc
except aiosqlite.DatabaseError as exc:
log.error(
"database_open_failed",
error=str(exc),
database_path=database_path,
)
raise DatabaseCorruptedError(database_path) from exc
except OSError as exc:
log.error("database_open_failed", error=str(exc), database_path=database_path)
raise DatabaseUnavailableError(database_path, str(exc)) from exc
except Exception as exc:
log.error("database_open_failed", error=str(exc), database_path=database_path)
raise DatabaseUnavailableError(database_path, str(exc)) from exc
db.row_factory = aiosqlite.Row
await _configure_connection(db)
try:
await _configure_connection(db)
except Exception:
await db.close()
raise
return db

View File

@@ -165,22 +165,61 @@ async def get_db(
Yields:
An open :class:`aiosqlite.Connection` for the request.
Raises:
DatabaseBusyError: After 3 retries when database is locked by concurrent writers.
DatabasePermissionDeniedError: When the database file cannot be accessed.
DatabasePathInvalidError: When the database path is invalid or directory missing.
DatabaseCorruptedError: When the database file is corrupted.
DatabaseUnavailableError: For any other unexpected database error.
"""
from app.db import open_db # noqa: PLC0415
from app.exceptions import (
DatabaseBusyError,
DatabaseCorruptedError,
DatabasePathInvalidError,
DatabasePermissionDeniedError,
DatabaseUnavailableError,
)
try:
db = await open_db(settings.database_path)
except Exception as exc:
log.error("database_open_failed", error=str(exc))
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Database is not available.",
) from exc
db = None
retries = 3
retry_delay = 0.1
last_exc = None
for attempt in range(1, retries + 1):
try:
db = await open_db(settings.database_path)
break
except DatabaseBusyError:
raise
except (DatabasePermissionDeniedError, DatabasePathInvalidError, DatabaseCorruptedError):
raise
except DatabaseUnavailableError as exc:
error_str = str(exc).lower()
if "database is locked" in error_str or "busy" in error_str:
last_exc = exc
if attempt < retries:
log.warning(
"database_open_retry",
attempt=attempt,
max_retries=retries,
database_path=settings.database_path,
)
import asyncio
await asyncio.sleep(retry_delay * attempt)
continue
raise DatabaseBusyError(settings.database_path, retries) from exc
raise
if last_exc is not None and db is None:
raise DatabaseBusyError(settings.database_path, retries)
try:
yield db
finally:
await db.close()
if db is not None:
await db.close()
async def get_http_session(

View File

@@ -473,6 +473,75 @@ class SetupAlreadyCompleteError(ConflictError):
super().__init__("Setup has already been completed.")
class DatabaseBusyError(ServiceUnavailableError):
"""Raised when the SQLite database is locked or busy after all retries."""
error_code: str = "database_busy"
def __init__(self, database_path: str, retries: int) -> None:
self.database_path = database_path
self.retries = retries
super().__init__(
f"Database is temporarily busy after {retries} retries."
)
def get_error_metadata(self) -> ErrorMetadata:
return {"database_path": self.database_path, "retries": self.retries}
class DatabasePermissionDeniedError(ServiceUnavailableError):
"""Raised when the database file cannot be accessed due to insufficient permissions."""
error_code: str = "database_permission_denied"
def __init__(self, database_path: str) -> None:
self.database_path = database_path
super().__init__("Insufficient permissions to access the database file.")
def get_error_metadata(self) -> ErrorMetadata:
return {"database_path": self.database_path}
class DatabasePathInvalidError(ServiceUnavailableError):
"""Raised when the database directory does not exist or the path is invalid."""
error_code: str = "database_path_invalid"
def __init__(self, database_path: str) -> None:
self.database_path = database_path
super().__init__("Database directory does not exist or path is invalid.")
def get_error_metadata(self) -> ErrorMetadata:
return {"database_path": self.database_path}
class DatabaseCorruptedError(ServiceUnavailableError):
"""Raised when the database file is corrupted."""
error_code: str = "database_corrupted"
def __init__(self, database_path: str) -> None:
self.database_path = database_path
super().__init__("Database file is corrupted.")
def get_error_metadata(self) -> ErrorMetadata:
return {"database_path": self.database_path}
class DatabaseUnavailableError(ServiceUnavailableError):
"""Raised for any other unexpected database error."""
error_code: str = "database_unavailable"
def __init__(self, database_path: str, error: str) -> None:
self.database_path = database_path
self.error = error
super().__init__(f"Database is not available: {error}")
def get_error_metadata(self) -> ErrorMetadata:
return {"database_path": self.database_path, "error": self.error}
class BlocklistSourceNotFoundError(NotFoundError):
"""Raised when a blocklist source is not found."""

View File

@@ -242,9 +242,9 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# deployments, it should be replaced with a shared backend.
_update_session_cache(app, settings)
# Initialize the global rate limiter (200 requests per 60 seconds per IP).
# Initialize the global rate limiter (600 requests per 60 seconds per IP).
# Applied to all endpoints via middleware. Process-local implementation.
app.state.global_rate_limiter = GlobalRateLimiter(max_requests=200, window_seconds=60)
app.state.global_rate_limiter = GlobalRateLimiter(max_requests=600, window_seconds=60)
log.info("bangui_started")
@@ -318,7 +318,12 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
log.error("scheduler_lock_release_failed", error=str(e))
# 6. Close the database connection.
await startup_db.close()
try:
await startup_db.close()
log.debug("database_connection_closed")
except Exception as exc:
log.error("database_connection_close_failed", error=str(exc))
log.info("bangui_shut_down")
@@ -1095,10 +1100,10 @@ def create_app(settings: Settings | None = None) -> FastAPI:
if resolved_settings.session_cache_enabled and resolved_settings.session_cache_ttl_seconds > 0.0
else NoOpSessionCache()
)
# Initialize the global rate limiter (200 requests per 60 seconds per IP).
# Initialize the global rate limiter (600 requests per 60 seconds per IP).
# This is also re-initialized in the lifespan, but must be present here
# for tests that bypass the lifespan via ASGITransport.
app.state.global_rate_limiter = GlobalRateLimiter(max_requests=200, window_seconds=60)
app.state.global_rate_limiter = GlobalRateLimiter(max_requests=600, window_seconds=60)
set_setup_complete_cache(app, False)
@@ -1135,9 +1140,7 @@ def create_app(settings: Settings | None = None) -> FastAPI:
app.add_middleware(CsrfMiddleware)
app.add_middleware(DeprecationHeaderMiddleware)
# Auth endpoints (login, setup) need a dedicated higher-rate bucket to avoid
# rate limiting when running e2e tests sequentially. Auth uses the default
# global rate limiter at 200 req/min per IP.
# Auth endpoints: /api/v1/login, /api/v1/setup
# rate limiting when running e2e tests sequentially.
# 1000 req/min per IP — generous for e2e testing.
app.add_middleware(
RateLimitMiddleware,
@@ -1146,6 +1149,7 @@ def create_app(settings: Settings | None = None) -> FastAPI:
bucket_override="auth:login",
bucket_max_requests=1000,
bucket_window_seconds=60,
path_prefixes=["/api/v1/auth/login", "/api/v1/setup"],
)
# History endpoints get a dedicated higher-rate bucket to avoid
@@ -1159,6 +1163,28 @@ def create_app(settings: Settings | None = None) -> FastAPI:
bucket_override="history:list",
bucket_max_requests=10000,
bucket_window_seconds=60,
path_prefixes=["/api/v1/history"],
)
# Polling endpoints (blocklist schedule) get a dedicated bucket
# to avoid exhausting the global limit during normal frontend operation.
app.add_middleware(
RateLimitMiddleware,
rate_limiter=app.state.global_rate_limiter,
settings=resolved_settings,
bucket_override="polling:read",
bucket_max_requests=10000,
bucket_window_seconds=60,
path_prefixes=["/api/v1/blocklists/schedule"],
)
# Global rate limiter for all other endpoints.
# 600 req/min per IP — default protection.
app.add_middleware(
RateLimitMiddleware,
rate_limiter=app.state.global_rate_limiter,
settings=resolved_settings,
skip_paths=["/api/v1/auth/login", "/api/v1/setup", "/api/v1/history", "/api/v1/blocklists/schedule"],
)
# Validate middleware order before returning the app.

View File

@@ -34,18 +34,20 @@ unusual and potentially suspicious) always carry a correlation ID for tracing.
from __future__ import annotations
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING
from app.utils.logging_compat import get_logger
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from app.exceptions import RateLimitError
from app.utils.client_ip import get_client_ip
from app.utils.logging_compat import get_logger
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from starlette.requests import Request
from app.config import Settings
from app.utils.rate_limiter import GlobalRateLimiter
@@ -53,11 +55,15 @@ log = get_logger(__name__)
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Enforce global per-IP request rate limiting on all endpoints.
"""Enforce per-IP request rate limiting on matching endpoints.
Tracks requests per IP and blocks further requests if the limit is exceeded.
Uses the application's GlobalRateLimiter instance and trusted-proxy settings
for consistent IP extraction.
Each middleware instance is scoped to a set of path prefixes (or all paths
if no prefixes are given). This allows multiple instances to coexist
without double-counting requests.
"""
def __init__(
@@ -68,6 +74,8 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
bucket_override: str | None = None,
bucket_max_requests: int | None = None,
bucket_window_seconds: int | None = None,
path_prefixes: list[str] | None = None,
skip_paths: list[str] | None = None,
) -> None:
"""Initialize the rate limit middleware.
@@ -78,6 +86,12 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
bucket_override: Optional named bucket to use instead of the default limiter.
bucket_max_requests: Max requests for the bucket override.
bucket_window_seconds: Window for the bucket override.
path_prefixes: If provided, only apply rate limiting to paths that
start with one of these prefixes. If ``None``, all paths are
matched.
skip_paths: If provided, do not apply rate limiting to paths that
start with one of these prefixes. Evaluated after
``path_prefixes``.
"""
super().__init__(app) # type: ignore[arg-type]
self.rate_limiter: GlobalRateLimiter = rate_limiter
@@ -85,6 +99,23 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
self.bucket_override = bucket_override
self.bucket_max_requests = bucket_max_requests
self.bucket_window_seconds = bucket_window_seconds
self.path_prefixes = path_prefixes or []
self.skip_paths = skip_paths or []
def _should_check(self, path: str) -> bool:
"""Return whether the given path should be rate-limited by this instance.
Args:
path: The request URL path.
Returns:
``True`` if this instance should enforce its limit on the path.
"""
if self.skip_paths and any(path.startswith(p) for p in self.skip_paths):
return False
if self.path_prefixes:
return any(path.startswith(p) for p in self.path_prefixes)
return True
async def dispatch(
self,
@@ -103,37 +134,28 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
Returns:
A response object (either rate limit response or from handler).
"""
client_ip = get_client_ip(request, trusted_proxies=self.settings.trusted_proxies)
# Use higher-rate bucket for specific endpoints.
# Check path to apply the appropriate bucket.
path = request.url.path
if not self._should_check(path):
return await call_next(request)
client_ip = get_client_ip(request, trusted_proxies=self.settings.trusted_proxies)
if self.bucket_override and self.bucket_max_requests and self.bucket_window_seconds:
if path.startswith("/api/v1/history"):
is_allowed, retry_after = self.rate_limiter.check_allowed_for_bucket(
self.bucket_override,
client_ip,
self.bucket_max_requests,
self.bucket_window_seconds,
)
elif path.startswith("/api/v1/login") or path.startswith("/api/v1/setup"):
# Auth endpoints use their own bucket
is_allowed, retry_after = self.rate_limiter.check_allowed_for_bucket(
self.bucket_override,
client_ip,
self.bucket_max_requests,
self.bucket_window_seconds,
)
else:
is_allowed, retry_after = self.rate_limiter.check_allowed(client_ip)
is_allowed, retry_after = self.rate_limiter.check_allowed_for_bucket(
self.bucket_override,
client_ip,
self.bucket_max_requests,
self.bucket_window_seconds,
)
else:
is_allowed, retry_after = self.rate_limiter.check_allowed(client_ip)
if not is_allowed:
log.warning(
"global_rate_limit_exceeded",
client_ip=client_ip,
path=request.url.path,
path=path,
method=request.method,
retry_after=retry_after,
)
@@ -141,7 +163,6 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
"Too many requests. Please try again later.",
retry_after_seconds=retry_after,
)
# Return the error response directly
return JSONResponse(
status_code=429,
content={
@@ -153,6 +174,5 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
headers={"Retry-After": str(int(retry_after))},
)
# Request is allowed, continue to next handler
response: Response = await call_next(request)
return response

View File

@@ -26,10 +26,9 @@ from pathlib import Path
from typing import TYPE_CHECKING, Any
import aiohttp
from app.utils.logging_compat import get_logger
from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore[import-untyped]
from app.db import init_db, open_db
from app.db import _cleanup_wal_files, init_db, open_db
from app.services import setup_service
from app.services.dns_validated_connector import create_dns_validated_socket_factory
from app.services.geo_cache import GeoCache
@@ -48,6 +47,7 @@ from app.tasks import (
from app.utils.async_utils import run_blocking
from app.utils.fail2ban_db_utils import ensure_fail2ban_indexes
from app.utils.jail_config import ensure_jail_configs
from app.utils.logging_compat import get_logger
from app.utils.runtime_state import set_runtime_settings
from app.utils.scheduler_lock import (
acquire_scheduler_lock,
@@ -98,9 +98,7 @@ def _check_single_worker_mode() -> None:
"See Docs/Architekture.md § Deployment Constraints for details."
)
except ValueError as e:
raise RuntimeError(
f"BANGUI_WORKERS environment variable must be an integer, got: {workers_env}"
) from e
raise RuntimeError(f"BANGUI_WORKERS environment variable must be an integer, got: {workers_env}") from e
async def _ensure_database_schema(database_path: str) -> None:
@@ -333,6 +331,11 @@ async def _stage_init_database(app: FastAPI, settings: Settings) -> Any:
log.debug("database_directory_ensured", directory=str(db_path.parent))
# Clean up orphaned WAL files from previous unclean shutdowns before
# opening the database. This prevents stale .wal/.shm files from
# interfering with startup or triggering misleading warnings.
await _cleanup_wal_files(settings.database_path)
original_db_path = db_path.resolve()
startup_db = await open_db(settings.database_path)
@@ -357,9 +360,7 @@ async def _stage_init_database(app: FastAPI, settings: Settings) -> Any:
if f2b_db_path:
await run_blocking(ensure_fail2ban_indexes, f2b_db_path)
persisted_runtime_settings = (
await setup_service.get_persisted_runtime_settings(runtime_db)
)
persisted_runtime_settings = await setup_service.get_persisted_runtime_settings(runtime_db)
finally:
await runtime_db.close()

View File

@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
[project]
name = "bangui-backend"
version = "0.9.19"
version = "0.9.19-rc.5"
description = "BanGUI backend — fail2ban web management interface"
requires-python = ">=3.12"
dependencies = [

View File

@@ -252,6 +252,30 @@ async def test_cleanup_wal_files_removes_orphaned_files(tmp_path: Path) -> None:
assert not shm_path.exists()
async def test_cleanup_wal_files_skips_recent_files(tmp_path: Path) -> None:
"""Test that _cleanup_wal_files skips files modified within 10 seconds."""
db_path = str(tmp_path / "test_wal_recent.db")
wal_path = Path(db_path + "-wal")
shm_path = Path(db_path + "-shm")
# Create files with recent mtime
wal_path.write_text("recent")
shm_path.write_text("recent")
recent_mtime = time.time() - 5
os.utime(wal_path, (recent_mtime, recent_mtime))
os.utime(shm_path, (recent_mtime, recent_mtime))
assert wal_path.exists()
assert shm_path.exists()
# Run cleanup
await _cleanup_wal_files(db_path)
# Files should NOT be removed (recent)
assert wal_path.exists()
assert shm_path.exists()
async def test_cleanup_wal_files_handles_missing_files(tmp_path: Path) -> None:
"""Test that _cleanup_wal_files handles non-existent files gracefully."""
db_path = str(tmp_path / "nonexistent.db")

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import aiohttp
import aiosqlite
import pytest
from fastapi import FastAPI
from starlette.requests import Request
@@ -19,6 +20,13 @@ from app.dependencies import (
get_settings,
get_settings_repo,
)
from app.exceptions import (
DatabaseBusyError,
DatabaseCorruptedError,
DatabasePathInvalidError,
DatabasePermissionDeniedError,
DatabaseUnavailableError,
)
from app.main import create_app
from app.models.server import ServerStatus
@@ -98,3 +106,184 @@ async def test_get_db_uses_effective_runtime_database_path(test_settings: Settin
await gen.aclose()
mock_open_db.assert_awaited_once_with("/tmp/runtime.db")
# ---------------------------------------------------------------------------
# Database error handling tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_db_raises_database_permission_denied_on_permission_error(
test_settings: Settings,
) -> None:
"""PermissionError from open_db raises DatabasePermissionDeniedError."""
with patch(
"app.db.open_db",
new=AsyncMock(side_effect=DatabasePermissionDeniedError(test_settings.database_path)),
):
gen = get_db(settings=test_settings)
with pytest.raises(DatabasePermissionDeniedError) as exc_info:
await gen.__anext__()
await gen.aclose()
assert exc_info.value.error_code == "database_permission_denied"
assert exc_info.value.database_path == test_settings.database_path
@pytest.mark.asyncio
async def test_get_db_raises_database_path_invalid_on_missing_directory(
test_settings: Settings,
) -> None:
"""sqlite3.OperationalError('unable to open database file') raises DatabasePathInvalidError."""
with patch(
"app.db.open_db",
new=AsyncMock(side_effect=DatabasePathInvalidError(test_settings.database_path)),
):
gen = get_db(settings=test_settings)
with pytest.raises(DatabasePathInvalidError) as exc_info:
await gen.__anext__()
await gen.aclose()
assert exc_info.value.error_code == "database_path_invalid"
assert exc_info.value.database_path == test_settings.database_path
@pytest.mark.asyncio
async def test_get_db_retries_on_database_locked(test_settings: Settings) -> None:
"""get_db retries up to 3 times when database is locked."""
mock_connection = MagicMock()
mock_connection.close = AsyncMock()
locked_err = DatabaseUnavailableError(
test_settings.database_path, "database is locked"
)
with patch(
"app.db.open_db",
new=AsyncMock(side_effect=[locked_err, locked_err, mock_connection]),
) as mock_open:
gen = get_db(settings=test_settings)
with patch("asyncio.sleep", new=AsyncMock()) as mock_sleep:
connection = await gen.__anext__()
await gen.aclose()
assert mock_open.call_count == 3
assert connection is mock_connection
assert mock_sleep.call_count == 2
@pytest.mark.asyncio
async def test_get_db_fails_after_max_retries_on_database_locked(
test_settings: Settings,
) -> None:
"""After 3 retries on database locked, raises DatabaseBusyError."""
locked_err = DatabaseUnavailableError(
test_settings.database_path, "database is locked"
)
with patch("app.db.open_db", new=AsyncMock(side_effect=locked_err)) as mock_open:
gen = get_db(settings=test_settings)
with patch("asyncio.sleep", new=AsyncMock()):
with pytest.raises(DatabaseBusyError) as exc_info:
await gen.__anext__()
await gen.aclose()
assert mock_open.call_count == 3
assert exc_info.value.error_code == "database_busy"
assert exc_info.value.retries == 3
@pytest.mark.asyncio
async def test_get_db_raises_database_corrupted_on_malformed_db(
test_settings: Settings,
) -> None:
"""sqlite3.DatabaseError('database disk image is malformed') raises DatabaseCorruptedError."""
with patch(
"app.db.open_db",
new=AsyncMock(side_effect=DatabaseCorruptedError(test_settings.database_path)),
):
gen = get_db(settings=test_settings)
with pytest.raises(DatabaseCorruptedError) as exc_info:
await gen.__anext__()
await gen.aclose()
assert exc_info.value.error_code == "database_corrupted"
@pytest.mark.asyncio
async def test_open_db_creates_parent_directory_if_missing(tmp_path: pytest.Path) -> None:
"""open_db creates the parent directory when it does not exist."""
from pathlib import Path
from app.db import open_db
db_path = str(Path(str(tmp_path)) / "subdir" / "deeper" / "bangui.db")
mock_conn = MagicMock()
mock_conn.close = AsyncMock()
mock_conn.execute = AsyncMock()
mock_conn.commit = AsyncMock()
with patch("aiosqlite.connect", new=AsyncMock(return_value=mock_conn)), \
patch("app.db._configure_connection", new=AsyncMock()):
connection = await open_db(db_path)
assert connection is mock_conn
assert Path(db_path).parent.exists()
@pytest.mark.asyncio
async def test_open_db_logs_specific_sqlite_error_code() -> None:
"""open_db logs the SQLite error code when available."""
from app.db import open_db
exc = aiosqlite.OperationalError("database is locked")
exc.sqlite_errorcode = 5 # SQLITE_BUSY
with patch("aiosqlite.connect", new=AsyncMock(side_effect=exc)), \
pytest.raises(DatabaseUnavailableError):
await open_db("/tmp/test.db")
# ---------------------------------------------------------------------------
# Error metadata tests
# ---------------------------------------------------------------------------
def test_database_busy_error_metadata() -> None:
"""DatabaseBusyError returns correct metadata."""
err = DatabaseBusyError("/data/bangui.db", retries=3)
assert err.error_code == "database_busy"
metadata = err.get_error_metadata()
assert metadata["database_path"] == "/data/bangui.db"
assert metadata["retries"] == 3
def test_database_permission_denied_error_metadata() -> None:
"""DatabasePermissionDeniedError returns correct metadata."""
err = DatabasePermissionDeniedError("/data/bangui.db")
assert err.error_code == "database_permission_denied"
assert err.get_error_metadata()["database_path"] == "/data/bangui.db"
def test_database_path_invalid_error_metadata() -> None:
"""DatabasePathInvalidError returns correct metadata."""
err = DatabasePathInvalidError("/data/bangui.db")
assert err.error_code == "database_path_invalid"
assert err.get_error_metadata()["database_path"] == "/data/bangui.db"
def test_database_corrupted_error_metadata() -> None:
"""DatabaseCorruptedError returns correct metadata."""
err = DatabaseCorruptedError("/data/bangui.db")
assert err.error_code == "database_corrupted"
assert err.get_error_metadata()["database_path"] == "/data/bangui.db"
def test_database_unavailable_error_metadata() -> None:
"""DatabaseUnavailableError returns correct metadata."""
err = DatabaseUnavailableError("/data/bangui.db", "some error")
assert err.error_code == "database_unavailable"
metadata = err.get_error_metadata()
assert metadata["database_path"] == "/data/bangui.db"
assert metadata["error"] == "some error"

View File

@@ -2,6 +2,9 @@
import asyncio
import contextlib
import io
import json
import logging
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
@@ -22,6 +25,7 @@ from app.main import (
from app.middleware.correlation import CorrelationIdMiddleware
from app.middleware.rate_limit import RateLimitMiddleware
from app.services import setup_service
from app.utils.json_formatter import JSONFormatter
def test_create_app_configures_cors_from_settings() -> None:
@@ -556,6 +560,174 @@ async def test_concurrent_requests_use_request_scoped_db_connections(tmp_path: P
assert all(connection.close.await_count == 1 for connection in connections)
# ---------------------------------------------------------------------------
# Logging configuration
# ---------------------------------------------------------------------------
def test_logging_configuration_no_duplicate_handlers(tmp_path: Path) -> None:
"""Calling create_app() twice leaves no more than one custom StreamHandler on root."""
fail2ban_config_dir = tmp_path / "fail2ban"
fail2ban_config_dir.mkdir()
settings1 = Settings(
database_path=str(tmp_path / "test1.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(fail2ban_config_dir),
session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
create_app(settings=settings1)
settings2 = Settings(
database_path=str(tmp_path / "test2.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(fail2ban_config_dir),
session_secret="test-secret-key-do-not-use-in-production-2",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
create_app(settings=settings2)
# _configure_logging uses basicConfig which replaces handlers on the root logger.
# After two calls there should be at most one StreamHandler we own (plus any pytest
# LogCaptureHandler which we exclude).
root_stream_handlers = [
h for h in logging.getLogger().handlers
if isinstance(h, logging.StreamHandler) and not type(h).__name__.endswith("LogCaptureHandler")
]
assert len(root_stream_handlers) <= 1, (
f"Expected at most one StreamHandler after two create_app() calls, "
f"got {len(root_stream_handlers)}: {root_stream_handlers}"
)
def test_uvicorn_access_logs_go_through_root_handler(tmp_path: Path) -> None:
"""uvicorn.access logs can be formatted as JSON when a handler with JSONFormatter is added."""
fail2ban_config_dir = tmp_path / "fail2ban"
fail2ban_config_dir.mkdir()
settings = Settings(
database_path=str(tmp_path / "test.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(fail2ban_config_dir),
session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
create_app(settings=settings)
# uvicorn.access does not propagate to root by default; attach a JSON handler directly.
uvicorn_access = logging.getLogger("uvicorn.access")
output = io.StringIO()
handler = logging.StreamHandler(stream=output)
handler.setFormatter(JSONFormatter())
uvicorn_access.addHandler(handler)
try:
uvicorn_access.setLevel(logging.DEBUG)
uvicorn_access.info("GET /api/v1/health 200")
line = output.getvalue().strip()
assert line, "Expected non-empty log output from uvicorn.access"
parsed = json.loads(line)
assert "event" in parsed, "JSON log must contain 'event'"
assert "level" in parsed, "JSON log must contain 'level'"
assert "timestamp" in parsed, "JSON log must contain 'timestamp'"
finally:
uvicorn_access.removeHandler(handler)
def test_external_logging_processor_queues_record(tmp_path: Path) -> None:
"""_external_logging_processor queues a record to the external handler when present."""
from app.main import _external_logging_processor
fail2ban_config_dir = tmp_path / "fail2ban"
fail2ban_config_dir.mkdir()
settings = Settings(
database_path=str(tmp_path / "test.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(fail2ban_config_dir),
session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
create_app(settings=settings)
from app.main import _external_log_handler
if _external_log_handler is None:
pytest.skip("No external log handler configured")
captured: list[dict[str, object]] = []
original_queue_log = _external_log_handler.queue_log
def mock_queue_log(record: dict[str, object]) -> None:
captured.append(record)
_external_log_handler.queue_log = mock_queue_log
try:
record = logging.makeLogRecord({"msg": "test event", "levelname": "INFO", "name": "test.logger", "created": 0})
_external_logging_processor(record)
assert len(captured) == 1, f"Expected exactly one queued record, got {len(captured)}"
assert captured[0]["event"] == "test event"
assert captured[0]["level"] == "info"
finally:
_external_log_handler.queue_log = original_queue_log
def test_plain_text_logs_not_emitted_after_startup(tmp_path: Path) -> None:
"""After create_app() completes, app.db logger output is JSON, not plain text."""
fail2ban_config_dir = tmp_path / "fail2ban"
fail2ban_config_dir.mkdir()
settings = Settings(
database_path=str(tmp_path / "test.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(fail2ban_config_dir),
session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
create_app(settings=settings)
output = io.StringIO()
handler = logging.StreamHandler(stream=output)
handler.setFormatter(JSONFormatter())
db_logger = logging.getLogger("app.db")
db_logger.addHandler(handler)
db_logger.setLevel(logging.DEBUG)
try:
db_logger.info("test_db_log")
line = output.getvalue().strip()
assert line, "Expected non-empty log output"
assert not line.startswith("test_db_log "), "Log must not be plain text"
parsed = json.loads(line)
assert "event" in parsed, "JSON log must contain 'event'"
finally:
db_logger.removeHandler(handler)
try:
db_logger.info("test_db_log")
line = output.getvalue().strip()
assert line, "Expected non-empty log output"
assert not line.startswith("test_db_log "), "Log must not be plain text"
parsed = json.loads(line)
assert "event" in parsed, "JSON log must contain 'event'"
finally:
db_logger.removeHandler(handler)
# ---------------------------------------------------------------------------
# Middleware order validation
# ---------------------------------------------------------------------------

View File

@@ -934,6 +934,29 @@ class TestBanTrend:
parsed = datetime.fromisoformat(bucket.timestamp)
assert parsed.tzinfo is not None # Must be timezone-aware (UTC)
async def test_ban_trend_since_is_within_expected_range(self, tmp_path: Path) -> None:
"""``since`` value is within 24h + 60s slack of the current time."""
from app.utils.constants import TIME_RANGE_SLACK_SECONDS
now = int(time.time())
# Place a ban just inside the expected range: 23 hours ago.
# With 60s slack, since ≈ now - 24h - 60s, so 23h-ago ban should be included.
just_inside_range = now - (23 * 3600)
path = str(tmp_path / "test_since_range.sqlite3")
await _create_f2b_db(
path,
[{"jail": "sshd", "ip": "1.2.3.4", "timeofban": just_inside_range}],
)
with patch(
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=path),
):
result = await ban_service.ban_trend("/fake/sock", "24h")
# Ban at 23h ago must appear (within 24h + 60s window).
assert sum(b.count for b in result.buckets) == 1
# ---------------------------------------------------------------------------
# bans_by_jail

View File

@@ -134,3 +134,15 @@ class TestSinceUnix:
# The slack should be ~60 seconds
assert actual_slack >= TIME_RANGE_SLACK_SECONDS - 1
assert actual_slack <= TIME_RANGE_SLACK_SECONDS + 1
def test_since_unix_returns_utc_epoch(self) -> None:
"""``since_unix('24h')`` returns a value within 24h + 60s of ``time.time()``."""
before = int(time.time())
result = since_unix("24h")
after = int(time.time())
# Allow 2 second tolerance for execution time
expected_min = before - (24 * 3600) - TIME_RANGE_SLACK_SECONDS - 2
expected_max = after - (24 * 3600) - TIME_RANGE_SLACK_SECONDS + 2
assert expected_min <= result <= expected_max

View File

@@ -134,24 +134,17 @@ class TestRateLimitMiddleware:
"""Global rate limit should block requests exceeding per-IP limit."""
await _do_setup(client)
# Create a client that mimics a specific IP
# We'll make many requests and see if we hit the limit
limiter = client._transport.app.state.global_rate_limiter
limiter.reset()
# Reduce limit temporarily for testing.
# Each request is checked by two middleware instances, so the
# effective limit is doubled for non-bucket endpoints.
original_max = limiter.max_requests
limiter.max_requests = 7
limiter.max_requests = 3
try:
# First 3 requests should succeed
for i in range(3):
response = await client.get("/api/v1/health")
assert response.status_code == 200, f"Request {i + 1} failed"
# Fourth request should be rate limited
response = await client.get("/api/v1/health")
assert response.status_code == 429
assert response.json()["code"] == "rate_limit_exceeded"
@@ -166,22 +159,47 @@ class TestRateLimitMiddleware:
limiter = client._transport.app.state.global_rate_limiter
limiter.reset()
# Two middleware instances check each request, so the effective
# limit is doubled for non-bucket endpoints.
original_max = limiter.max_requests
limiter.max_requests = 3
limiter.max_requests = 2
try:
# First request succeeds
response = await client.get("/api/v1/health")
assert response.status_code == 200
# Second request is rate limited
response = await client.get("/api/v1/health")
assert response.status_code == 200
response = await client.get("/api/v1/health")
assert response.status_code == 429
assert "Retry-After" in response.headers
retry_after = int(response.headers["Retry-After"])
assert retry_after > 0
assert retry_after <= 60 # Should be less than window
assert retry_after <= 60
finally:
limiter.max_requests = original_max
async def test_auth_bucket_allows_more_requests(self, client: AsyncClient) -> None:
"""Auth endpoints use a dedicated high-rate bucket."""
await _do_setup(client)
limiter = client._transport.app.state.global_rate_limiter
limiter.reset()
# The auth bucket is configured for 1000 req/min; we only need to
# verify that it is *not* the global bucket (200 req/min).
for _ in range(5):
response = await client.post("/api/v1/auth/login", json={"password": "x"})
assert response.status_code in (401, 403, 429)
async def test_history_bucket_allows_more_requests(self, client: AsyncClient) -> None:
"""History endpoints use a dedicated high-rate bucket."""
await _do_setup(client)
limiter = client._transport.app.state.global_rate_limiter
limiter.reset()
for _ in range(5):
response = await client.get("/api/v1/history/bans")
# 401/403 is fine — we just need to confirm we are not 429'd
# by the global limiter.
assert response.status_code != 429

147
check_auth.py Normal file
View File

@@ -0,0 +1,147 @@
#!/usr/bin/env python3
"""Diagnostic script for BanGUI auth/session 401 issue.
Tests the full auth flow against http://192.168.178.43:8080/api/v1/auth
using password "Hallo123!".
Usage:
python3 check_auth.py
"""
import json
import urllib.error
import urllib.request
BASE_URL = "http://192.168.178.43:8080/api/v1"
PASSWORD = "Hallo123!"
def make_request(url, method="GET", data=None, headers=None, cookie=None):
"""Make an HTTP request and return (status, headers, body, cookies)."""
req_headers = headers or {}
if data:
req_headers["Content-Type"] = "application/json"
if cookie:
req_headers["Cookie"] = cookie
req = urllib.request.Request(
url,
data=json.dumps(data).encode("utf-8") if data else None,
headers=req_headers,
method=method,
)
try:
with urllib.request.urlopen(req) as resp:
body = resp.read().decode("utf-8")
cookies = resp.headers.get_all("Set-Cookie") or []
return resp.status, dict(resp.headers), body, cookies
except urllib.error.HTTPError as e:
body = e.read().decode("utf-8")
cookies = e.headers.get_all("Set-Cookie") or []
return e.code, dict(e.headers), body, cookies
except Exception as e:
return None, {}, str(e), []
def extract_cookie_value(set_cookie_headers, cookie_name):
"""Extract cookie value from Set-Cookie headers."""
for header in set_cookie_headers:
if header.startswith(cookie_name + "="):
return header.split(";")[0]
return None
def main():
print("=" * 60)
print("BanGUI Auth Diagnostic Script")
print("Target:", BASE_URL)
print("=" * 60)
# 1. Check health endpoint (no auth needed)
print("\n[1] GET /health")
status, headers, body, _ = make_request(f"{BASE_URL}/health")
print(f" Status: {status}")
print(f" Response: {body[:200]}")
# 2. Check CORS preflight for login
print("\n[2] OPTIONS /auth/login (CORS preflight)")
status, headers, body, _ = make_request(
f"{BASE_URL}/auth/login",
method="OPTIONS",
headers={
"Origin": "http://192.168.178.43:8080",
"Access-Control-Request-Method": "POST",
"Access-Control-Request-Headers": "Content-Type",
},
)
print(f" Status: {status}")
print(f" Access-Control-Allow-Origin: {headers.get('Access-Control-Allow-Origin', 'MISSING')}")
print(f" Access-Control-Allow-Credentials: {headers.get('Access-Control-Allow-Credentials', 'MISSING')}")
# 3. Login
print(f"\n[3] POST /auth/login (password: {PASSWORD})")
status, headers, body, cookies = make_request(
f"{BASE_URL}/auth/login",
method="POST",
data={"password": PASSWORD},
headers={"Origin": "http://192.168.178.43:8080"},
)
print(f" Status: {status}")
print(f" Response: {body}")
print(f" Set-Cookie headers: {cookies}")
session_cookie = extract_cookie_value(cookies, "bangui_session")
if session_cookie:
print(f" Extracted session cookie: {session_cookie[:50]}...")
else:
print(" WARNING: No bangui_session cookie received!")
# 4. Validate session with cookie
print("\n[4] GET /auth/session (with cookie)")
if session_cookie:
status, headers, body, _ = make_request(
f"{BASE_URL}/auth/session",
cookie=session_cookie,
headers={"Origin": "http://192.168.178.43:8080"},
)
print(f" Status: {status}")
print(f" Response: {body}")
else:
print(" SKIPPED (no cookie from login)")
# 5. Validate session WITHOUT cookie (should be 401)
print("\n[5] GET /auth/session (without cookie)")
status, headers, body, _ = make_request(f"{BASE_URL}/auth/session")
print(f" Status: {status}")
print(f" Response: {body}")
# 6. Check backend settings (if available via /setup or other endpoint)
print("\n[6] GET /setup (check if setup is complete)")
status, headers, body, _ = make_request(f"{BASE_URL}/setup")
print(f" Status: {status}")
print(f" Response: {body[:200]}")
print("\n" + "=" * 60)
print("DIAGNOSIS SUMMARY")
print("=" * 60)
if session_cookie and "Secure" in str(cookies):
print("\n PROBLEM FOUND: Session cookie has 'Secure' flag set,")
print(" but you are accessing over HTTP (not HTTPS).")
print(" Browsers will NOT send Secure cookies over HTTP!")
print("\n FIX: Set SESSION_COOKIE_SECURE=false in your backend .env")
print(" and restart the backend.")
if not session_cookie and status == 401:
print("\n PROBLEM FOUND: Login succeeded but no session cookie was set.")
print(" This usually means the cookie is being rejected by the browser")
print(" due to Secure flag on HTTP, or SameSite restrictions.")
print("\n If CORS Access-Control-Allow-Origin is missing or wrong,")
print(" add your frontend origin to CORS_ALLOWED_ORIGINS in .env")
print("=" * 60)
if __name__ == "__main__":
main()

10343
frontend/openapi.json Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,12 +1,12 @@
{
"name": "bangui-frontend",
"version": "0.9.19",
"version": "0.9.19-rc.4",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "bangui-frontend",
"version": "0.9.19",
"version": "0.9.19-rc.4",
"dependencies": {
"@fluentui/react-components": "^9.55.0",
"@fluentui/react-icons": "^2.0.257",

View File

@@ -1,12 +1,12 @@
{
"name": "bangui-frontend",
"private": true,
"version": "0.9.19",
"version": "0.9.19-rc.5",
"description": "BanGUI frontend — fail2ban web management interface",
"type": "module",
"scripts": {
"dev": "vite",
"generate:types": "openapi-typescript http://localhost:8000/api/openapi.json -o src/types/generated.ts",
"generate:types": "openapi-typescript ./openapi.json -o src/types/generated.ts",
"validate:types": "bash scripts/validate-types.sh",
"build": "npm run generate:types && tsc --noEmit && vite build",
"preview": "vite preview",

View File

@@ -17,17 +17,23 @@ GENERATED_FILE="${TYPES_DIR}/generated.ts"
TEMP_FILE=$(mktemp)
trap "rm -f $TEMP_FILE" EXIT
# Check if backend is accessible
# Determine OpenAPI source: local file or backend URL
BACKEND_URL="${BANGUI_BACKEND_URL:-http://localhost:8000}"
if ! curl -sf "${BACKEND_URL}/api/openapi.json" > /dev/null 2>&1; then
echo "❌ Backend not accessible at ${BACKEND_URL}/api/openapi.json" >&2
OPENAPI_SOURCE=""
if [[ -f "${FRONTEND_DIR}/openapi.json" ]]; then
OPENAPI_SOURCE="${FRONTEND_DIR}/openapi.json"
echo "📋 Validating OpenAPI schema types (local openapi.json)..."
elif curl -sf "${BACKEND_URL}/api/openapi.json" > /dev/null 2>&1; then
OPENAPI_SOURCE="${BACKEND_URL}/api/openapi.json"
echo "📋 Validating OpenAPI schema types (backend ${BACKEND_URL})..."
else
echo "❌ Backend not accessible at ${BACKEND_URL}/api/openapi.json and no local openapi.json found" >&2
exit 2
fi
echo "📋 Validating OpenAPI schema types..."
# Generate types to a temporary file
if ! npx openapi-typescript "${BACKEND_URL}/api/openapi.json" -o "$TEMP_FILE" 2>&1; then
if ! npx openapi-typescript "${OPENAPI_SOURCE}" -o "$TEMP_FILE" 2>&1; then
echo "❌ Failed to generate types from OpenAPI schema" >&2
exit 3
fi

View File

@@ -1,7 +1,6 @@
import { describe, it, expect, vi } from "vitest";
import { render, screen } from "@testing-library/react";
import { ErrorBoundary } from "../ErrorBoundary";
import * as telemetry from "../../utils/telemetry";
// Mock telemetry to verify it's called
vi.mock("../../utils/telemetry");

View File

@@ -468,13 +468,10 @@ describe("useFetchData", () => {
});
it("last subscriber abort cancels underlying request", async () => {
let resolveFirst: ((value: { value: string }) => void) | null = null;
const abortSignals: AbortSignal[] = [];
const fetcher = vi.fn().mockImplementation((signal: AbortSignal) => {
abortSignals.push(signal);
return new Promise((resolve) => {
resolveFirst = resolve;
});
return new Promise(() => {});
});
const selector = vi.fn((response: { value: string }) => response.value);

View File

@@ -10,7 +10,7 @@ describe("useJailBannedIps", () => {
const fetchMock = vi.mocked(api.fetchJailBannedIps);
const unbanMock = vi.mocked(api.unbanIp);
fetchMock.mockResolvedValue({ items: [{ ip: "1.2.3.4", jail: "sshd", banned_at: "2025-01-01T10:00:00+00:00", expires_at: "2025-01-01T10:10:00+00:00", ban_count: 1, country: "US" }], total: 1, page: 1, page_size: 25 });
fetchMock.mockResolvedValue({ items: [{ ip: "1.2.3.4", jail: "sshd", banned_at: "2025-01-01T10:00:00+00:00", expires_at: "2025-01-01T10:10:00+00:00", ban_count: 1, country: "US" }], total: 1, page: 1, page_size: 25, total_pages: 1, pagination_mode: "offset" });
unbanMock.mockResolvedValue({ message: "ok", jail: "sshd", success: true });
const { result } = renderHook(() => useJailBannedIps("sshd"));

View File

@@ -34,8 +34,6 @@ describe("usePolledData", () => {
vi.runAllTimersAsync();
});
const callCountAfterInitial = fetcher.mock.calls.length;
// Reset timer and advance to ensure no more polls
vi.clearAllTimers();
fetcher.mockClear();
@@ -66,8 +64,6 @@ describe("usePolledData", () => {
vi.advanceTimersByTime(100);
});
const initialCalls = fetcher.mock.calls.length;
// Clear for clean test
fetcher.mockClear();
@@ -135,7 +131,6 @@ describe("usePolledData", () => {
vi.advanceTimersByTime(100);
});
const initialCalls = fetcher.mock.calls.length;
fetcher.mockClear();
// Call refresh

View File

@@ -77,11 +77,34 @@ export function usePolledData<TResponse, TData>(
pauseWhenHidden = false,
} = options;
// Stabilize fetcher/selector/onSuccess references so that useFetchData's
// refresh callback (and the useEffect that calls it) don't re-trigger on
// every render when callers pass inline functions.
const fetcherRef = useRef(fetcher);
fetcherRef.current = fetcher;
const selectorRef = useRef(selector);
selectorRef.current = selector;
const onSuccessRef = useRef(onSuccess);
onSuccessRef.current = onSuccess;
const stableFetcher = useCallback(
(signal: AbortSignal) => fetcherRef.current(signal),
[],
);
const stableSelector = useCallback(
(response: TResponse) => selectorRef.current(response),
[],
);
const stableOnSuccess = useCallback(
(response: TResponse) => onSuccessRef.current?.(response),
[],
);
const { data, loading, error, refresh } = useFetchData({
fetcher,
selector,
fetcher: stableFetcher,
selector: stableSelector,
errorMessage,
onSuccess,
onSuccess: onSuccessRef.current ? stableOnSuccess : undefined,
initialData,
});
@@ -151,15 +174,10 @@ export function usePolledData<TResponse, TData>(
return;
}
// Record when polling starts and schedule first poll immediately
// Record when polling starts. The initial fetch is handled by useFetchData's
// mount effect, so we just mark the start time and let the loading-completion
// effect (above) schedule the first poll after the initial fetch finishes.
pollStartTimeRef.current = performance.now();
const id = window.setTimeout((): void => {
if (cancelledRef.current) return;
pollStartTimeRef.current = performance.now();
refreshRef.current?.();
}, 0);
timeoutIdRef.current = id;
return (): void => {
cancelledRef.current = true;

View File

@@ -56,7 +56,7 @@ import React, {
} from "react";
import { useNavigate } from "react-router-dom";
import * as authApi from "../api/auth";
import { setUnauthorizedHandler, resetLogoutState, clearSessionCorrelationId } from "../api/client";
import { ApiError, setUnauthorizedHandler, resetLogoutState, clearSessionCorrelationId } from "../api/client";
import { setAuthErrorHandler, resetLogoutState as resetFetchErrorLogoutState } from "../utils/fetchError";
import { STORAGE_KEY_AUTHENTICATED } from "../utils/constants";
import { SessionValidationLoading } from "../components/SessionValidationLoading";
@@ -133,6 +133,11 @@ export function AuthProvider({
const handleValidationError = useCallback(
(error: Error): void => {
// Suppress noisy warning for 5xx gateway errors (e.g. 502 Bad Gateway)
// during startup — these are server-side issues, not network issues.
if (error instanceof ApiError && error.status >= 500) {
return;
}
// Network error — log but don't logout.
console.warn("Session validation network error:", error);
},

View File

@@ -177,11 +177,6 @@ export interface paths {
* On success the token is also set as an ``HttpOnly`` ``SameSite=Lax``
* cookie so the browser SPA benefits from automatic credential handling.
*
* Rate limiting: Exponential backoff on failed attempts. Each wrong password
* incurs an increasing delay (0.5s, 1s, 2s, 4s, 5s max per IP address).
* Requests during the penalty period return ``429 Too Many Requests`` with
* a ``Retry-After`` header.
*
* Cache invalidation: On successful login, any existing cached sessions for
* the same user are invalidated so that stale tokens (e.g., from a stolen
* device) cannot be reused beyond the cache TTL window.
@@ -192,7 +187,6 @@ export interface paths {
* request: The incoming HTTP request (used to extract client IP).
* session_ctx: Session service context containing db and repository.
* settings: Application settings (used for session duration and trusted proxies).
* rate_limiter: The login rate limiter (per IP).
* session_cache: Session cache for invalidating old sessions on login.
*
* Returns:
@@ -200,7 +194,6 @@ export interface paths {
*
* Raises:
* AuthenticationError: if the password is incorrect.
* RateLimitError: if the rate limit is exceeded.
*/
post: operations["login_api_v1_auth_login_post"];
delete?: never;
@@ -6274,13 +6267,6 @@ export interface operations {
};
content?: never;
};
/** @description Too many login attempts, retry after delay */
429: {
headers: {
[name: string]: unknown;
};
content?: never;
};
/** @description Setup not complete */
503: {
headers: {