Compare commits
14 Commits
48ef85bec5
...
v0.9.19-rc
| Author | SHA1 | Date | |
|---|---|---|---|
| dcee222a41 | |||
| 12fe70d768 | |||
| 83b2cb67b1 | |||
| 7308ff88d6 | |||
| 77df5d5d65 | |||
| 96ce516ecf | |||
| 7ec80fdeec | |||
| 7790736918 | |||
| 79df1aa493 | |||
| cc9d3220c9 | |||
| 8fc1989cc4 | |||
| aa717a28f8 | |||
| e4c3ae718c | |||
| d4bab89cf3 |
15
.gitignore
vendored
15
.gitignore
vendored
@@ -95,17 +95,7 @@ Thumbs.db
|
||||
# ── Docker dev config ─────────────────────────
|
||||
# Ignore auto-generated linuxserver/fail2ban config files,
|
||||
# but track our custom filter, jail, and documentation.
|
||||
Docker/fail2ban-dev-config/**
|
||||
!Docker/fail2ban-dev-config/README.md
|
||||
!Docker/fail2ban-dev-config/fail2ban/
|
||||
!Docker/fail2ban-dev-config/fail2ban/filter.d/
|
||||
!Docker/fail2ban-dev-config/fail2ban/filter.d/bangui-sim.conf
|
||||
!Docker/fail2ban-dev-config/fail2ban/filter.d/bangui-access.conf
|
||||
!Docker/fail2ban-dev-config/fail2ban/jail.d/
|
||||
!Docker/fail2ban-dev-config/fail2ban/jail.d/bangui-sim.conf
|
||||
!Docker/fail2ban-dev-config/fail2ban/jail.d/bangui-access.conf
|
||||
!Docker/fail2ban-dev-config/fail2ban/jail.d/blocklist-import.conf
|
||||
!Docker/fail2ban-dev-config/fail2ban/jail.local
|
||||
data/*
|
||||
|
||||
# ── Misc ──────────────────────────────────────
|
||||
*.log
|
||||
@@ -115,3 +105,6 @@ Docker/fail2ban-dev-config/**
|
||||
|
||||
# ── E2E test results ───────────────────────────
|
||||
e2e/results/
|
||||
e2e/Instructions.md
|
||||
|
||||
playwright-log.txt
|
||||
|
||||
@@ -18,7 +18,7 @@ WORKDIR /build
|
||||
COPY frontend/package.json frontend/package-lock.json* /build/
|
||||
RUN npm ci --ignore-scripts
|
||||
|
||||
# Copy source and build
|
||||
# Copy source + local OpenAPI spec (avoids needing a running backend during build)
|
||||
COPY frontend/ /build/
|
||||
RUN npm run build
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
v0.9.19
|
||||
v0.9.19-rc.2
|
||||
|
||||
@@ -34,7 +34,7 @@ services:
|
||||
- ../data/fail2ban-dev-config:/config
|
||||
- fail2ban-dev-run:/var/run/fail2ban
|
||||
- /var/log:/var/log:ro
|
||||
- ./logs:/remotelogs/bangui
|
||||
- ../data/log:/remotelogs/bangui
|
||||
healthcheck:
|
||||
test: ["CMD", "fail2ban-client", "ping"]
|
||||
interval: 15s
|
||||
@@ -58,6 +58,7 @@ services:
|
||||
BANGUI_DATABASE_PATH: "/data/bangui.db"
|
||||
BANGUI_FAIL2BAN_SOCKET: "/var/run/fail2ban/fail2ban.sock"
|
||||
BANGUI_FAIL2BAN_CONFIG_DIR: "/config/fail2ban"
|
||||
BANGUI_LOG_FILE: "/data/log/bangui.log"
|
||||
BANGUI_LOG_LEVEL: "debug"
|
||||
BANGUI_ENABLE_DOCS: "true"
|
||||
BANGUI_SESSION_SECRET: "${BANGUI_SESSION_SECRET:?BANGUI_SESSION_SECRET must be set — generate with: python -c 'import secrets; print(secrets.token_hex(32))'}"
|
||||
@@ -70,11 +71,9 @@ services:
|
||||
volumes:
|
||||
- ../backend/app:/app/app:z
|
||||
- ../fail2ban-master:/app/fail2ban-master:ro,z
|
||||
- ../data/data:/data
|
||||
- ../data:/data
|
||||
- fail2ban-dev-run:/var/run/fail2ban:ro
|
||||
- ../data/fail2ban-dev-config:/config:rw
|
||||
ports:
|
||||
- "${BANGUI_BACKEND_PORT:-8000}:8000"
|
||||
command:
|
||||
[
|
||||
"uvicorn", "app.main:create_app", "--factory",
|
||||
@@ -87,8 +86,7 @@ services:
|
||||
timeout: 5s
|
||||
start_period: 45s
|
||||
retries: 5
|
||||
networks:
|
||||
- bangui-dev-net
|
||||
network_mode: host
|
||||
|
||||
# ── Frontend (Vite dev server with HMR) ─────────────────────
|
||||
frontend:
|
||||
@@ -98,23 +96,15 @@ services:
|
||||
working_dir: /app
|
||||
environment:
|
||||
NODE_ENV: development
|
||||
VITE_BACKEND_URL: "http://localhost:8000"
|
||||
volumes:
|
||||
- ../frontend:/app:z
|
||||
- frontend-node-modules:/app/node_modules
|
||||
ports:
|
||||
- "${BANGUI_FRONTEND_PORT:-5173}:5173"
|
||||
command: ["sh", "-c", "npm install && npm run dev -- --host 0.0.0.0"]
|
||||
depends_on:
|
||||
backend:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD", "wget", "-qO", "/dev/null", "http://localhost:5173/"]
|
||||
interval: 15s
|
||||
timeout: 5s
|
||||
start_period: 30s
|
||||
retries: 5
|
||||
networks:
|
||||
- bangui-dev-net
|
||||
network_mode: host
|
||||
|
||||
volumes:
|
||||
bangui-dev-data:
|
||||
|
||||
@@ -1,115 +0,0 @@
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# BanGUI — Production Compose
|
||||
#
|
||||
# Compatible with:
|
||||
# docker compose -f Docker/compose.prod.yml up -d
|
||||
# podman compose -f Docker/compose.prod.yml up -d
|
||||
# podman-compose -f Docker/compose.prod.yml up -d
|
||||
#
|
||||
# Prerequisites:
|
||||
# Create a .env file at the project root (or pass --env-file):
|
||||
# BANGUI_SESSION_SECRET=<random-secret>
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
name: bangui
|
||||
|
||||
services:
|
||||
# ── fail2ban ─────────────────────────────────────────────────
|
||||
fail2ban:
|
||||
image: lscr.io/linuxserver/fail2ban:latest
|
||||
container_name: bangui-fail2ban
|
||||
restart: unless-stopped
|
||||
cap_add:
|
||||
- NET_ADMIN
|
||||
- NET_RAW
|
||||
network_mode: host
|
||||
environment:
|
||||
TZ: "${BANGUI_TIMEZONE:-UTC}"
|
||||
PUID: 0
|
||||
PGID: 0
|
||||
volumes:
|
||||
- fail2ban-config:/config
|
||||
- fail2ban-run:/var/run/fail2ban
|
||||
- /var/log:/var/log:ro
|
||||
healthcheck:
|
||||
test: ["CMD", "fail2ban-client", "ping"]
|
||||
interval: 30s
|
||||
timeout: 5s
|
||||
start_period: 15s
|
||||
retries: 3
|
||||
# NOTE: The fail2ban-config volume must be pre-populated with the following files:
|
||||
# • fail2ban/jail.conf (or jail.d/*.conf) with the DEFAULT section containing:
|
||||
# banaction = iptables-allports[lockingopt="-w 5"]
|
||||
# This prevents xtables lock contention errors when multiple jails start in parallel.
|
||||
# See https://fail2ban.readthedocs.io/en/latest/development/environment.html
|
||||
|
||||
# ── Backend (FastAPI + uvicorn) ─────────────────────────────
|
||||
backend:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: Docker/Dockerfile.backend
|
||||
container_name: bangui-backend
|
||||
restart: unless-stopped
|
||||
stop_grace_period: 30s
|
||||
depends_on:
|
||||
fail2ban:
|
||||
condition: service_healthy
|
||||
environment:
|
||||
BANGUI_DATABASE_PATH: "/data/bangui.db"
|
||||
BANGUI_FAIL2BAN_SOCKET: "/var/run/fail2ban/fail2ban.sock"
|
||||
BANGUI_FAIL2BAN_CONFIG_DIR: "/config/fail2ban"
|
||||
BANGUI_LOG_LEVEL: "info"
|
||||
# ⚠️ BANGUI_WORKERS MUST be 1 — see session_cache.py docstring for details
|
||||
# BanGUI uses a process-local session cache. Multiple workers in a single process
|
||||
# would cause users to be randomly logged out as sessions wouldn't be shared.
|
||||
# For HA, run multiple BanGUI instances (each with --workers 1) via orchestration.
|
||||
BANGUI_WORKERS: "1"
|
||||
BANGUI_SESSION_SECRET: "${BANGUI_SESSION_SECRET:?Set BANGUI_SESSION_SECRET}"
|
||||
BANGUI_TIMEZONE: "${BANGUI_TIMEZONE:-UTC}"
|
||||
volumes:
|
||||
- bangui-data:/data
|
||||
- fail2ban-run:/var/run/fail2ban:ro
|
||||
- fail2ban-config:/config:rw
|
||||
expose:
|
||||
- "8000"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/api/v1/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 40s
|
||||
networks:
|
||||
- bangui-net
|
||||
|
||||
# ── Frontend (nginx serving built SPA + API proxy) ──────────
|
||||
frontend:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: Docker/Dockerfile.frontend
|
||||
container_name: bangui-frontend
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "${BANGUI_PORT:-8080}:80"
|
||||
depends_on:
|
||||
backend:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:80/"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 5s
|
||||
networks:
|
||||
- bangui-net
|
||||
|
||||
volumes:
|
||||
bangui-data:
|
||||
driver: local
|
||||
fail2ban-config:
|
||||
driver: local
|
||||
fail2ban-run:
|
||||
driver: local
|
||||
|
||||
networks:
|
||||
bangui-net:
|
||||
driver: bridge
|
||||
@@ -1,97 +0,0 @@
|
||||
version: '3.8'
|
||||
services:
|
||||
fail2ban:
|
||||
image: lscr.io/linuxserver/fail2ban:latest
|
||||
container_name: fail2ban
|
||||
cap_add:
|
||||
- NET_ADMIN
|
||||
- NET_RAW
|
||||
network_mode: host
|
||||
environment:
|
||||
- PUID=1011
|
||||
- PGID=1001
|
||||
- TZ=Etc/UTC
|
||||
- VERBOSITY=-vv #optional
|
||||
|
||||
volumes:
|
||||
- /server/server_fail2ban/config:/config
|
||||
- /server/server_fail2ban/fail2ban-run:/var/run/fail2ban
|
||||
- /var/log:/var/log
|
||||
- /server/server_nextcloud/config/nextcloud.log:/remotelogs/nextcloud/nextcloud.log:ro #optional
|
||||
- /server/server_nginx/data/logs:/remotelogs/nginx:ro #optional
|
||||
- /server/server_gitea/log/gitea.log:/remotelogs/gitea/gitea.log:ro #optional
|
||||
|
||||
|
||||
#- /path/to/homeassistant/log:/remotelogs/homeassistant:ro #optional
|
||||
#- /path/to/unificontroller/log:/remotelogs/unificontroller:ro #optional
|
||||
#- /path/to/vaultwarden/log:/remotelogs/vaultwarden:ro #optional
|
||||
restart: unless-stopped
|
||||
deploy:
|
||||
limits:
|
||||
cpus: '0.5'
|
||||
memory: 128M
|
||||
reservations:
|
||||
cpus: '0.1'
|
||||
memory: 64M
|
||||
|
||||
backend:
|
||||
image: git.lpl-mind.de/lukas.pupkalipinski/bangui/backend:latest
|
||||
container_name: bangui-backend
|
||||
restart: unless-stopped
|
||||
stop_grace_period: 30s
|
||||
depends_on:
|
||||
fail2ban:
|
||||
condition: service_started
|
||||
environment:
|
||||
- PUID=1011
|
||||
- PGID=1001
|
||||
- BANGUI_DATABASE_PATH=/data/bangui.db
|
||||
- BANGUI_FAIL2BAN_SOCKET=/var/run/fail2ban/fail2ban.sock
|
||||
- BANGUI_FAIL2BAN_CONFIG_DIR=/config/fail2ban
|
||||
- BANGUI_LOG_LEVEL=info
|
||||
# ⚠️ BANGUI_WORKERS MUST be 1 — the session cache is process-local
|
||||
# Multiple workers would cause random logouts and duplicate background jobs
|
||||
- BANGUI_SESSION_SECRET=${BANGUI_SESSION_SECRET:?Set BANGUI_SESSION_SECRET}
|
||||
- BANGUI_TIMEZONE=${BANGUI_TIMEZONE:-UTC}
|
||||
volumes:
|
||||
- /server/server_fail2ban/bangui-data:/data
|
||||
- /server/server_fail2ban/fail2ban-run:/var/run/fail2ban:ro
|
||||
- /server/server_fail2ban/config:/config:rw
|
||||
expose:
|
||||
- "8000"
|
||||
networks:
|
||||
- bangui-net
|
||||
deploy:
|
||||
limits:
|
||||
cpus: '2'
|
||||
memory: 512M
|
||||
reservations:
|
||||
cpus: '1'
|
||||
memory: 256M
|
||||
|
||||
# ── Frontend (nginx serving built SPA + API proxy) ──────────
|
||||
frontend:
|
||||
image: git.lpl-mind.de/lukas.pupkalipinski/bangui/frontend:latest
|
||||
container_name: bangui-frontend
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- PUID=1011
|
||||
- PGID=1001
|
||||
ports:
|
||||
- "${BANGUI_PORT:-8080}:80"
|
||||
depends_on:
|
||||
backend:
|
||||
condition: service_started
|
||||
networks:
|
||||
- bangui-net
|
||||
deploy:
|
||||
limits:
|
||||
cpus: '0.5'
|
||||
memory: 128M
|
||||
reservations:
|
||||
cpus: '0.25'
|
||||
memory: 64M
|
||||
|
||||
networks:
|
||||
bangui-net:
|
||||
name: bangui-net
|
||||
@@ -6,7 +6,7 @@
|
||||
# ./release.sh
|
||||
#
|
||||
# The current version is stored in VERSION (next to this script).
|
||||
# You will be asked whether to bump major, minor, or patch.
|
||||
# You will be asked whether to bump major, minor, patch, or release candidate (rc).
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
@@ -24,24 +24,60 @@ CURRENT="$(cat "${VERSION_FILE}")"
|
||||
# Strip leading 'v' for arithmetic
|
||||
VERSION="${CURRENT#v}"
|
||||
|
||||
IFS='.' read -r MAJOR MINOR PATCH <<< "${VERSION}"
|
||||
# Parse version: X.Y.Z or X.Y.Z-rc.N
|
||||
if [[ "${VERSION}" =~ ^([0-9]+)\.([0-9]+)\.([0-9]+)(-rc\.([0-9]+))?$ ]]; then
|
||||
MAJOR="${BASH_REMATCH[1]}"
|
||||
MINOR="${BASH_REMATCH[2]}"
|
||||
PATCH="${BASH_REMATCH[3]}"
|
||||
RC_SUFFIX="${BASH_REMATCH[4]:-}"
|
||||
RC_NUM="${BASH_REMATCH[5]:-0}"
|
||||
else
|
||||
echo "Error: version '${VERSION}' does not match expected format X.Y.Z or X.Y.Z-rc.N" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "============================================"
|
||||
echo " BanGUI — Release"
|
||||
echo " Current version: v${MAJOR}.${MINOR}.${PATCH}"
|
||||
if [[ -n "${RC_SUFFIX}" ]]; then
|
||||
echo " Current version: v${MAJOR}.${MINOR}.${PATCH}-rc.${RC_NUM}"
|
||||
else
|
||||
echo " Current version: v${MAJOR}.${MINOR}.${PATCH}"
|
||||
fi
|
||||
echo "============================================"
|
||||
echo ""
|
||||
echo "How would you like to bump the version?"
|
||||
echo " 1) patch (v${MAJOR}.${MINOR}.${PATCH} → v${MAJOR}.${MINOR}.$((PATCH + 1)))"
|
||||
echo " 2) minor (v${MAJOR}.${MINOR}.${PATCH} → v${MAJOR}.$((MINOR + 1)).0)"
|
||||
echo " 3) major (v${MAJOR}.${MINOR}.${PATCH} → v$((MAJOR + 1)).0.0)"
|
||||
if [[ -n "${RC_SUFFIX}" ]]; then
|
||||
echo " 1) patch (v${MAJOR}.${MINOR}.${PATCH}-rc.${RC_NUM} → v${MAJOR}.${MINOR}.${PATCH})"
|
||||
echo " 2) minor (v${MAJOR}.${MINOR}.${PATCH}-rc.${RC_NUM} → v${MAJOR}.$((MINOR + 1)).0)"
|
||||
echo " 3) major (v${MAJOR}.${MINOR}.${PATCH}-rc.${RC_NUM} → v$((MAJOR + 1)).0.0)"
|
||||
echo " 4) rc (v${MAJOR}.${MINOR}.${PATCH}-rc.${RC_NUM} → v${MAJOR}.${MINOR}.${PATCH}-rc.$((RC_NUM + 1)))"
|
||||
else
|
||||
echo " 1) patch (v${MAJOR}.${MINOR}.${PATCH} → v${MAJOR}.${MINOR}.$((PATCH + 1)))"
|
||||
echo " 2) minor (v${MAJOR}.${MINOR}.${PATCH} → v${MAJOR}.$((MINOR + 1)).0)"
|
||||
echo " 3) major (v${MAJOR}.${MINOR}.${PATCH} → v$((MAJOR + 1)).0.0)"
|
||||
echo " 4) rc (v${MAJOR}.${MINOR}.${PATCH} → v${MAJOR}.${MINOR}.${PATCH}-rc.1)"
|
||||
fi
|
||||
echo ""
|
||||
read -rp "Enter choice [1/2/3]: " CHOICE
|
||||
read -rp "Enter choice [1/2/3/4]: " CHOICE
|
||||
|
||||
case "${CHOICE}" in
|
||||
1) NEW_TAG="v${MAJOR}.${MINOR}.$((PATCH + 1))" ;;
|
||||
1)
|
||||
if [[ -n "${RC_SUFFIX}" ]]; then
|
||||
# Release the RC: strip RC suffix
|
||||
NEW_TAG="v${MAJOR}.${MINOR}.${PATCH}"
|
||||
else
|
||||
NEW_TAG="v${MAJOR}.${MINOR}.$((PATCH + 1))"
|
||||
fi
|
||||
;;
|
||||
2) NEW_TAG="v${MAJOR}.$((MINOR + 1)).0" ;;
|
||||
3) NEW_TAG="v$((MAJOR + 1)).0.0" ;;
|
||||
4)
|
||||
if [[ "${RC_NUM}" -gt 0 ]]; then
|
||||
NEW_TAG="v${MAJOR}.${MINOR}.${PATCH}-rc.$((RC_NUM + 1))"
|
||||
else
|
||||
NEW_TAG="v${MAJOR}.${MINOR}.${PATCH}-rc.1"
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
echo "Invalid choice. Aborting." >&2
|
||||
exit 1
|
||||
@@ -81,7 +117,13 @@ fi
|
||||
# Push containers
|
||||
# ---------------------------------------------------------------------------
|
||||
bash "${SCRIPT_DIR}/push.sh" "${NEW_TAG}"
|
||||
bash "${SCRIPT_DIR}/push.sh"
|
||||
|
||||
# Push to "latest" or "latestRC" depending on whether this is a release candidate
|
||||
if [[ "${NEW_TAG}" == *-rc* ]]; then
|
||||
bash "${SCRIPT_DIR}/push.sh" "latestRC"
|
||||
else
|
||||
bash "${SCRIPT_DIR}/push.sh" "latest"
|
||||
fi
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# Defaults:
|
||||
# COUNT : 5
|
||||
# SOURCE_IP: 192.168.100.99
|
||||
# LOG_FILE : Docker/logs/auth.log (relative to repo root)
|
||||
# LOG_FILE : data/log/auth.log (relative to repo root)
|
||||
#
|
||||
# Log line format (must match manual-Jail failregex exactly):
|
||||
# YYYY-MM-DD HH:MM:SS bangui-auth: authentication failure from <IP>
|
||||
@@ -25,7 +25,7 @@ readonly DEFAULT_IP="192.168.100.99"
|
||||
|
||||
# Resolve script location so defaults work regardless of cwd.
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
readonly DEFAULT_LOG_FILE="${SCRIPT_DIR}/logs/auth.log"
|
||||
readonly DEFAULT_LOG_FILE="${SCRIPT_DIR}/../data/log/auth.log"
|
||||
|
||||
# ── Arguments ─────────────────────────────────────────────────
|
||||
COUNT="${1:-${DEFAULT_COUNT}}"
|
||||
|
||||
@@ -1238,8 +1238,6 @@ The `setup_completed = "1"` key is still written for backward compatibility with
|
||||
- **GeoCache** — `GeoCache` instance is created at startup with a configurable `allow_http_fallback` flag and stored on `app.state.geo_cache`. It implements a primary + fallback resolution strategy: (1) try local MaxMind GeoLite2-Country MMDB database (primary, encrypted, no network traffic), (2) if unavailable/no result and allowed, fall back to ip-api.com HTTP API (unencrypted, disabled by default for security). Encapsulates in-memory lookup cache, negative cache for unresolvable IPs (5-minute TTL), dirty set for persistence, and thread-safe async locking. Cache is loaded from the `geo_cache` SQLite table on startup. New resolutions are accumulated in memory and periodically flushed to the database by the `geo_cache_flush` background task. Stale entries are re-resolved by the `geo_re_resolve` task. Injected into routes and tasks via FastAPI's dependency system. See Backend-Development.md § IP Geolocation Resolution for setup and security details.
|
||||
- **Runtime state** (`RuntimeState` in `app.utils.runtime_state`) — stores mutable application state: `server_status` (fail2ban online/offline), `last_activation` (jail activation tracking), `pending_recovery` (crash detection), `runtime_settings` (effective configuration), and service-specific state holders like `jail_service_state` (`JailServiceState` for jail capability detection cache). RuntimeState fields are managed through dedicated functions (e.g., `record_activation()`, `clear_pending_recovery()`) and via dependency injection to services. Service-specific state (like `JailServiceState`) is nested within `RuntimeState` to keep all mutable state in one controlled location. **⚠️ RuntimeState is process-local and only safe when BanGUI runs as a single asyncio worker.** Mutations must not span `await` points (cooperative scheduling within a single event loop is safe). In multi-worker deployments, each process has its own copy — logouts from worker A don't affect worker B's cache, health status updates are per-worker, and activation tracking is unreliable. BanGUI enforces single-worker mode (TASK-002) to prevent this issue. For future multi-worker support, replace RuntimeState with a shared coordination backend (Redis, shared memory, database). See `app/utils/runtime_state.py` module docstring for details.
|
||||
- **Setup-completion flag** — once `is_setup_complete()` returns `True`, the result is stored in `app.state._setup_complete_cached`. The `SetupRedirectMiddleware` skips the DB query on all subsequent requests, removing 1 SQL query per request for the common post-setup case. The completion flag is only written after the runtime database is successfully initialized and all initial setup settings are persisted, preventing a failed setup from permanently bypassing the setup wizard.
|
||||
- **Login Rate Limiting** — the `/api/auth/login` endpoint employs exponential backoff to defend against brute-force attacks. Each failed login attempt is recorded per client IP, and subsequent attempts within the backoff window return HTTP 429 Too Many Requests. The penalty grows exponentially with each consecutive failure (2s, 4s, 8s, up to 10s max), ensuring attackers face rapidly increasing delays. This is complemented by bcrypt password hashing (≈100ms per attempt), which adds computational resistance without blocking legitimate users. The backoff counter resets after 60 seconds without additional failures. The rate limiter is process-local and tracks failures in memory via `app.utils.rate_limiter.RateLimiter`, stored on `app.state.login_rate_limiter`. Client IP detection respects proxy headers (`X-Forwarded-For`, `X-Real-IP`) only from configured trusted proxies, preventing header spoofing attacks. In multi-worker deployments, each worker has independent rate limit counters; BanGUI enforces single-worker mode (TASK-002) to prevent attackers from bypassing limits by distributing requests across workers.
|
||||
|
||||
### 8.1 CSRF Protection
|
||||
|
||||
State-mutating endpoints (POST, PUT, DELETE, PATCH) that use cookie-based authentication are protected against Cross-Site Request Forgery (CSRF) attacks via a **custom header check middleware**.
|
||||
|
||||
@@ -1665,6 +1665,37 @@ async def get_jail(...) -> JailDetailResponse:
|
||||
|
||||
---
|
||||
|
||||
### 7.7 Third-Party Library Log Levels
|
||||
|
||||
Application code must use **structlog** for all logging. Third-party libraries that emit logs through Python's standard `logging` module are configured centrally in `backend/app/main.py::_configure_logging()`.
|
||||
|
||||
**Current overrides:**
|
||||
|
||||
| Library | Logger | Level | Reason |
|
||||
|---------|--------|-------|--------|
|
||||
| APScheduler | `apscheduler` | `WARNING` | Routine scheduler polling is too verbose at DEBUG. |
|
||||
| aiosqlite | `aiosqlite` | `WARNING` | Database operation traces clutter logs. |
|
||||
|
||||
**Adding a new override:**
|
||||
|
||||
```python
|
||||
# In backend/app/main.py, inside _configure_logging()
|
||||
logging.getLogger("new_library").setLevel(logging.WARNING)
|
||||
```
|
||||
|
||||
- Prefer `WARNING` over `ERROR` so legitimate warnings (e.g., connection retries) are still visible.
|
||||
- Place the override immediately after `logging.basicConfig()` so it takes effect before any library initializes its own loggers.
|
||||
|
||||
**Disabling suppression:**
|
||||
|
||||
Set `BANGUI_SUPPRESS_THIRD_PARTY_LOGS=false` to allow APScheduler and aiosqlite to emit their normal DEBUG/INFO logs. This is useful when troubleshooting scheduler or database issues in development.
|
||||
|
||||
**Stdlib interception:**
|
||||
|
||||
All stdlib logs are intercepted by `structlog.stdlib.ProcessorFormatter` and rendered as JSON. Even third-party library logs therefore appear as structured JSON in `bangui.log`, not plain text.
|
||||
|
||||
---
|
||||
|
||||
## 8. Error Handling
|
||||
|
||||
- Define **custom exception classes** for domain errors (e.g., `JailNotFoundError`, `BanFailedError`).
|
||||
@@ -2771,41 +2802,6 @@ update = GlobalConfigUpdate(log_target="/etc/passwd") # Raises ValidationError
|
||||
await config_service.update_global_config(socket_path, update) # Validates again before sending to fail2ban
|
||||
```
|
||||
|
||||
### Login Rate Limiting
|
||||
|
||||
The login endpoint (`POST /api/auth/login`) is protected against brute-force attacks using an in-memory exponential backoff rate limiter.
|
||||
|
||||
**Design:**
|
||||
- Uses a `dict[str, deque[float]]` keyed by client IP, storing failed login timestamps within a time window.
|
||||
- Old failures outside the time window are automatically pruned during validation checks.
|
||||
- Expired IP entries are cleaned up to prevent unbounded memory growth.
|
||||
|
||||
**Rate Limit Rules:**
|
||||
- **Exponential backoff:** Each failed login attempt incurs a progressively longer delay before the next attempt is allowed:
|
||||
- 1st failure: 1 × 2¹ = 2 seconds
|
||||
- 2nd failure: 1 × 2² = 4 seconds
|
||||
- 3rd failure: 1 × 2³ = 8 seconds
|
||||
- 4th+ failures: capped at 10 seconds (max)
|
||||
- Failed attempts that arrive during the backoff period return **HTTP 429 Too Many Requests** with a `Retry-After` header indicating the remaining wait time.
|
||||
- Each failed login is also accompanied by bcrypt password hashing (~100ms), providing additional computational resistance.
|
||||
- The backoff counter resets after the rate-limit window (60 seconds by default) expires with no new failures.
|
||||
|
||||
**IP Extraction (Proxy Safety):**
|
||||
- When behind nginx, the rate limiter reads the real client IP from `X-Forwarded-For` or `X-Real-IP` headers.
|
||||
- Only trusts these headers when the immediate connection source is in a configured trusted proxy list.
|
||||
- Prevents attackers from spoofing these headers to bypass rate limits.
|
||||
- Falls back to the direct connection IP when proxy headers cannot be trusted.
|
||||
|
||||
**Process-Local Limitation:**
|
||||
- The rate limiter is process-local (in-memory). In multi-worker deployments (e.g., Gunicorn with 4 workers), each worker maintains its own rate limit counter.
|
||||
- This is acceptable because the single-worker constraint is enforced elsewhere. See [TASK-002/003 notes](Instructions.md) for details.
|
||||
|
||||
**Implementation:**
|
||||
- Rate limiter: `app.utils.rate_limiter.RateLimiter`
|
||||
- IP extraction: `app.utils.client_ip.get_client_ip()`
|
||||
- Dependency: `LoginRateLimiterDep` in `app.dependencies`
|
||||
|
||||
|
||||
### Global Rate Limiting
|
||||
|
||||
In addition to login-specific rate limiting, all API endpoints are protected by global per-IP rate limiting to prevent resource exhaustion, CPU spikes, and network bandwidth attacks from malicious or misconfigured clients.
|
||||
|
||||
@@ -98,6 +98,44 @@ log.error("fail2ban_start_failed", stdout=stdout_raw, stderr=stderr_raw) # Neve
|
||||
|
||||
---
|
||||
|
||||
## Third-Party Library Logs
|
||||
|
||||
BanGUI uses **structlog** for all application logs, but third-party libraries often emit plain text through Python's standard `logging` module. To maintain uniform JSON output and reduce noise, the following libraries have their log levels overridden to `WARNING`:
|
||||
|
||||
| Library | Logger Name | Level | Rationale |
|
||||
|---------|-------------|-------|-----------|
|
||||
| APScheduler | `apscheduler` | `WARNING` | Suppresses routine scheduler polling ("Looking for jobs to run", "Next wakeup is due at...") while preserving job failure warnings. |
|
||||
| aiosqlite | `aiosqlite` | `WARNING` | Suppresses database operation traces and connection details while preserving connection errors. |
|
||||
|
||||
These overrides are applied in `backend/app/main.py::_configure_logging()` immediately after `logging.basicConfig()`.
|
||||
|
||||
### Disabling Suppression
|
||||
|
||||
Set the environment variable `BANGUI_SUPPRESS_THIRD_PARTY_LOGS=false` to allow APScheduler and aiosqlite to emit their normal DEBUG/INFO logs. This is useful when troubleshooting scheduler or database issues in development.
|
||||
|
||||
```bash
|
||||
BANGUI_SUPPRESS_THIRD_PARTY_LOGS=false python -m uvicorn app.main:create_app
|
||||
```
|
||||
|
||||
When suppression is disabled, the loggers inherit the application's `BANGUI_LOG_LEVEL` (e.g., `debug`).
|
||||
|
||||
### Uniform JSON Formatting
|
||||
|
||||
All stdlib logs — including those from third-party libraries — are intercepted by `structlog.stdlib.ProcessorFormatter` and rendered as JSON. This ensures every log line in `bangui.log` is machine-readable, regardless of its source.
|
||||
|
||||
### Adding New Overrides
|
||||
|
||||
When integrating a new library that emits verbose DEBUG logs:
|
||||
|
||||
```python
|
||||
# In backend/app/main.py, inside _configure_logging()
|
||||
logging.getLogger("new_library").setLevel(logging.WARNING)
|
||||
```
|
||||
|
||||
Use `WARNING` as the default to still capture errors and warnings. Only use `ERROR` if the library is exceptionally noisy and its warnings are not actionable.
|
||||
|
||||
---
|
||||
|
||||
## Structured Logging Best Practices
|
||||
|
||||
### Log Levels
|
||||
|
||||
@@ -418,6 +418,65 @@ Then set it in your `.env` file or environment variables.
|
||||
|
||||
---
|
||||
|
||||
## Enabling Debug Logs for Third-Party Libraries
|
||||
|
||||
BanGUI suppresses verbose DEBUG logs from APScheduler and aiosqlite by default (see `Docs/Observability.md`). When troubleshooting scheduler or database issues, you can temporarily re-enable these logs.
|
||||
|
||||
### Quick method (environment variable)
|
||||
|
||||
Set `BANGUI_SUPPRESS_THIRD_PARTY_LOGS=false` and ensure `BANGUI_LOG_LEVEL=debug`:
|
||||
|
||||
```bash
|
||||
BANGUI_SUPPRESS_THIRD_PARTY_LOGS=false \
|
||||
BANGUI_LOG_LEVEL=debug \
|
||||
python -m uvicorn app.main:create_app
|
||||
```
|
||||
|
||||
This allows APScheduler and aiosqlite to inherit the application log level without editing code.
|
||||
|
||||
### Code method (for permanent changes)
|
||||
|
||||
If you need to change the level for a specific library only, edit `backend/app/main.py` inside `_configure_logging()`:
|
||||
|
||||
```python
|
||||
logging.getLogger("apscheduler").setLevel(logging.DEBUG)
|
||||
```
|
||||
|
||||
Restart the application. You will see scheduler polling messages such as:
|
||||
- `Looking for jobs to run`
|
||||
- `Next wakeup is due at ...`
|
||||
- `Running job ...`
|
||||
|
||||
### Reverting
|
||||
|
||||
Remove the environment variable or code change and restart. When suppression is re-enabled, the loggers return to `WARNING` level.
|
||||
|
||||
---
|
||||
|
||||
## Plain Text Logs Still Appearing
|
||||
|
||||
If `bangui.log` contains plain text lines that are not JSON, a library is bypassing structlog's `ProcessorFormatter`.
|
||||
|
||||
**Diagnosis:**
|
||||
|
||||
1. Identify the logger name in the plain text line (usually at the start of the line).
|
||||
2. Check whether the logger is listed in `backend/app/main.py::_configure_logging()` under the third-party overrides.
|
||||
3. Verify that `structlog.stdlib.ProcessorFormatter` is attached to all handlers:
|
||||
```python
|
||||
for handler in handlers:
|
||||
handler.setFormatter(formatter)
|
||||
```
|
||||
|
||||
**Common causes:**
|
||||
|
||||
| Cause | Fix |
|
||||
|-------|-----|
|
||||
| Library initializes its own handler after startup | Add `logging.getLogger("library_name").setLevel(logging.WARNING)` in `_configure_logging()`. |
|
||||
| Custom handler added outside `_configure_logging()` | Ensure all handlers use `structlog.stdlib.ProcessorFormatter`. |
|
||||
| Log emitted before `_configure_logging()` is called | Move logging configuration earlier in the lifespan or app factory. |
|
||||
|
||||
---
|
||||
|
||||
## Getting Help
|
||||
|
||||
If issues persist after following this guide:
|
||||
|
||||
@@ -102,7 +102,7 @@ for (int i = 0; i < items.Count; i++)
|
||||
|
||||
// Step 1 — run the task prompt
|
||||
await RunCopilot(Enumerable.Empty<string>(), $"/caveman full");
|
||||
await RunCopilot(new[] { "--continue" }, $"read ./Docs/Instructions.md. {item}");
|
||||
await RunCopilot(new[] { "--continue" }, $"read ./Docs/Instructions.md. fix the following test and only that one. Keep in mind that i did many refactorings and test may is obsolet or need to be changed. {item}");
|
||||
if (cts.IsCancellationRequested) break;
|
||||
|
||||
// Step 2 — confirm completion in the same chat session
|
||||
|
||||
19
Makefile
19
Makefile
@@ -64,11 +64,11 @@ print('Created .env with a generated BANGUI_SESSION_SECRET.')"; \
|
||||
|
||||
## Start the debug stack (detached).
|
||||
## Ensures log stub files exist so fail2ban can open them on first start.
|
||||
## All output is logged to Docker/logs/make-up.log.
|
||||
## All output is logged to /data/log/make-up.log.
|
||||
up: ensure-env
|
||||
@mkdir -p Docker/logs
|
||||
@touch Docker/logs/auth.log
|
||||
$(COMPOSE) $(COMPOSE_OPTS) up -d 2>&1 | tee Docker/logs/make-up.log
|
||||
@mkdir -p data/log
|
||||
@touch data/log/auth.log
|
||||
$(COMPOSE) $(COMPOSE_OPTS) up -d 2>&1 | tee data/log/make-up.log
|
||||
|
||||
## Stop the debug stack.
|
||||
down: ensure-env
|
||||
@@ -91,20 +91,23 @@ clean: ensure-env
|
||||
$(COMPOSE) $(COMPOSE_OPTS) down --remove-orphans
|
||||
$(RUNTIME) volume rm $(DEV_VOLUMES) 2>/dev/null || true
|
||||
$(RUNTIME) rmi $(DEV_IMAGES) 2>/dev/null || true
|
||||
@echo "All debug volumes and local images removed. Run 'make up' to rebuild and start fresh."
|
||||
rm -rf ./data
|
||||
@echo "All debug volumes, local images, and ./data removed. Run 'make up' to rebuild and start fresh."
|
||||
|
||||
## Run the Robot Framework E2E test suite.
|
||||
## Requires: stack up (make up), BANGUI_SESSION_SECRET env var set.
|
||||
## Installs: pip install -r e2e/requirements.txt && rfbrowser init
|
||||
e2e: up
|
||||
e2e: down clean up
|
||||
@echo "Waiting 2 minutes for services to initialize..."
|
||||
@sleep 120
|
||||
@echo "Waiting for stack to be healthy..."
|
||||
@timeout=120; \
|
||||
until curl -sf http://localhost:8000/api/health > /dev/null 2>&1; do \
|
||||
until curl -sf http://localhost:8000/api/v1/health > /dev/null 2>&1; do \
|
||||
sleep 5; timeout=$$((timeout-5)); \
|
||||
if [ $$timeout -le 0 ]; then echo "Backend not healthy after 120s"; exit 1; fi; \
|
||||
done
|
||||
pip install -r e2e/requirements.txt -q
|
||||
rfbrowser init --quiet
|
||||
rfbrowser init
|
||||
robot --outputdir e2e/results e2e/tests/
|
||||
|
||||
## One-command smoke test for the ban pipeline:
|
||||
|
||||
@@ -285,6 +285,17 @@ class Settings(BaseSettings):
|
||||
default="info",
|
||||
description="Application log level: debug | info | warning | error | critical.",
|
||||
)
|
||||
log_file: str | None = Field(
|
||||
default="/data/log/bangui.log",
|
||||
description="Optional file path for writing application logs. Set to null to disable file logging.",
|
||||
)
|
||||
suppress_third_party_logs: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"When true, sets APScheduler and aiosqlite loggers to WARNING level. "
|
||||
"Set to false to allow third-party libraries to emit DEBUG/INFO logs."
|
||||
),
|
||||
)
|
||||
geoip_db_path: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
@@ -596,6 +607,7 @@ class Settings(BaseSettings):
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -14,9 +14,10 @@ from __future__ import annotations
|
||||
from pathlib import Path
|
||||
|
||||
import aiosqlite
|
||||
import structlog
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = get_logger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DDL statements
|
||||
@@ -246,7 +247,6 @@ CREATE INDEX IF NOT EXISTS idx_import_log_source_id_desc
|
||||
}
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -254,6 +254,7 @@ CREATE INDEX IF NOT EXISTS idx_import_log_source_id_desc
|
||||
|
||||
async def _configure_connection(db: aiosqlite.Connection) -> None:
|
||||
"""Apply hardening pragmas to a newly-opened SQLite connection."""
|
||||
await db.execute("PRAGMA journal_mode=WAL;")
|
||||
await db.execute("PRAGMA foreign_keys=ON;")
|
||||
await db.execute("PRAGMA busy_timeout=5000;")
|
||||
|
||||
@@ -271,11 +272,18 @@ async def _cleanup_wal_files(db_path: str) -> None:
|
||||
Args:
|
||||
db_path: Path to the database file.
|
||||
"""
|
||||
import time
|
||||
|
||||
wal_path = Path(db_path + "-wal")
|
||||
shm_path = Path(db_path + "-shm")
|
||||
|
||||
for path in (wal_path, shm_path):
|
||||
if path.exists():
|
||||
# Skip files that were modified recently — they likely belong to an
|
||||
# active connection. Only remove stale files left by crashes.
|
||||
mtime = path.stat().st_mtime
|
||||
if time.time() - mtime < 10:
|
||||
continue
|
||||
try:
|
||||
path.unlink()
|
||||
log.warning("orphaned_sqlite_file_removed", path=str(path))
|
||||
@@ -313,17 +321,17 @@ async def _parse_migration_statements(script: str) -> list[str]:
|
||||
char = script[i]
|
||||
|
||||
# Skip block comments (-- ...)
|
||||
if i < len(script) - 1 and script[i:i+2] == "--":
|
||||
if i < len(script) - 1 and script[i : i + 2] == "--":
|
||||
while i < len(script) and script[i] != "\n":
|
||||
i += 1
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Skip line comments (/* ... */)
|
||||
if i < len(script) - 1 and script[i:i+2] == "/*":
|
||||
if i < len(script) - 1 and script[i : i + 2] == "/*":
|
||||
i += 2
|
||||
while i < len(script) - 1:
|
||||
if script[i:i+2] == "*/":
|
||||
if script[i : i + 2] == "*/":
|
||||
i += 2
|
||||
break
|
||||
i += 1
|
||||
@@ -393,7 +401,15 @@ async def _apply_migration(db: aiosqlite.Connection, version: int) -> None:
|
||||
await db.execute("BEGIN IMMEDIATE;")
|
||||
|
||||
for statement in statements:
|
||||
await db.execute(statement)
|
||||
try:
|
||||
await db.execute(statement)
|
||||
except aiosqlite.OperationalError as exc:
|
||||
# Ignore duplicate column / table errors so migrations remain
|
||||
# idempotent when a legacy database already has the object.
|
||||
msg = str(exc).lower()
|
||||
if "duplicate column name" in msg or "table" in msg and "already exists" in msg:
|
||||
continue
|
||||
raise
|
||||
|
||||
await db.execute("INSERT INTO schema_migrations (version) VALUES (?);", (version,))
|
||||
|
||||
@@ -411,8 +427,7 @@ async def _migrate_schema(db: aiosqlite.Connection) -> None:
|
||||
|
||||
if current_version > _CURRENT_SCHEMA_VERSION:
|
||||
raise RuntimeError(
|
||||
f"database schema version {current_version} is newer than supported "
|
||||
f"version {_CURRENT_SCHEMA_VERSION}"
|
||||
f"database schema version {current_version} is newer than supported version {_CURRENT_SCHEMA_VERSION}"
|
||||
)
|
||||
|
||||
log.info("migrating_database_schema", from_version=current_version, to_version=_CURRENT_SCHEMA_VERSION)
|
||||
|
||||
@@ -36,7 +36,6 @@ from typing import Annotated, cast
|
||||
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
import structlog
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore[import-untyped]
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
||||
|
||||
@@ -45,22 +44,6 @@ from app.exceptions import RateLimitError
|
||||
from app.models.auth import Session
|
||||
from app.models.config import PendingRecovery
|
||||
from app.models.server import ServerStatus
|
||||
from app.repositories.protocols import (
|
||||
BlocklistRepository,
|
||||
Fail2BanDbRepository,
|
||||
GeoCacheRepository,
|
||||
HistoryArchiveRepository,
|
||||
ImportLogRepository,
|
||||
ImportRunRepository,
|
||||
SessionRepository,
|
||||
SettingsRepository,
|
||||
)
|
||||
from app.services.geo_cache import GeoCache
|
||||
from app.services.protocols import Fail2BanMetadataService
|
||||
from app.utils.constants import SESSION_COOKIE_NAME
|
||||
from app.utils.rate_limiter import GlobalRateLimiter, RateLimiter
|
||||
from app.utils.runtime_state import ApplicationState, JailServiceState, RuntimeState
|
||||
from app.utils.session_cache import NoOpSessionCache, SessionCache
|
||||
|
||||
# Module-level imports for repositories and services
|
||||
# These are safe at module level since no circular dependencies exist
|
||||
@@ -74,10 +57,27 @@ from app.repositories import (
|
||||
session_repo,
|
||||
settings_repo,
|
||||
)
|
||||
from app.repositories.protocols import (
|
||||
BlocklistRepository,
|
||||
Fail2BanDbRepository,
|
||||
GeoCacheRepository,
|
||||
HistoryArchiveRepository,
|
||||
ImportLogRepository,
|
||||
ImportRunRepository,
|
||||
SessionRepository,
|
||||
SettingsRepository,
|
||||
)
|
||||
from app.services import auth_service, health_service
|
||||
from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service
|
||||
from app.services.geo_cache import GeoCache
|
||||
from app.services.protocols import Fail2BanMetadataService
|
||||
from app.utils.constants import SESSION_COOKIE_NAME
|
||||
from app.utils.logging_compat import get_logger
|
||||
from app.utils.rate_limiter import GlobalRateLimiter
|
||||
from app.utils.runtime_state import ApplicationState, JailServiceState, RuntimeState
|
||||
from app.utils.session_cache import NoOpSessionCache, SessionCache
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -93,7 +93,6 @@ class ApplicationContext:
|
||||
runtime_settings: Settings | None
|
||||
runtime_state: RuntimeState
|
||||
session_cache: SessionCache | None
|
||||
login_rate_limiter: RateLimiter
|
||||
global_rate_limiter: GlobalRateLimiter
|
||||
|
||||
|
||||
@@ -109,6 +108,7 @@ class ApplicationContext:
|
||||
#: or distributed deployments, the configured cache backend should provide
|
||||
#: invalidation semantics appropriate for the deployment.
|
||||
|
||||
|
||||
def _session_cache_enabled(settings: Settings) -> bool:
|
||||
"""Return whether the session validation cache should be used."""
|
||||
return settings.session_cache_enabled and settings.session_cache_ttl_seconds > 0.0
|
||||
@@ -120,10 +120,6 @@ def _build_app_context(request: Request) -> ApplicationContext:
|
||||
if session_cache is None:
|
||||
session_cache = NoOpSessionCache()
|
||||
|
||||
login_rate_limiter: RateLimiter = getattr(state, "login_rate_limiter", None)
|
||||
if login_rate_limiter is None:
|
||||
login_rate_limiter = RateLimiter()
|
||||
|
||||
global_rate_limiter: GlobalRateLimiter = getattr(state, "global_rate_limiter", None)
|
||||
if global_rate_limiter is None:
|
||||
global_rate_limiter = GlobalRateLimiter()
|
||||
@@ -138,7 +134,6 @@ def _build_app_context(request: Request) -> ApplicationContext:
|
||||
runtime_settings=getattr(state, "runtime_settings", None),
|
||||
runtime_state=state.runtime_state,
|
||||
session_cache=session_cache,
|
||||
login_rate_limiter=login_rate_limiter,
|
||||
global_rate_limiter=global_rate_limiter,
|
||||
)
|
||||
|
||||
@@ -264,13 +259,6 @@ async def get_session_cache(app_context: Annotated[ApplicationContext, Depends(g
|
||||
return app_context.session_cache
|
||||
|
||||
|
||||
async def get_login_rate_limiter(
|
||||
app_context: Annotated[ApplicationContext, Depends(get_app_context)],
|
||||
) -> RateLimiter:
|
||||
"""Provide the login endpoint rate limiter from application context."""
|
||||
return app_context.login_rate_limiter
|
||||
|
||||
|
||||
async def get_global_rate_limiter(
|
||||
app_context: Annotated[ApplicationContext, Depends(get_app_context)],
|
||||
) -> GlobalRateLimiter:
|
||||
@@ -297,6 +285,7 @@ def rate_limit_dependency(
|
||||
Returns:
|
||||
A callable that can be used as a FastAPI Depends() dependency.
|
||||
"""
|
||||
|
||||
async def check_rate_limit(
|
||||
request: Request,
|
||||
rate_limiter: GlobalRateLimiterDep,
|
||||
@@ -306,9 +295,7 @@ def rate_limit_dependency(
|
||||
settings: Settings = request.app.state.settings
|
||||
client_ip = get_client_ip(request, trusted_proxies=settings.trusted_proxies)
|
||||
|
||||
is_allowed, retry_after = rate_limiter.check_allowed_for_bucket(
|
||||
bucket, client_ip, max_requests, window_seconds
|
||||
)
|
||||
is_allowed, retry_after = rate_limiter.check_allowed_for_bucket(bucket, client_ip, max_requests, window_seconds)
|
||||
|
||||
if not is_allowed:
|
||||
log.warning(
|
||||
@@ -420,6 +407,8 @@ async def get_app(request: Request) -> FastAPI:
|
||||
|
||||
async def get_server_status(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> ServerStatus:
|
||||
"""Return the cached fail2ban server status snapshot from application context."""
|
||||
if app_context.server_status is None:
|
||||
return ServerStatus(online=False)
|
||||
return app_context.server_status
|
||||
|
||||
|
||||
@@ -667,7 +656,7 @@ async def require_auth(
|
||||
if not token:
|
||||
auth_header: str = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
token = auth_header[len("Bearer "):]
|
||||
token = auth_header[len("Bearer ") :]
|
||||
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
@@ -730,7 +719,6 @@ Fail2BanDbRepositoryDep = Annotated[Fail2BanDbRepository, Depends(get_fail2ban_d
|
||||
AppStateDep = Annotated[ApplicationContext, Depends(get_app_state)]
|
||||
AppDep = Annotated[FastAPI, Depends(get_app)]
|
||||
AuthDep = Annotated[Session, Depends(require_auth)]
|
||||
LoginRateLimiterDep = Annotated[RateLimiter, Depends(get_login_rate_limiter)]
|
||||
GlobalRateLimiterDep = Annotated[GlobalRateLimiter, Depends(get_global_rate_limiter)]
|
||||
Fail2BanMetadataServiceDep = Annotated[Fail2BanMetadataService, Depends(get_fail2ban_metadata_service)]
|
||||
|
||||
|
||||
@@ -25,7 +25,6 @@ if TYPE_CHECKING:
|
||||
|
||||
from app.models.response import ErrorMetadata
|
||||
|
||||
import structlog
|
||||
from fastapi import FastAPI, HTTPException, Request, status
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@@ -73,13 +72,15 @@ from app.utils.external_logging import (
|
||||
ExternalLogHandler,
|
||||
create_external_log_handler,
|
||||
)
|
||||
from app.utils.rate_limiter import GlobalRateLimiter, RateLimiter
|
||||
from app.utils.json_formatter import JSONFormatter
|
||||
from app.utils.logging_compat import get_logger
|
||||
from app.utils.rate_limiter import GlobalRateLimiter
|
||||
from app.utils.runtime_state import ApplicationState, RuntimeState
|
||||
from app.utils.scheduler_lock import release_scheduler_lock
|
||||
from app.utils.session_cache import InMemorySessionCache, NoOpSessionCache
|
||||
from app.utils.setup_state import is_setup_complete_cached, set_setup_complete_cache
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger("bangui")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -89,58 +90,67 @@ log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
_external_log_handler: ExternalLogHandler | None = None
|
||||
|
||||
|
||||
def _external_logging_processor(
|
||||
logger: logging.Logger, method_name: str, event_dict: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Structlog processor that queues logs to external logging handler.
|
||||
def _external_logging_processor(record: logging.LogRecord) -> None:
|
||||
"""Queue log record to external logging handler.
|
||||
|
||||
Args:
|
||||
logger: The logger instance.
|
||||
method_name: The name of the method called on the logger.
|
||||
event_dict: The event dictionary from structlog.
|
||||
|
||||
Returns:
|
||||
The event dictionary unchanged (other processors handle rendering).
|
||||
record: The log record to queue.
|
||||
"""
|
||||
if _external_log_handler is not None:
|
||||
_external_log_handler.queue_log(event_dict.copy())
|
||||
return event_dict
|
||||
_external_log_handler.queue_log(
|
||||
{
|
||||
"event": record.getMessage(),
|
||||
"level": record.levelname.lower(),
|
||||
"logger": record.name,
|
||||
"timestamp": record.created,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _configure_logging(log_level: str, settings: Settings | None = None) -> None:
|
||||
"""Configure structlog for production JSON output.
|
||||
class _ExternalLoggingHandler(logging.Handler):
|
||||
"""Handler that forwards log records to the external log handler."""
|
||||
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
_external_logging_processor(record)
|
||||
|
||||
|
||||
def _configure_logging(log_level: str, log_file: str | None, settings: Settings | None = None) -> None:
|
||||
"""Configure stdlib logging for production JSON output.
|
||||
|
||||
Args:
|
||||
log_level: One of ``debug``, ``info``, ``warning``, ``error``, ``critical``.
|
||||
log_file: Optional file path to write logs to (in addition to stdout).
|
||||
settings: Optional Settings object to configure external logging.
|
||||
"""
|
||||
level: int = logging.getLevelName(log_level.upper())
|
||||
logging.basicConfig(level=level, stream=sys.stdout, format="%(message)s")
|
||||
handlers: list[logging.Handler] = [logging.StreamHandler(sys.stdout)]
|
||||
if log_file:
|
||||
try:
|
||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||
handlers.append(logging.FileHandler(log_file))
|
||||
except (PermissionError, OSError) as exc:
|
||||
log.warning(
|
||||
"log_file_directory_not_created",
|
||||
log_file=log_file,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
processors = [
|
||||
structlog.contextvars.merge_contextvars,
|
||||
structlog.stdlib.filter_by_level,
|
||||
structlog.processors.TimeStamper(fmt="iso"),
|
||||
structlog.stdlib.add_logger_name,
|
||||
structlog.stdlib.add_log_level,
|
||||
structlog.stdlib.PositionalArgumentsFormatter(),
|
||||
structlog.processors.StackInfoRenderer(),
|
||||
structlog.processors.format_exc_info,
|
||||
structlog.processors.UnicodeDecoder(),
|
||||
]
|
||||
# Suppress verbose third-party library logs that emit plain text
|
||||
# through the standard library logging module.
|
||||
if settings is None or settings.suppress_third_party_logs:
|
||||
logging.getLogger("apscheduler").setLevel(logging.WARNING)
|
||||
logging.getLogger("aiosqlite").setLevel(logging.WARNING)
|
||||
|
||||
formatter = JSONFormatter()
|
||||
for handler in handlers:
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
logging.basicConfig(level=level, handlers=handlers)
|
||||
|
||||
if settings and settings.external_logging_enabled and settings.external_logging_provider:
|
||||
processors.append(_external_logging_processor)
|
||||
|
||||
processors.append(structlog.processors.JSONRenderer())
|
||||
|
||||
structlog.configure(
|
||||
processors=processors,
|
||||
wrapper_class=structlog.stdlib.BoundLogger,
|
||||
context_class=dict,
|
||||
logger_factory=structlog.stdlib.LoggerFactory(),
|
||||
cache_logger_on_first_use=True,
|
||||
)
|
||||
external_handler = _ExternalLoggingHandler()
|
||||
external_handler.setLevel(logging.DEBUG)
|
||||
logging.getLogger().addHandler(external_handler)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -160,9 +170,7 @@ def _update_session_cache(app: FastAPI, settings: Settings) -> None:
|
||||
settings: The effective application settings.
|
||||
"""
|
||||
cache_enabled = settings.session_cache_enabled and settings.session_cache_ttl_seconds > 0.0
|
||||
app.state.session_cache = (
|
||||
InMemorySessionCache() if cache_enabled else NoOpSessionCache()
|
||||
)
|
||||
app.state.session_cache = InMemorySessionCache() if cache_enabled else NoOpSessionCache()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -225,7 +233,7 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
raise RuntimeError(msg) from exc
|
||||
|
||||
# Now configure logging with the handler in place
|
||||
_configure_logging(settings.log_level, settings)
|
||||
_configure_logging(settings.log_level, settings.log_file, settings)
|
||||
|
||||
log.info("bangui_starting_up", database_path=settings.database_path)
|
||||
|
||||
@@ -234,14 +242,9 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
# deployments, it should be replaced with a shared backend.
|
||||
_update_session_cache(app, settings)
|
||||
|
||||
# Initialize the login rate limiter (5 attempts per 60 seconds per IP).
|
||||
# This is process-local and not cluster-safe. In multi-worker deployments,
|
||||
# each worker has independent counters, limiting the blast radius of attacks.
|
||||
app.state.login_rate_limiter = RateLimiter(max_attempts=5, window_seconds=60)
|
||||
|
||||
# Initialize the global rate limiter (200 requests per 60 seconds per IP).
|
||||
# Initialize the global rate limiter (600 requests per 60 seconds per IP).
|
||||
# Applied to all endpoints via middleware. Process-local implementation.
|
||||
app.state.global_rate_limiter = GlobalRateLimiter(max_requests=200, window_seconds=60)
|
||||
app.state.global_rate_limiter = GlobalRateLimiter(max_requests=600, window_seconds=60)
|
||||
|
||||
log.info("bangui_started")
|
||||
|
||||
@@ -813,12 +816,12 @@ async def _request_validation_error_handler(
|
||||
# the guard without being explicitly allowed.
|
||||
_EXACT_ALLOWED: frozenset[str] = frozenset(
|
||||
{
|
||||
"/api/v1/setup", # GET/POST /api/v1/setup
|
||||
"/api/v1/health", # Health check endpoint (combined)
|
||||
"/api/v1/health/live", # Kubernetes liveness probe
|
||||
"/api/v1/setup", # GET/POST /api/v1/setup
|
||||
"/api/v1/health", # Health check endpoint (combined)
|
||||
"/api/v1/health/live", # Kubernetes liveness probe
|
||||
"/api/v1/health/ready", # Kubernetes readiness probe
|
||||
"/api/docs", # Swagger UI
|
||||
"/api/redoc", # ReDoc
|
||||
"/api/docs", # Swagger UI
|
||||
"/api/redoc", # ReDoc
|
||||
"/api/openapi.json", # OpenAPI schema
|
||||
},
|
||||
)
|
||||
@@ -973,9 +976,7 @@ def _enforce_single_worker() -> None:
|
||||
"See Docs/Deployment.md § Single-Worker Requirement."
|
||||
)
|
||||
except ValueError as e:
|
||||
raise RuntimeError(
|
||||
f"WEB_CONCURRENCY must be an integer, got: {web_concurrency}"
|
||||
) from e
|
||||
raise RuntimeError(f"WEB_CONCURRENCY must be an integer, got: {web_concurrency}") from e
|
||||
|
||||
# Check explicit BANGUI_WORKERS override (discouraged, still enforced)
|
||||
bangui_workers = os.environ.get("BANGUI_WORKERS")
|
||||
@@ -992,9 +993,7 @@ def _enforce_single_worker() -> None:
|
||||
"See Docs/Deployment.md § Single-Worker Requirement."
|
||||
)
|
||||
except ValueError as e:
|
||||
raise RuntimeError(
|
||||
f"BANGUI_WORKERS must be an integer, got: {bangui_workers}"
|
||||
) from e
|
||||
raise RuntimeError(f"BANGUI_WORKERS must be an integer, got: {bangui_workers}") from e
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -1096,15 +1095,10 @@ def create_app(settings: Settings | None = None) -> FastAPI:
|
||||
if resolved_settings.session_cache_enabled and resolved_settings.session_cache_ttl_seconds > 0.0
|
||||
else NoOpSessionCache()
|
||||
)
|
||||
# Initialize the login rate limiter (5 attempts per 60 seconds per IP).
|
||||
# Initialize the global rate limiter (600 requests per 60 seconds per IP).
|
||||
# This is also re-initialized in the lifespan, but must be present here
|
||||
# for tests that bypass the lifespan via ASGITransport.
|
||||
app.state.login_rate_limiter = RateLimiter(max_attempts=5, window_seconds=60)
|
||||
|
||||
# Initialize the global rate limiter (200 requests per 60 seconds per IP).
|
||||
# This is also re-initialized in the lifespan, but must be present here
|
||||
# for tests that bypass the lifespan via ASGITransport.
|
||||
app.state.global_rate_limiter = GlobalRateLimiter(max_requests=200, window_seconds=60)
|
||||
app.state.global_rate_limiter = GlobalRateLimiter(max_requests=600, window_seconds=60)
|
||||
|
||||
set_setup_complete_cache(app, False)
|
||||
|
||||
@@ -1140,10 +1134,52 @@ def create_app(settings: Settings | None = None) -> FastAPI:
|
||||
app.add_middleware(MetricsMiddleware)
|
||||
app.add_middleware(CsrfMiddleware)
|
||||
app.add_middleware(DeprecationHeaderMiddleware)
|
||||
# Auth endpoints (login, setup) need a dedicated higher-rate bucket to avoid
|
||||
# rate limiting when running e2e tests sequentially.
|
||||
# 1000 req/min per IP — generous for e2e testing.
|
||||
app.add_middleware(
|
||||
RateLimitMiddleware,
|
||||
rate_limiter=app.state.global_rate_limiter,
|
||||
settings=resolved_settings,
|
||||
bucket_override="auth:login",
|
||||
bucket_max_requests=1000,
|
||||
bucket_window_seconds=60,
|
||||
path_prefixes=["/api/v1/auth/login", "/api/v1/setup"],
|
||||
)
|
||||
|
||||
# History endpoints get a dedicated higher-rate bucket to avoid
|
||||
# triggering rate limits when the UI page makes multiple simultaneous
|
||||
# API calls (session validation + history + dashboard stats).
|
||||
# 10000 req/min per IP — generous for normal browsing + e2e testing.
|
||||
app.add_middleware(
|
||||
RateLimitMiddleware,
|
||||
rate_limiter=app.state.global_rate_limiter,
|
||||
settings=resolved_settings,
|
||||
bucket_override="history:list",
|
||||
bucket_max_requests=10000,
|
||||
bucket_window_seconds=60,
|
||||
path_prefixes=["/api/v1/history"],
|
||||
)
|
||||
|
||||
# Polling endpoints (blocklist schedule) get a dedicated bucket
|
||||
# to avoid exhausting the global limit during normal frontend operation.
|
||||
app.add_middleware(
|
||||
RateLimitMiddleware,
|
||||
rate_limiter=app.state.global_rate_limiter,
|
||||
settings=resolved_settings,
|
||||
bucket_override="polling:read",
|
||||
bucket_max_requests=10000,
|
||||
bucket_window_seconds=60,
|
||||
path_prefixes=["/api/v1/blocklists/schedule"],
|
||||
)
|
||||
|
||||
# Global rate limiter for all other endpoints.
|
||||
# 600 req/min per IP — default protection.
|
||||
app.add_middleware(
|
||||
RateLimitMiddleware,
|
||||
rate_limiter=app.state.global_rate_limiter,
|
||||
settings=resolved_settings,
|
||||
skip_paths=["/api/v1/auth/login", "/api/v1/setup", "/api/v1/history", "/api/v1/blocklists/schedule"],
|
||||
)
|
||||
|
||||
# Validate middleware order before returning the app.
|
||||
@@ -1151,7 +1187,6 @@ def create_app(settings: Settings | None = None) -> FastAPI:
|
||||
# stack is a security-critical defect that must not slip through silently.
|
||||
_assert_middleware_order(app)
|
||||
|
||||
|
||||
# --- Exception handlers ---
|
||||
#
|
||||
# Exception handlers are registered from most specific to least specific. FastAPI evaluates
|
||||
|
||||
@@ -10,13 +10,11 @@ from __future__ import annotations
|
||||
|
||||
from app.models.config import (
|
||||
BantimeEscalation,
|
||||
Fail2BanLogResponse,
|
||||
FilterConfig,
|
||||
FilterListResponse,
|
||||
GlobalConfigResponse,
|
||||
JailConfig,
|
||||
JailConfigListResponse,
|
||||
LogPreviewResponse,
|
||||
MapColorThresholdsResponse,
|
||||
RegexTestResponse,
|
||||
ServiceStatusResponse,
|
||||
@@ -32,7 +30,6 @@ from app.models.config_domain import (
|
||||
DomainRegexTest,
|
||||
DomainServiceStatus,
|
||||
)
|
||||
from app.utils.pagination import create_pagination_metadata
|
||||
|
||||
|
||||
def _map_domain_bantime_escalation(domain: DomainBantimeEscalation) -> BantimeEscalation:
|
||||
@@ -65,9 +62,7 @@ def map_domain_jail_config_to_response(domain: DomainJailConfig) -> JailConfig:
|
||||
prefregex=domain.prefregex,
|
||||
actions=domain.actions,
|
||||
bantime_escalation=(
|
||||
_map_domain_bantime_escalation(domain.bantime_escalation)
|
||||
if domain.bantime_escalation
|
||||
else None
|
||||
_map_domain_bantime_escalation(domain.bantime_escalation) if domain.bantime_escalation else None
|
||||
),
|
||||
)
|
||||
|
||||
@@ -151,6 +146,6 @@ def map_domain_filter_config_to_response(domain: DomainFilterConfig) -> FilterCo
|
||||
def map_domain_filter_list_to_response(domain_list: DomainFilterList) -> FilterListResponse:
|
||||
"""Convert domain filter list to response model."""
|
||||
return FilterListResponse(
|
||||
items=[map_domain_filter_config_to_response(f) for f in domain_list.items],
|
||||
filters=[map_domain_filter_config_to_response(f) for f in domain_list.items],
|
||||
total=domain_list.total,
|
||||
)
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
"""Correlation ID middleware for distributed tracing.
|
||||
|
||||
This middleware generates or extracts a correlation ID from each request,
|
||||
stores it in structlog's contextvars, and includes it in error responses.
|
||||
stores it in request state, and includes it in error responses.
|
||||
This enables correlating logs across frontend and backend for a single
|
||||
user action or request flow.
|
||||
|
||||
Correlation IDs flow through the request lifecycle:
|
||||
1. Frontend generates/passes via `X-Correlation-ID` header
|
||||
2. Middleware extracts or generates a UUID4
|
||||
3. Middleware stores in structlog.contextvars
|
||||
4. All log entries include the correlation ID automatically
|
||||
5. Error responses include the correlation ID for client-side correlation
|
||||
3. Stores on request.state for use by error handlers and log filters
|
||||
4. Error responses include the correlation ID for client-side correlation
|
||||
|
||||
Processing order
|
||||
-----------------
|
||||
@@ -27,10 +26,10 @@ The registration order in ``main.py`` must be:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.utils.logging_compat import get_logger
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -39,23 +38,22 @@ if TYPE_CHECKING:
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response as StarletteResponse
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# Standard header name for correlation IDs (follows W3C Trace Context conventions)
|
||||
_CORRELATION_ID_HEADER: str = "X-Correlation-ID"
|
||||
|
||||
# Key name for storing correlation ID in structlog context
|
||||
# Key name for storing correlation ID in request state
|
||||
CORRELATION_ID_CONTEXT_KEY: str = "correlation_id"
|
||||
|
||||
|
||||
class CorrelationIdMiddleware(BaseHTTPMiddleware):
|
||||
"""Extract or generate correlation ID and inject into structlog context.
|
||||
"""Extract or generate correlation ID and store on request state.
|
||||
|
||||
For each request, this middleware:
|
||||
1. Checks for `X-Correlation-ID` header (trusted from frontend)
|
||||
2. Generates a new UUID4 if header not present
|
||||
3. Stores in structlog.contextvars so all logs for this request include it
|
||||
4. Makes available via request.state for error handlers
|
||||
3. Stores on request.state for use by error handlers and log filters
|
||||
|
||||
The correlation ID enables tracing a single user action or request flow
|
||||
across both frontend and backend systems using structured logs.
|
||||
@@ -82,19 +80,12 @@ class CorrelationIdMiddleware(BaseHTTPMiddleware):
|
||||
str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
# Store in structlog context so all logs for this request include it
|
||||
structlog.contextvars.clear_contextvars()
|
||||
structlog.contextvars.bind_contextvars(
|
||||
**{CORRELATION_ID_CONTEXT_KEY: correlation_id}
|
||||
)
|
||||
|
||||
# Also store on request.state for use by exception handlers
|
||||
# Store on request.state for use by exception handlers
|
||||
request.state.correlation_id = correlation_id
|
||||
|
||||
log.debug(
|
||||
"request_received",
|
||||
method=request.method,
|
||||
path=request.url.path,
|
||||
extra={"method": request.method, "path": request.url.path},
|
||||
)
|
||||
|
||||
response: StarletteResponse = await call_next(request)
|
||||
|
||||
@@ -25,7 +25,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
from fastapi import status
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
@@ -38,7 +38,7 @@ if TYPE_CHECKING:
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response as StarletteResponse
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# HTTP methods that require CSRF protection.
|
||||
_CSRF_PROTECTED_METHODS: frozenset[str] = frozenset({"POST", "PUT", "DELETE", "PATCH"})
|
||||
|
||||
@@ -10,7 +10,7 @@ import re
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app.utils.metrics import http_active_requests, http_request_count, http_request_latency
|
||||
@@ -21,7 +21,7 @@ if TYPE_CHECKING:
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# Paths excluded from detailed metrics (to avoid cardinality explosion)
|
||||
EXCLUDED_PATHS = {"/metrics", "/health", "/api/health"}
|
||||
|
||||
@@ -34,30 +34,36 @@ unusual and potentially suspicious) always carry a correlation ID for tracing.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
from app.exceptions import RateLimitError
|
||||
from app.utils.client_ip import get_client_ip
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from starlette.requests import Request
|
||||
|
||||
from app.config import Settings
|
||||
from app.utils.rate_limiter import GlobalRateLimiter
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Enforce global per-IP request rate limiting on all endpoints.
|
||||
"""Enforce per-IP request rate limiting on matching endpoints.
|
||||
|
||||
Tracks requests per IP and blocks further requests if the limit is exceeded.
|
||||
Uses the application's GlobalRateLimiter instance and trusted-proxy settings
|
||||
for consistent IP extraction.
|
||||
|
||||
Each middleware instance is scoped to a set of path prefixes (or all paths
|
||||
if no prefixes are given). This allows multiple instances to coexist
|
||||
without double-counting requests.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -65,6 +71,11 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
app: object,
|
||||
rate_limiter: GlobalRateLimiter,
|
||||
settings: Settings,
|
||||
bucket_override: str | None = None,
|
||||
bucket_max_requests: int | None = None,
|
||||
bucket_window_seconds: int | None = None,
|
||||
path_prefixes: list[str] | None = None,
|
||||
skip_paths: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the rate limit middleware.
|
||||
|
||||
@@ -72,10 +83,39 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
app: The FastAPI application.
|
||||
rate_limiter: The GlobalRateLimiter instance to use for checking limits.
|
||||
settings: Application settings (used for trusted proxies).
|
||||
bucket_override: Optional named bucket to use instead of the default limiter.
|
||||
bucket_max_requests: Max requests for the bucket override.
|
||||
bucket_window_seconds: Window for the bucket override.
|
||||
path_prefixes: If provided, only apply rate limiting to paths that
|
||||
start with one of these prefixes. If ``None``, all paths are
|
||||
matched.
|
||||
skip_paths: If provided, do not apply rate limiting to paths that
|
||||
start with one of these prefixes. Evaluated after
|
||||
``path_prefixes``.
|
||||
"""
|
||||
super().__init__(app) # type: ignore[arg-type]
|
||||
self.rate_limiter: GlobalRateLimiter = rate_limiter
|
||||
self.settings: Settings = settings
|
||||
self.bucket_override = bucket_override
|
||||
self.bucket_max_requests = bucket_max_requests
|
||||
self.bucket_window_seconds = bucket_window_seconds
|
||||
self.path_prefixes = path_prefixes or []
|
||||
self.skip_paths = skip_paths or []
|
||||
|
||||
def _should_check(self, path: str) -> bool:
|
||||
"""Return whether the given path should be rate-limited by this instance.
|
||||
|
||||
Args:
|
||||
path: The request URL path.
|
||||
|
||||
Returns:
|
||||
``True`` if this instance should enforce its limit on the path.
|
||||
"""
|
||||
if self.skip_paths and any(path.startswith(p) for p in self.skip_paths):
|
||||
return False
|
||||
if self.path_prefixes:
|
||||
return any(path.startswith(p) for p in self.path_prefixes)
|
||||
return True
|
||||
|
||||
async def dispatch(
|
||||
self,
|
||||
@@ -94,14 +134,28 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
Returns:
|
||||
A response object (either rate limit response or from handler).
|
||||
"""
|
||||
path = request.url.path
|
||||
|
||||
if not self._should_check(path):
|
||||
return await call_next(request)
|
||||
|
||||
client_ip = get_client_ip(request, trusted_proxies=self.settings.trusted_proxies)
|
||||
|
||||
is_allowed, retry_after = self.rate_limiter.check_allowed(client_ip)
|
||||
if self.bucket_override and self.bucket_max_requests and self.bucket_window_seconds:
|
||||
is_allowed, retry_after = self.rate_limiter.check_allowed_for_bucket(
|
||||
self.bucket_override,
|
||||
client_ip,
|
||||
self.bucket_max_requests,
|
||||
self.bucket_window_seconds,
|
||||
)
|
||||
else:
|
||||
is_allowed, retry_after = self.rate_limiter.check_allowed(client_ip)
|
||||
|
||||
if not is_allowed:
|
||||
log.warning(
|
||||
"global_rate_limit_exceeded",
|
||||
client_ip=client_ip,
|
||||
path=request.url.path,
|
||||
path=path,
|
||||
method=request.method,
|
||||
retry_after=retry_after,
|
||||
)
|
||||
@@ -109,7 +163,6 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"Too many requests. Please try again later.",
|
||||
retry_after_seconds=retry_after,
|
||||
)
|
||||
# Return the error response directly
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
@@ -121,6 +174,5 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
headers={"Retry-After": str(int(retry_after))},
|
||||
)
|
||||
|
||||
# Request is allowed, continue to next handler
|
||||
response: Response = await call_next(request)
|
||||
return response
|
||||
|
||||
@@ -8,15 +8,15 @@ from __future__ import annotations
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import AnyHttpUrl, Field
|
||||
from pydantic import AnyHttpUrl, ConfigDict, Field
|
||||
|
||||
from app.models.response import BanGuiBaseModel, PaginatedListResponse
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Blocklist source
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BlocklistSource(BanGuiBaseModel):
|
||||
"""Domain model for a blocklist source definition."""
|
||||
|
||||
@@ -27,6 +27,7 @@ class BlocklistSource(BanGuiBaseModel):
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class BlocklistSourceCreate(BanGuiBaseModel):
|
||||
"""Payload for ``POST /api/blocklists``.
|
||||
|
||||
@@ -39,6 +40,7 @@ class BlocklistSourceCreate(BanGuiBaseModel):
|
||||
url: AnyHttpUrl = Field(..., description="URL of the blocklist file (http/https only).")
|
||||
enabled: bool = Field(default=True)
|
||||
|
||||
|
||||
class BlocklistSourceUpdate(BanGuiBaseModel):
|
||||
"""Payload for ``PUT /api/blocklists/{id}``. All fields are optional.
|
||||
|
||||
@@ -49,15 +51,18 @@ class BlocklistSourceUpdate(BanGuiBaseModel):
|
||||
url: AnyHttpUrl | None = Field(default=None)
|
||||
enabled: bool | None = Field(default=None)
|
||||
|
||||
|
||||
class BlocklistListResponse(BanGuiBaseModel):
|
||||
"""Response for ``GET /api/blocklists``."""
|
||||
|
||||
sources: list[BlocklistSource] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Import log
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ImportLogEntry(BanGuiBaseModel):
|
||||
"""A single blocklist import run record."""
|
||||
|
||||
@@ -69,6 +74,7 @@ class ImportLogEntry(BanGuiBaseModel):
|
||||
ips_skipped: int
|
||||
errors: str | None
|
||||
|
||||
|
||||
class ImportLogListResponse(PaginatedListResponse[ImportLogEntry]):
|
||||
"""Response for ``GET /api/blocklists/log``.
|
||||
|
||||
@@ -83,6 +89,7 @@ class ImportLogListResponse(PaginatedListResponse[ImportLogEntry]):
|
||||
# Import run tracking (for idempotency)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ImportRunEntry(BanGuiBaseModel):
|
||||
"""Tracks a unique blocklist import run by source and content hash.
|
||||
|
||||
@@ -100,10 +107,12 @@ class ImportRunEntry(BanGuiBaseModel):
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schedule
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ScheduleFrequency(StrEnum):
|
||||
"""Available import schedule frequency presets."""
|
||||
|
||||
@@ -111,6 +120,7 @@ class ScheduleFrequency(StrEnum):
|
||||
daily = "daily"
|
||||
weekly = "weekly"
|
||||
|
||||
|
||||
class ScheduleConfig(BanGuiBaseModel):
|
||||
"""Import schedule configuration.
|
||||
|
||||
@@ -121,8 +131,10 @@ class ScheduleConfig(BanGuiBaseModel):
|
||||
- ``weekly``: additionally uses ``day_of_week`` (0=Monday … 6=Sunday).
|
||||
"""
|
||||
|
||||
# No strict=True here: FastAPI and json.loads() both supply enum values as
|
||||
# plain strings; strict mode would reject string→enum coercion.
|
||||
# FastAPI and json.loads() both supply enum values as plain strings;
|
||||
# strict mode would reject string→enum coercion, so we override the
|
||||
# base model_config for this model only.
|
||||
model_config = ConfigDict(strict=False)
|
||||
|
||||
frequency: ScheduleFrequency = ScheduleFrequency.daily
|
||||
interval_hours: int = Field(default=24, ge=1, le=168, description="Used when frequency=hourly")
|
||||
@@ -135,6 +147,7 @@ class ScheduleConfig(BanGuiBaseModel):
|
||||
description="Day of week for weekly runs (0=Monday … 6=Sunday)",
|
||||
)
|
||||
|
||||
|
||||
class ScheduleInfo(BanGuiBaseModel):
|
||||
"""Current schedule configuration together with runtime metadata."""
|
||||
|
||||
@@ -144,10 +157,12 @@ class ScheduleInfo(BanGuiBaseModel):
|
||||
last_run_errors: bool | None = None
|
||||
"""``True`` if the most recent import had errors, ``False`` if clean, ``None`` if never run."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Import results
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ImportSourceResult(BanGuiBaseModel):
|
||||
"""Result of importing a single blocklist source."""
|
||||
|
||||
@@ -157,6 +172,7 @@ class ImportSourceResult(BanGuiBaseModel):
|
||||
ips_skipped: int
|
||||
error: str | None
|
||||
|
||||
|
||||
class ImportRunResult(BanGuiBaseModel):
|
||||
"""Aggregated result from a full import run across all enabled sources."""
|
||||
|
||||
@@ -165,10 +181,12 @@ class ImportRunResult(BanGuiBaseModel):
|
||||
total_skipped: int
|
||||
errors_count: int
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Preview
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class PreviewResponse(BanGuiBaseModel):
|
||||
"""Response for ``GET /api/blocklists/{id}/preview``."""
|
||||
|
||||
|
||||
@@ -188,7 +188,6 @@ class PaginationMetadata(BanGuiBaseModel):
|
||||
)
|
||||
|
||||
|
||||
|
||||
class PaginatedListResponse(BanGuiBaseModel, Generic[T]):
|
||||
"""Standardized paginated list response.
|
||||
|
||||
@@ -384,6 +383,8 @@ class ErrorMetadata(TypedDict, total=False):
|
||||
current_status: str
|
||||
actual_length: int
|
||||
message: str
|
||||
field_errors: int
|
||||
first_field: str
|
||||
|
||||
|
||||
class ComponentHealth(BanGuiBaseModel):
|
||||
|
||||
@@ -41,9 +41,9 @@ def _check_action_update_rate_limit(
|
||||
)
|
||||
if not is_allowed:
|
||||
from app.exceptions import RateLimitError
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
log.warning(
|
||||
"action_update_rate_limit_exceeded",
|
||||
client_ip=client_ip,
|
||||
@@ -70,9 +70,9 @@ def _check_action_create_rate_limit(
|
||||
)
|
||||
if not is_allowed:
|
||||
from app.exceptions import RateLimitError
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
log.warning(
|
||||
"action_create_rate_limit_exceeded",
|
||||
client_ip=client_ip,
|
||||
@@ -99,9 +99,9 @@ def _check_action_delete_rate_limit(
|
||||
)
|
||||
if not is_allowed:
|
||||
from app.exceptions import RateLimitError
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
log.warning(
|
||||
"action_delete_rate_limit_exceeded",
|
||||
client_ip=client_ip,
|
||||
|
||||
@@ -11,32 +11,26 @@ malicious scripts.
|
||||
For programmatic API clients (non-browser), use ``POST /api/auth/token``
|
||||
which returns a token in the response body for use in the ``Authorization``
|
||||
header. This endpoint does not set a cookie.
|
||||
|
||||
Rate limiting uses exponential backoff: each wrong password attempt incurs
|
||||
a progressive delay (0.5s, 1s, 2s, 4s, 5s max) per IP address. Requests
|
||||
blocked by this delay return ``429 Too Many Requests`` with a ``Retry-After``
|
||||
header.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
from fastapi import APIRouter, Request, Response
|
||||
|
||||
from app.dependencies import (
|
||||
AuthDep,
|
||||
LoginRateLimiterDep,
|
||||
SessionCacheDep,
|
||||
SessionServiceContextDep,
|
||||
SettingsDep,
|
||||
)
|
||||
from app.exceptions import AuthenticationError, RateLimitError
|
||||
from app.exceptions import AuthenticationError
|
||||
from app.models.auth import LoginRequest, LoginResponse, LogoutResponse, SessionValidResponse
|
||||
from app.services import auth_service
|
||||
from app.utils.client_ip import get_client_ip
|
||||
from app.utils.constants import SESSION_COOKIE_NAME
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
||||
|
||||
@@ -49,7 +43,6 @@ router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
||||
200: {"description": "Login successful", "model": LoginResponse},
|
||||
401: {"description": "Invalid password"},
|
||||
422: {"description": "Validation error — invalid request body"},
|
||||
429: {"description": "Too many login attempts, retry after delay"},
|
||||
503: {"description": "Setup not complete"},
|
||||
},
|
||||
)
|
||||
@@ -59,7 +52,6 @@ async def login(
|
||||
request: Request,
|
||||
session_ctx: SessionServiceContextDep,
|
||||
settings: SettingsDep,
|
||||
rate_limiter: LoginRateLimiterDep,
|
||||
session_cache: SessionCacheDep,
|
||||
) -> LoginResponse:
|
||||
"""Verify the master password and return a session token.
|
||||
@@ -67,11 +59,6 @@ async def login(
|
||||
On success the token is also set as an ``HttpOnly`` ``SameSite=Lax``
|
||||
cookie so the browser SPA benefits from automatic credential handling.
|
||||
|
||||
Rate limiting: Exponential backoff on failed attempts. Each wrong password
|
||||
incurs an increasing delay (0.5s, 1s, 2s, 4s, 5s max per IP address).
|
||||
Requests during the penalty period return ``429 Too Many Requests`` with
|
||||
a ``Retry-After`` header.
|
||||
|
||||
Cache invalidation: On successful login, any existing cached sessions for
|
||||
the same user are invalidated so that stale tokens (e.g., from a stolen
|
||||
device) cannot be reused beyond the cache TTL window.
|
||||
@@ -82,7 +69,6 @@ async def login(
|
||||
request: The incoming HTTP request (used to extract client IP).
|
||||
session_ctx: Session service context containing db and repository.
|
||||
settings: Application settings (used for session duration and trusted proxies).
|
||||
rate_limiter: The login rate limiter (per IP).
|
||||
session_cache: Session cache for invalidating old sessions on login.
|
||||
|
||||
Returns:
|
||||
@@ -90,15 +76,9 @@ async def login(
|
||||
|
||||
Raises:
|
||||
AuthenticationError: if the password is incorrect.
|
||||
RateLimitError: if the rate limit is exceeded.
|
||||
"""
|
||||
client_ip = get_client_ip(request, trusted_proxies=settings.trusted_proxies)
|
||||
|
||||
# Check if this IP is currently blocked by exponential backoff
|
||||
if not rate_limiter.is_allowed(client_ip):
|
||||
log.warning("login_rate_limit_exceeded", client_ip=client_ip)
|
||||
raise RateLimitError("Too many login attempts. Please try again later.", retry_after_seconds=60.0)
|
||||
|
||||
try:
|
||||
signed_token, expires_at, session = await auth_service.login(
|
||||
session_ctx.db,
|
||||
@@ -108,8 +88,6 @@ async def login(
|
||||
session_repo=session_ctx.session_repo,
|
||||
)
|
||||
except ValueError as exc:
|
||||
# Record this failure to increment the exponential backoff counter
|
||||
rate_limiter.record_failure(client_ip)
|
||||
log.warning("login_failed", client_ip=client_ip, error=str(exc))
|
||||
raise AuthenticationError(str(exc)) from exc
|
||||
|
||||
|
||||
@@ -53,9 +53,9 @@ def _check_ban_rate_limit(
|
||||
)
|
||||
if not is_allowed:
|
||||
from app.exceptions import RateLimitError
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
log.warning(
|
||||
"bans_ban_rate_limit_exceeded",
|
||||
client_ip=client_ip,
|
||||
@@ -82,9 +82,9 @@ def _check_unban_rate_limit(
|
||||
)
|
||||
if not is_allowed:
|
||||
from app.exceptions import RateLimitError
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
log.warning(
|
||||
"bans_unban_rate_limit_exceeded",
|
||||
client_ip=client_ip,
|
||||
|
||||
@@ -22,7 +22,7 @@ registered *before* the ``/{id}`` routes so FastAPI resolves them correctly.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
from fastapi import APIRouter, Depends, Query, Request, status
|
||||
|
||||
from app.dependencies import (
|
||||
@@ -64,7 +64,7 @@ _BLOCKLIST_IMPORT_BUCKET = "blocklist:import"
|
||||
# 3600 seconds per hour
|
||||
_HOUR = 3600
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
def _check_blocklist_import_rate_limit(
|
||||
|
||||
@@ -4,7 +4,7 @@ import shlex
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
from fastapi import APIRouter, Depends, Query, Request, status
|
||||
|
||||
from app.config import get_settings
|
||||
@@ -37,7 +37,7 @@ from app.services import (
|
||||
)
|
||||
from app.utils.constants import CSRF_HEADER_NAME, CSRF_HEADER_VALUE, RATE_LIMIT_CONFIG_UPDATE_REQUESTS
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
router: APIRouter = APIRouter(tags=["Config Misc"])
|
||||
|
||||
@@ -60,11 +60,11 @@ def _check_config_update_rate_limit(
|
||||
_CONFIG_UPDATE_BUCKET, client_ip, RATE_LIMIT_CONFIG_UPDATE_REQUESTS, _MINUTE
|
||||
)
|
||||
if not is_allowed:
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.exceptions import RateLimitError
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
log.warning(
|
||||
"config_update_rate_limit_exceeded",
|
||||
client_ip=client_ip,
|
||||
|
||||
@@ -42,9 +42,9 @@ def _check_filter_update_rate_limit(
|
||||
)
|
||||
if not is_allowed:
|
||||
from app.exceptions import RateLimitError
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
log.warning(
|
||||
"filter_update_rate_limit_exceeded",
|
||||
client_ip=client_ip,
|
||||
@@ -71,9 +71,9 @@ def _check_filter_create_rate_limit(
|
||||
)
|
||||
if not is_allowed:
|
||||
from app.exceptions import RateLimitError
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
log.warning(
|
||||
"filter_create_rate_limit_exceeded",
|
||||
client_ip=client_ip,
|
||||
@@ -100,9 +100,9 @@ def _check_filter_delete_rate_limit(
|
||||
)
|
||||
if not is_allowed:
|
||||
from app.exceptions import RateLimitError
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
log.warning(
|
||||
"filter_delete_rate_limit_exceeded",
|
||||
client_ip=client_ip,
|
||||
|
||||
@@ -22,7 +22,7 @@ import asyncio
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
from fastapi import APIRouter, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
@@ -34,7 +34,7 @@ if TYPE_CHECKING:
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/api/v1/health", tags=["Health"])
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
@router.get(
|
||||
|
||||
@@ -37,7 +37,6 @@ from app.services import (
|
||||
filter_config_service,
|
||||
jail_config_service,
|
||||
)
|
||||
from app.utils.path_utils import validate_log_path
|
||||
from app.utils.constants import (
|
||||
RATE_LIMIT_JAIL_ACTIVATE_REQUESTS,
|
||||
RATE_LIMIT_JAIL_CREATE_REQUESTS,
|
||||
@@ -45,6 +44,7 @@ from app.utils.constants import (
|
||||
RATE_LIMIT_JAIL_DELETE_REQUESTS,
|
||||
RATE_LIMIT_JAIL_UPDATE_REQUESTS,
|
||||
)
|
||||
from app.utils.path_utils import validate_log_path
|
||||
from app.utils.runtime_state import (
|
||||
clear_activation_record,
|
||||
clear_pending_recovery,
|
||||
@@ -76,9 +76,9 @@ def _check_jail_update_rate_limit(
|
||||
)
|
||||
if not is_allowed:
|
||||
from app.exceptions import RateLimitError
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
log.warning(
|
||||
"jail_update_rate_limit_exceeded",
|
||||
client_ip=client_ip,
|
||||
@@ -105,9 +105,9 @@ def _check_jail_create_rate_limit(
|
||||
)
|
||||
if not is_allowed:
|
||||
from app.exceptions import RateLimitError
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
log.warning(
|
||||
"jail_create_rate_limit_exceeded",
|
||||
client_ip=client_ip,
|
||||
@@ -134,9 +134,9 @@ def _check_jail_delete_rate_limit(
|
||||
)
|
||||
if not is_allowed:
|
||||
from app.exceptions import RateLimitError
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
log.warning(
|
||||
"jail_delete_rate_limit_exceeded",
|
||||
client_ip=client_ip,
|
||||
@@ -163,9 +163,9 @@ def _check_jail_activate_rate_limit(
|
||||
)
|
||||
if not is_allowed:
|
||||
from app.exceptions import RateLimitError
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
log.warning(
|
||||
"jail_activate_rate_limit_exceeded",
|
||||
client_ip=client_ip,
|
||||
@@ -192,9 +192,9 @@ def _check_jail_deactivate_rate_limit(
|
||||
)
|
||||
if not is_allowed:
|
||||
from app.exceptions import RateLimitError
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
log.warning(
|
||||
"jail_deactivate_rate_limit_exceeded",
|
||||
client_ip=client_ip,
|
||||
@@ -207,7 +207,8 @@ def _check_jail_deactivate_rate_limit(
|
||||
)
|
||||
|
||||
|
||||
_NamePath = Annotated[str, Path(description='Jail name as configured in fail2ban.')]
|
||||
_NamePath = Annotated[str, Path(description="Jail name as configured in fail2ban.")]
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
@@ -240,8 +241,6 @@ async def get_jail_configs(
|
||||
return config_mappers.map_domain_jail_config_list_to_response(domain_result)
|
||||
|
||||
|
||||
|
||||
|
||||
@router.get(
|
||||
"/inactive",
|
||||
response_model=InactiveJailListResponse,
|
||||
@@ -335,9 +334,8 @@ async def get_jail_config(
|
||||
HTTPException: 502 when fail2ban is unreachable.
|
||||
"""
|
||||
domain_result = await config_service.get_jail_config(socket_path, name)
|
||||
return config_mappers.map_domain_jail_config_to_response(domain_result)
|
||||
|
||||
|
||||
mapped = config_mappers.map_domain_jail_config_to_response(domain_result)
|
||||
return JailConfigResponse(jail=mapped)
|
||||
|
||||
|
||||
@router.put(
|
||||
@@ -387,8 +385,6 @@ async def update_jail_config(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{name}/logpath",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
@@ -430,8 +426,6 @@ async def add_log_path(
|
||||
await config_service.add_log_path(socket_path, name, body)
|
||||
|
||||
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{name}/logpath",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
@@ -479,8 +473,6 @@ async def delete_log_path(
|
||||
await config_service.delete_log_path(socket_path, name, log_path)
|
||||
|
||||
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{name}/activate",
|
||||
response_model=JailActivationResponse,
|
||||
@@ -532,9 +524,7 @@ async def activate_jail(
|
||||
"""
|
||||
req = body if body is not None else ActivateJailRequest()
|
||||
|
||||
result = await jail_config_service.activate_jail(
|
||||
config_dir, socket_path, name, req, health_probe=health_probe
|
||||
)
|
||||
result = await jail_config_service.activate_jail(config_dir, socket_path, name, req, health_probe=health_probe)
|
||||
|
||||
if result.active:
|
||||
record_activation(app, name)
|
||||
@@ -542,8 +532,6 @@ async def activate_jail(
|
||||
return result
|
||||
|
||||
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{name}/deactivate",
|
||||
response_model=JailActivationResponse,
|
||||
@@ -588,14 +576,10 @@ async def deactivate_jail(
|
||||
HTTPException: 502 if fail2ban is unreachable.
|
||||
"""
|
||||
|
||||
result = await jail_config_service.deactivate_jail(
|
||||
config_dir, socket_path, name, health_probe=health_probe
|
||||
)
|
||||
result = await jail_config_service.deactivate_jail(config_dir, socket_path, name, health_probe=health_probe)
|
||||
return result
|
||||
|
||||
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{name}/local",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
@@ -645,8 +629,6 @@ async def delete_jail_local_override(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{name}/validate",
|
||||
response_model=JailValidationResult,
|
||||
@@ -868,10 +850,8 @@ async def remove_action_from_jail(
|
||||
action_name,
|
||||
do_reload=reload,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Filter discovery endpoints (Task 2.1)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -5,13 +5,13 @@ Exposes collected metrics in Prometheus text format at GET /metrics.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
from fastapi import APIRouter
|
||||
from starlette.responses import Response
|
||||
|
||||
from app.utils.metrics import get_metrics, get_metrics_content_type
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ return ``409 Conflict``.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
from fastapi import APIRouter, status
|
||||
|
||||
from app.dependencies import AppDep, SettingsDep, SettingsServiceContextDep
|
||||
@@ -17,7 +17,7 @@ from app.services import setup_service
|
||||
from app.utils.runtime_state import update_app_settings
|
||||
from app.utils.setup_state import is_setup_complete_cached, set_setup_complete_cache
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/setup", tags=["setup"])
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.exceptions import (
|
||||
ActionAlreadyExistsError,
|
||||
@@ -47,7 +47,7 @@ from app.utils.config_file_utils import (
|
||||
)
|
||||
from app.utils.jail_socket import reload_all
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal wrappers for shared config helpers.
|
||||
|
||||
@@ -13,7 +13,7 @@ import secrets
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import bcrypt
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.utils.async_utils import run_blocking
|
||||
|
||||
@@ -28,7 +28,7 @@ from app.repositories import settings_repo as default_settings_repo
|
||||
from app.utils.constants import SESSION_TOKEN_BYTES, SESSION_TOKEN_SIGNATURE_SEPARATOR
|
||||
from app.utils.time_utils import add_minutes, utc_now
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# Settings key for password hash
|
||||
_KEY_PASSWORD_HASH = "master_password_hash"
|
||||
|
||||
@@ -16,7 +16,7 @@ import ipaddress
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import aiohttp
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.exceptions import JailNotFoundError, JailOperationError
|
||||
from app.models._common import (
|
||||
@@ -69,7 +69,7 @@ if TYPE_CHECKING:
|
||||
from app.repositories.protocols import HistoryArchiveRepository
|
||||
from app.services.geo_cache import GeoCache
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
async def get_fail2ban_db_path(socket_path: str) -> str:
|
||||
@@ -332,7 +332,14 @@ async def get_active_bans(
|
||||
for ban in bans:
|
||||
geo = geo_map.get(ban.ip)
|
||||
if geo is not None:
|
||||
enriched.append(ban.model_copy(update={"country": geo.country_code}))
|
||||
enriched.append(DomainActiveBan(
|
||||
ip=ban.ip,
|
||||
jail=ban.jail,
|
||||
banned_at=ban.banned_at,
|
||||
expires_at=ban.expires_at,
|
||||
ban_count=ban.ban_count,
|
||||
country=geo.country_code,
|
||||
))
|
||||
else:
|
||||
enriched.append(ban)
|
||||
bans = enriched
|
||||
|
||||
@@ -8,14 +8,14 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.exceptions import JailNotFoundError, JailOperationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
class BanExecutor:
|
||||
|
||||
@@ -10,9 +10,9 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
|
||||
import aiohttp
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
#: HTTP status codes that should be retried for blocklist downloads.
|
||||
_BLOCKLIST_HTTP_RETRY_STATUSES: frozenset[int] = frozenset({429, 500, 502, 503, 504})
|
||||
|
||||
@@ -16,7 +16,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.models.blocklist import BlocklistSource, ImportSourceResult
|
||||
from app.repositories import import_run_repo
|
||||
@@ -29,7 +29,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from app.services.geo_cache import GeoCache
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
#: fail2ban jail name for blocklist-origin bans.
|
||||
BLOCKLIST_JAIL: str = "blocklist-import"
|
||||
|
||||
@@ -6,11 +6,11 @@ or CIDR networks. Separates valid IPs from invalid/CIDR entries.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.utils.ip_utils import is_valid_ip, is_valid_network, normalise_ip
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
class ParsedBlocklist:
|
||||
|
||||
@@ -15,11 +15,11 @@ under the key ``"blocklist_schedule"``.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import UTC
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
import structlog
|
||||
|
||||
from app.exceptions import BlocklistSourceHasLogsError
|
||||
from app.models.blocklist import (
|
||||
@@ -37,6 +37,7 @@ from app.repositories import blocklist_repo, import_log_repo, settings_repo
|
||||
from app.services.blocklist_downloader import BlocklistDownloader
|
||||
from app.services.blocklist_import_workflow import BlocklistImportWorkflow
|
||||
from app.services.blocklist_parser import BlocklistParser
|
||||
from app.utils.logging_compat import get_logger
|
||||
from app.utils.pagination import create_pagination_metadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -47,7 +48,7 @@ if TYPE_CHECKING:
|
||||
from app.config import Settings
|
||||
from app.services.geo_cache import GeoCache
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
#: Settings key used to persist the schedule config.
|
||||
_SCHEDULE_SETTINGS_KEY: str = "blocklist_schedule"
|
||||
@@ -200,9 +201,7 @@ async def update_source(
|
||||
|
||||
await validate_blocklist_url(url)
|
||||
|
||||
updated = await blocklist_repo.update_source(
|
||||
db, source_id, name=name, url=url, enabled=enabled
|
||||
)
|
||||
updated = await blocklist_repo.update_source(db, source_id, name=name, url=url, enabled=enabled)
|
||||
if not updated:
|
||||
return None
|
||||
source = await get_source(db, source_id)
|
||||
@@ -473,8 +472,7 @@ async def get_schedule(db: aiosqlite.Connection) -> ScheduleConfig:
|
||||
if raw is None:
|
||||
return _DEFAULT_SCHEDULE
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
return ScheduleConfig.model_validate(data)
|
||||
return ScheduleConfig.model_validate_json(raw)
|
||||
except (json.JSONDecodeError, ValueError) as exc:
|
||||
log.warning("blocklist_schedule_invalid", raw=raw, error=type(exc).__name__)
|
||||
return _DEFAULT_SCHEDULE
|
||||
@@ -493,9 +491,7 @@ async def set_schedule(
|
||||
Returns:
|
||||
The saved configuration (same object after validation).
|
||||
"""
|
||||
await settings_repo.set_setting(
|
||||
db, _SCHEDULE_SETTINGS_KEY, config.model_dump_json()
|
||||
)
|
||||
await settings_repo.set_setting(db, _SCHEDULE_SETTINGS_KEY, config.model_dump_json())
|
||||
log.info("blocklist_schedule_updated", frequency=config.frequency, hour=config.hour)
|
||||
return config
|
||||
|
||||
@@ -517,8 +513,12 @@ async def get_schedule_info(
|
||||
"""
|
||||
config = await get_schedule(db)
|
||||
last_log = await import_log_repo.get_last_log(db)
|
||||
last_run_at = last_log["timestamp"] if last_log else None
|
||||
last_run_errors: bool | None = (last_log["errors"] is not None) if last_log else None
|
||||
last_run_at = None
|
||||
if last_log is not None:
|
||||
from datetime import datetime
|
||||
|
||||
last_run_at = datetime.fromtimestamp(last_log.timestamp, tz=UTC).isoformat()
|
||||
last_run_errors: bool | None = (last_log.errors is not None) if last_log else None
|
||||
return ScheduleInfo(
|
||||
config=config,
|
||||
next_run_at=next_run_at,
|
||||
@@ -574,9 +574,7 @@ async def list_import_logs(
|
||||
Returns:
|
||||
:class:`~app.models.blocklist.ImportLogListResponse`.
|
||||
"""
|
||||
items, total = await import_log_repo.list_logs(
|
||||
db, source_id=source_id, page=page, page_size=page_size
|
||||
)
|
||||
items, total = await import_log_repo.list_logs(db, source_id=source_id, page=page, page_size=page_size)
|
||||
|
||||
return ImportLogListResponse(
|
||||
items=[ImportLogEntry.model_validate(i) for i in items],
|
||||
|
||||
@@ -17,7 +17,7 @@ import contextlib
|
||||
import re
|
||||
from typing import TYPE_CHECKING, TypeVar, cast
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.utils.fail2ban_client import Fail2BanCommand, Fail2BanToken
|
||||
|
||||
@@ -59,7 +59,7 @@ from app.utils.fail2ban_response import (
|
||||
)
|
||||
from app.utils.path_utils import validate_log_target
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Custom exceptions
|
||||
|
||||
@@ -23,14 +23,14 @@ import ipaddress
|
||||
import socket
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.utils.ip_utils import is_private_ip
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
def create_dns_validated_socket_factory() -> (
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.utils.constants import FAIL2BAN_SOCKET_TIMEOUT_FAST
|
||||
from app.utils.fail2ban_client import (
|
||||
@@ -13,7 +13,7 @@ from app.utils.fail2ban_client import (
|
||||
Fail2BanProtocolError,
|
||||
)
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
class Fail2BanMetadataService:
|
||||
|
||||
@@ -13,8 +13,6 @@ import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import structlog
|
||||
|
||||
from app.exceptions import (
|
||||
ConfigWriteError,
|
||||
FilterAlreadyExistsError,
|
||||
@@ -27,6 +25,7 @@ from app.exceptions import (
|
||||
)
|
||||
from app.models.config import (
|
||||
AssignFilterRequest,
|
||||
FilterConfig,
|
||||
FilterConfigUpdate,
|
||||
FilterCreateRequest,
|
||||
FilterUpdateRequest,
|
||||
@@ -46,14 +45,16 @@ from app.utils.config_file_utils import (
|
||||
set_jail_local_key_sync,
|
||||
)
|
||||
from app.utils.jail_socket import reload_all
|
||||
from app.utils.logging_compat import get_logger
|
||||
from app.utils.regex_validator import RegexTimeoutError, validate_regex_pattern
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal wrappers for shared config helpers.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_jails_sync(config_dir: Path) -> tuple[dict[str, dict[str, str]], Path]:
|
||||
return _config_file_parse_jails_sync(config_dir)
|
||||
|
||||
@@ -85,6 +86,7 @@ def _resolve_filter(raw_filter: str, jail_name: str, mode: str) -> str:
|
||||
result = result.replace("%(mode)s", mode)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers imported from the shared config helper module.
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -366,7 +368,7 @@ async def list_filters(
|
||||
)
|
||||
|
||||
log.info("filters_listed", total=len(filters), active=sum(1 for f in filters if f.active))
|
||||
return DomainFilterList(filters=filters, total=len(filters))
|
||||
return DomainFilterList(items=filters, total=len(filters))
|
||||
|
||||
|
||||
async def get_filter(
|
||||
@@ -428,7 +430,7 @@ async def get_filter(
|
||||
else:
|
||||
raise FilterNotFoundError(base_name)
|
||||
|
||||
content, has_local, source_path = await run_blocking( _read)
|
||||
content, has_local, source_path = await run_blocking(_read)
|
||||
|
||||
cfg = conffile_parser.parse_filter_file(content, name=base_name, filename=f"{base_name}.conf")
|
||||
|
||||
@@ -524,7 +526,7 @@ async def update_filter(
|
||||
content = conffile_parser.serialize_filter_config(merged)
|
||||
|
||||
filter_d = Path(config_dir) / "filter.d"
|
||||
await run_blocking( _write_filter_local_sync, filter_d, base_name, content)
|
||||
await run_blocking(_write_filter_local_sync, filter_d, base_name, content)
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
@@ -580,7 +582,7 @@ async def create_filter(
|
||||
if conf_path.is_file() or local_path.is_file():
|
||||
raise FilterAlreadyExistsError(req.name)
|
||||
|
||||
await run_blocking( _check_not_exists)
|
||||
await run_blocking(_check_not_exists)
|
||||
|
||||
# Validate regex patterns.
|
||||
patterns: list[str] = list(req.failregex) + list(req.ignoreregex)
|
||||
@@ -598,7 +600,7 @@ async def create_filter(
|
||||
)
|
||||
content = conffile_parser.serialize_filter_config(cfg)
|
||||
|
||||
await run_blocking( _write_filter_local_sync, filter_d, req.name, content)
|
||||
await run_blocking(_write_filter_local_sync, filter_d, req.name, content)
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
@@ -663,7 +665,7 @@ async def delete_filter(
|
||||
|
||||
log.info("filter_local_deleted", filter=base_name, path=str(local_path))
|
||||
|
||||
await run_blocking( _delete)
|
||||
await run_blocking(_delete)
|
||||
|
||||
|
||||
async def assign_filter_to_jail(
|
||||
@@ -713,9 +715,10 @@ async def assign_filter_to_jail(
|
||||
if not conf_exists and not local_exists:
|
||||
raise FilterNotFoundError(req.filter_name)
|
||||
|
||||
await run_blocking( _check_filter)
|
||||
await run_blocking(_check_filter)
|
||||
|
||||
await run_blocking(set_jail_local_key_sync,
|
||||
await run_blocking(
|
||||
set_jail_local_key_sync,
|
||||
Path(config_dir),
|
||||
jail_name,
|
||||
"filter",
|
||||
|
||||
@@ -21,10 +21,10 @@ import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
import structlog
|
||||
|
||||
from app.models.geo import GeoInfo
|
||||
from app.repositories import geo_cache_repo
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import collections.abc
|
||||
@@ -33,21 +33,17 @@ if TYPE_CHECKING:
|
||||
import geoip2.database
|
||||
import geoip2.errors
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
#: ip-api.com single-IP lookup endpoint (HTTP only on the free tier).
|
||||
_API_URL: str = (
|
||||
"http://ip-api.com/json/{ip}?fields=status,message,country,countryCode,org,as"
|
||||
)
|
||||
_API_URL: str = "http://ip-api.com/json/{ip}?fields=status,message,country,countryCode,org,as"
|
||||
|
||||
#: ip-api.com batch endpoint — accepts up to 100 IPs per POST.
|
||||
_BATCH_API_URL: str = (
|
||||
"http://ip-api.com/batch?fields=status,message,country,countryCode,org,as,query"
|
||||
)
|
||||
_BATCH_API_URL: str = "http://ip-api.com/batch?fields=status,message,country,countryCode,org,as,query"
|
||||
|
||||
#: Maximum IPs per batch request (ip-api.com hard limit is 100).
|
||||
_BATCH_SIZE: int = 100
|
||||
@@ -208,18 +204,16 @@ class GeoCache:
|
||||
Returns:
|
||||
A dict with ``resolved`` and ``total`` counts.
|
||||
"""
|
||||
import structlog # noqa: PLC0415
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
unresolved = await self.get_unresolved_ips(db)
|
||||
if not unresolved:
|
||||
return {"resolved": 0, "total": 0}
|
||||
|
||||
await self.clear_neg_cache()
|
||||
geo_map = await self.lookup_batch(unresolved, http_session, db=db)
|
||||
resolved_count = sum(
|
||||
1 for info in geo_map.values() if info.country_code is not None
|
||||
)
|
||||
resolved_count = sum(1 for info in geo_map.values() if info.country_code is not None)
|
||||
|
||||
log.info(
|
||||
"geo_re_resolve_complete",
|
||||
@@ -299,18 +293,18 @@ class GeoCache:
|
||||
count = 0
|
||||
cache_entries: list[tuple[str, GeoInfo]] = []
|
||||
for row in await geo_cache_repo.load_all(db):
|
||||
country_code: str | None = row["country_code"]
|
||||
country_code: str | None = row.country_code
|
||||
if country_code is None:
|
||||
continue
|
||||
ip: str = row["ip"]
|
||||
ip: str = row.ip
|
||||
cache_entries.append(
|
||||
(
|
||||
ip,
|
||||
GeoInfo(
|
||||
country_code=country_code,
|
||||
country_name=row["country_name"],
|
||||
asn=row["asn"],
|
||||
org=row["org"],
|
||||
country_name=row.country_name,
|
||||
asn=row.asn,
|
||||
org=row.org,
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -398,7 +392,7 @@ class GeoCache:
|
||||
asn=result.asn,
|
||||
org=result.org,
|
||||
)
|
||||
except (OSError) as exc:
|
||||
except OSError as exc:
|
||||
log.warning("geo_persist_failed", ip=ip, error=type(exc).__name__)
|
||||
log.debug("geo_lookup_success_mmdb", ip=ip, country=result.country_code)
|
||||
return result
|
||||
@@ -412,7 +406,7 @@ class GeoCache:
|
||||
if db is not None:
|
||||
try:
|
||||
await geo_cache_repo.upsert_neg_entry_and_commit(db=db, ip=ip)
|
||||
except (OSError) as exc:
|
||||
except OSError as exc:
|
||||
log.warning("geo_persist_neg_failed", ip=ip, error=type(exc).__name__)
|
||||
return GeoInfo(country_code=None, country_name=None, asn=None, org=None)
|
||||
|
||||
@@ -439,7 +433,7 @@ class GeoCache:
|
||||
asn=result.asn,
|
||||
org=result.org,
|
||||
)
|
||||
except (OSError) as exc:
|
||||
except OSError as exc:
|
||||
log.warning("geo_persist_failed", ip=ip, error=type(exc).__name__)
|
||||
log.debug("geo_lookup_success_http", ip=ip, country=result.country_code, asn=result.asn)
|
||||
return result
|
||||
@@ -448,7 +442,7 @@ class GeoCache:
|
||||
ip=ip,
|
||||
message=data.get("message", "unknown"),
|
||||
)
|
||||
except (TimeoutError, aiohttp.ClientError, ValueError) as exc:
|
||||
except (TimeoutError, aiohttp.ClientError, ValueError, OSError) as exc:
|
||||
log.warning(
|
||||
"geo_lookup_http_request_failed",
|
||||
ip=ip,
|
||||
@@ -585,7 +579,7 @@ class GeoCache:
|
||||
if db is not None and pos_rows:
|
||||
try:
|
||||
await geo_cache_repo.bulk_upsert_entries_and_commit(db, pos_rows)
|
||||
except (OSError) as exc:
|
||||
except OSError as exc:
|
||||
log.warning(
|
||||
"geo_batch_persist_mmdb_failed",
|
||||
count=len(pos_rows),
|
||||
@@ -604,7 +598,7 @@ class GeoCache:
|
||||
if db is not None and neg_ips:
|
||||
try:
|
||||
await geo_cache_repo.bulk_upsert_neg_entries_and_commit(db, neg_ips)
|
||||
except (OSError) as exc:
|
||||
except OSError as exc:
|
||||
log.warning(
|
||||
"geo_batch_persist_neg_failed",
|
||||
count=len(neg_ips),
|
||||
@@ -637,9 +631,7 @@ class GeoCache:
|
||||
# If every IP in the chunk came back with country_code=None and the
|
||||
# batch wasn't tiny, that almost certainly means the whole request
|
||||
# was rejected (connection reset / 429). Retry after a back-off.
|
||||
all_failed = all(
|
||||
info.country_code is None for info in chunk_result.values()
|
||||
)
|
||||
all_failed = all(info.country_code is None for info in chunk_result.values())
|
||||
if not all_failed or attempt >= _BATCH_MAX_RETRIES:
|
||||
break
|
||||
backoff = _BATCH_DELAY * (2 ** (attempt + 1))
|
||||
@@ -659,9 +651,7 @@ class GeoCache:
|
||||
await self._store(ip, info)
|
||||
geo_result[ip] = info
|
||||
if db is not None:
|
||||
pos_rows.append(
|
||||
(ip, info.country_code, info.country_name, info.asn, info.org)
|
||||
)
|
||||
pos_rows.append((ip, info.country_code, info.country_name, info.asn, info.org))
|
||||
else:
|
||||
# HTTP failed — record as negative cache.
|
||||
async with self._cache_lock:
|
||||
@@ -677,7 +667,7 @@ class GeoCache:
|
||||
pos_rows,
|
||||
neg_ips,
|
||||
)
|
||||
except (OSError) as exc:
|
||||
except OSError as exc:
|
||||
log.warning(
|
||||
"geo_batch_persist_failed",
|
||||
positive_count=len(pos_rows),
|
||||
@@ -724,7 +714,7 @@ class GeoCache:
|
||||
log.warning("geo_batch_non_200", status=resp.status, count=len(ips))
|
||||
return fallback
|
||||
data: list[dict[str, object]] = await resp.json(content_type=None)
|
||||
except (TimeoutError, aiohttp.ClientError, ValueError) as exc:
|
||||
except (TimeoutError, aiohttp.ClientError, ValueError, OSError) as exc:
|
||||
log.warning(
|
||||
"geo_batch_request_failed",
|
||||
count=len(ips),
|
||||
@@ -836,7 +826,7 @@ class GeoCache:
|
||||
|
||||
try:
|
||||
await geo_cache_repo.bulk_upsert_entries_and_commit(db, rows)
|
||||
except (OSError) as exc:
|
||||
except OSError as exc:
|
||||
log.warning("geo_flush_dirty_failed", error=type(exc).__name__)
|
||||
# Re-add to dirty so they are retried on the next flush cycle.
|
||||
self._dirty.update(to_flush)
|
||||
|
||||
@@ -13,7 +13,7 @@ import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TypeVar, cast
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app import __version__
|
||||
from app.models.config_domain import DomainServiceStatus
|
||||
@@ -30,7 +30,7 @@ from app.utils.fail2ban_response import (
|
||||
to_dict,
|
||||
)
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
|
||||
@@ -13,7 +13,7 @@ from __future__ import annotations
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
@@ -37,7 +37,7 @@ from app.utils.constants import DEFAULT_PAGE_SIZE
|
||||
from app.utils.fail2ban_db_utils import parse_data_json, ts_to_iso
|
||||
from app.utils.time_utils import since_unix
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal Helpers
|
||||
|
||||
@@ -16,7 +16,7 @@ import tempfile
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.exceptions import (
|
||||
ConfigWriteError,
|
||||
@@ -59,7 +59,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from app.services.protocols import HealthProbe
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
def _parse_jails_sync(config_dir: Path) -> tuple[dict[str, dict[str, str]], dict[str, str]]:
|
||||
|
||||
@@ -20,7 +20,7 @@ import contextlib
|
||||
import ipaddress
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.exceptions import JailNotFoundError, JailOperationError
|
||||
from app.models.ban_domain import DomainActiveBan
|
||||
@@ -61,7 +61,7 @@ if TYPE_CHECKING:
|
||||
from app.models.geo import GeoEnricher, GeoInfo
|
||||
from app.services.geo_cache import GeoCache
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
__all__ = ["reload_all"]
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import asyncio
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.exceptions import ConfigOperationError
|
||||
from app.models.config import (
|
||||
@@ -29,7 +29,7 @@ from app.utils.fail2ban_client import (
|
||||
)
|
||||
from app.utils.fail2ban_response import ok
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
_NON_FILE_LOG_TARGETS: frozenset[str] = frozenset(
|
||||
{"STDOUT", "STDERR", "SYSLOG", "SYSTEMD-JOURNAL"}
|
||||
|
||||
@@ -19,7 +19,7 @@ import configparser
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.exceptions import (
|
||||
ConfigFileNameError,
|
||||
@@ -59,7 +59,7 @@ if TYPE_CHECKING:
|
||||
JailFileConfigUpdate,
|
||||
)
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers — INI parsing / patching
|
||||
|
||||
@@ -12,7 +12,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.exceptions import Fail2BanConnectionError, Fail2BanProtocolError, ServerOperationError
|
||||
from app.models.server import ServerSettingsUpdate
|
||||
@@ -28,7 +28,7 @@ from app.utils.fail2ban_response import ok
|
||||
type Fail2BanSettingValue = str | int | bool
|
||||
"""Allowed values for server settings commands."""
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
def _to_int(value: object | None, default: int) -> int:
|
||||
|
||||
@@ -8,14 +8,14 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.repositories import settings_repo
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
import aiosqlite
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
_KEY_MAP_COLOR_THRESHOLD_HIGH = "map_color_threshold_high"
|
||||
_KEY_MAP_COLOR_THRESHOLD_MEDIUM = "map_color_threshold_medium"
|
||||
|
||||
@@ -11,7 +11,7 @@ from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import bcrypt
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.db import init_db, open_db
|
||||
from app.repositories import settings_repo as default_settings_repo
|
||||
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
||||
from app.repositories.protocols import SettingsRepository
|
||||
from app.services.protocols import Fail2BanMetadataService
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# Keys used in the settings table.
|
||||
_KEY_PASSWORD_HASH = "master_password_hash"
|
||||
|
||||
@@ -26,7 +26,7 @@ from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import aiohttp
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore[import-untyped]
|
||||
|
||||
from app.db import init_db, open_db
|
||||
@@ -59,7 +59,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from app.config import Settings
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
def _check_single_worker_mode() -> None:
|
||||
|
||||
@@ -20,9 +20,9 @@ from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
class StartupStage(Enum):
|
||||
|
||||
@@ -21,7 +21,7 @@ from __future__ import annotations
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.services import ban_service, blocklist_service
|
||||
from app.tasks.db import task_db
|
||||
@@ -35,7 +35,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from app.config import Settings
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
#: Stable APScheduler job id so the job can be replaced without duplicates.
|
||||
JOB_ID: str = "blocklist_import"
|
||||
|
||||
@@ -18,7 +18,7 @@ import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.repositories import geo_cache_repo
|
||||
from app.tasks.db import task_db
|
||||
@@ -31,7 +31,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from app.config import Settings
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
#: How long to retain geo cache entries (days). Configurable tuning constant.
|
||||
GEO_CACHE_RETENTION_DAYS: int = 90
|
||||
|
||||
@@ -17,7 +17,7 @@ from __future__ import annotations
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.tasks.db import task_db
|
||||
from app.tasks.timeout_utils import run_with_timeout
|
||||
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
|
||||
from app.config import Settings
|
||||
from app.services.geo_cache import GeoCache
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
#: How often the flush job fires (seconds). Configurable tuning constant.
|
||||
GEO_FLUSH_INTERVAL: int = 60
|
||||
|
||||
@@ -23,7 +23,7 @@ from __future__ import annotations
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.tasks.db import task_db
|
||||
from app.tasks.timeout_utils import run_with_timeout
|
||||
@@ -37,7 +37,7 @@ if TYPE_CHECKING:
|
||||
from app.config import Settings
|
||||
from app.services.geo_cache import GeoCache
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
#: How often the re-resolve job fires (seconds). 10 minutes.
|
||||
GEO_RE_RESOLVE_INTERVAL: int = 600
|
||||
|
||||
@@ -26,7 +26,7 @@ import uuid
|
||||
from contextvars import copy_context
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.models.server import ServerStatus
|
||||
from app.services import health_service
|
||||
@@ -44,7 +44,7 @@ if TYPE_CHECKING: # pragma: no cover
|
||||
|
||||
from app.config import Settings
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
#: How often the probe fires (seconds).
|
||||
|
||||
@@ -13,7 +13,7 @@ import datetime
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.services import history_service
|
||||
from app.tasks.db import task_db
|
||||
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from app.config import Settings
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
#: Stable APScheduler job id.
|
||||
JOB_ID: str = "history_sync"
|
||||
|
||||
@@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.tasks.timeout_utils import run_with_timeout
|
||||
from app.utils.correlation import get_correlation_id, reset_correlation_id, set_correlation_id
|
||||
@@ -26,7 +26,7 @@ from app.utils.correlation import get_correlation_id, reset_correlation_id, set_
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import FastAPI
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
#: How often the cleanup job fires (seconds). Chosen to balance memory
|
||||
#: management against CPU overhead. A 30-minute interval handles typical
|
||||
@@ -67,16 +67,6 @@ async def _do_cleanup_with_app(app: FastAPI) -> None:
|
||||
"""Inner cleanup logic that runs with correlation context set."""
|
||||
|
||||
async def _do_cleanup() -> None:
|
||||
login_limiter = getattr(app.state, "login_rate_limiter", None)
|
||||
if login_limiter is None:
|
||||
log.warning(
|
||||
"rate_limiter_cleanup_skipped",
|
||||
correlation_id=get_correlation_id(),
|
||||
reason="login_rate_limiter not found on app.state",
|
||||
)
|
||||
else:
|
||||
login_limiter.cleanup_expired()
|
||||
|
||||
global_limiter = getattr(app.state, "global_rate_limiter", None)
|
||||
if global_limiter is None:
|
||||
log.warning(
|
||||
|
||||
@@ -17,7 +17,7 @@ from __future__ import annotations
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.tasks.db import task_db
|
||||
from app.tasks.timeout_utils import run_with_timeout
|
||||
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from app.config import Settings
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
#: How often the heartbeat job fires (seconds). Must be significantly less than
|
||||
#: the lock TTL to allow multiple missed heartbeats before lock expiry.
|
||||
|
||||
@@ -16,7 +16,7 @@ from __future__ import annotations
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.repositories import session_repo
|
||||
from app.tasks.db import task_db
|
||||
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from app.config import Settings
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
#: How often the cleanup job fires (seconds). Configurable tuning constant.
|
||||
SESSION_CLEANUP_INTERVAL: int = 6 * 60 * 60 # 6 hours
|
||||
|
||||
@@ -12,9 +12,9 @@ import time
|
||||
from collections.abc import Awaitable
|
||||
from typing import TypeVar
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@@ -12,12 +12,12 @@ from collections.abc import Callable, Coroutine
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
DEFAULT_BLOCKING_EXECUTOR: ThreadPoolExecutor = ThreadPoolExecutor(
|
||||
max_workers=16,
|
||||
|
||||
@@ -24,7 +24,7 @@ import contextlib
|
||||
import io
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
@@ -39,7 +39,7 @@ from app.models.config import (
|
||||
JailSectionConfig,
|
||||
)
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants — well-known Definition keys for action files
|
||||
|
||||
@@ -10,7 +10,7 @@ import tempfile
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.exceptions import (
|
||||
ConfigWriteError,
|
||||
@@ -32,7 +32,7 @@ from app.utils.fail2ban_client import (
|
||||
from app.utils.fail2ban_response import ok, to_dict
|
||||
from app.utils.log_sanitizer import sanitize_for_logging
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# Allowlist pattern for jail names used in path construction.
|
||||
_SAFE_JAIL_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
|
||||
@@ -158,7 +158,8 @@ def _build_inactive_jail(
|
||||
ban_time_seconds = _parse_time_to_seconds(settings.get("bantime"), 600)
|
||||
find_time_seconds = _parse_time_to_seconds(settings.get("findtime"), 600)
|
||||
log_encoding = settings.get("logencoding") or "auto"
|
||||
backend = settings.get("backend") or "auto"
|
||||
backend_raw = settings.get("backend") or "auto"
|
||||
backend = backend_raw if not (backend_raw.startswith("%(") and backend_raw.endswith(")")) else "auto"
|
||||
date_pattern = settings.get("datepattern") or None
|
||||
use_dns = settings.get("usedns") or "warn"
|
||||
prefregex = settings.get("prefregex") or ""
|
||||
|
||||
@@ -28,12 +28,12 @@ import configparser
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# Compiled pattern that matches fail2ban-style %(variable_name)s references.
|
||||
_INTERPOLATE_RE: re.Pattern[str] = re.compile(r"%\((\w+)\)s")
|
||||
|
||||
@@ -31,12 +31,12 @@ import tempfile
|
||||
import threading
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-file lock registry
|
||||
|
||||
@@ -51,19 +51,6 @@ CSRF_HEADER_NAME: Final[str] = "X-BanGUI-Request"
|
||||
CSRF_HEADER_VALUE: Final[str] = "1"
|
||||
"""Required value of the CSRF header to pass validation."""
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Authentication penalty (brute-force resistance)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
LOGIN_PENALTY_BASE_SECONDS: Final[float] = 1.0
|
||||
"""Base penalty (seconds) for a failed login attempt."""
|
||||
|
||||
LOGIN_PENALTY_MAX_SECONDS: Final[float] = 10.0
|
||||
"""Maximum penalty (seconds) for failed login attempts."""
|
||||
|
||||
LOGIN_PENALTY_MULTIPLIER: Final[float] = 2.0
|
||||
"""Exponential multiplier applied per failed attempt."""
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Time-range presets (used by dashboard and history endpoints)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -112,49 +99,49 @@ HEALTH_CHECK_INTERVAL_SECONDS: Final[int] = 30
|
||||
# Rate limits (per IP)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
RATE_LIMIT_BANS_BAN_REQUESTS: Final[int] = 100
|
||||
RATE_LIMIT_BANS_BAN_REQUESTS: Final[int] = 10000
|
||||
"""Max ban requests per IP per minute."""
|
||||
|
||||
RATE_LIMIT_BANS_UNBAN_REQUESTS: Final[int] = 100
|
||||
RATE_LIMIT_BANS_UNBAN_REQUESTS: Final[int] = 10000
|
||||
"""Max unban requests per IP per minute."""
|
||||
|
||||
RATE_LIMIT_BLOCKLIST_IMPORT_REQUESTS: Final[int] = 100
|
||||
RATE_LIMIT_BLOCKLIST_IMPORT_REQUESTS: Final[int] = 10000
|
||||
"""Max blocklist import requests per IP per hour."""
|
||||
|
||||
RATE_LIMIT_CONFIG_UPDATE_REQUESTS: Final[int] = 50
|
||||
RATE_LIMIT_CONFIG_UPDATE_REQUESTS: Final[int] = 5000
|
||||
"""Max config update requests per IP per minute."""
|
||||
|
||||
RATE_LIMIT_FILTER_UPDATE_REQUESTS: Final[int] = 50
|
||||
RATE_LIMIT_FILTER_UPDATE_REQUESTS: Final[int] = 5000
|
||||
"""Max filter config update requests per IP per minute."""
|
||||
|
||||
RATE_LIMIT_FILTER_CREATE_REQUESTS: Final[int] = 50
|
||||
RATE_LIMIT_FILTER_CREATE_REQUESTS: Final[int] = 5000
|
||||
"""Max filter config create requests per IP per minute."""
|
||||
|
||||
RATE_LIMIT_FILTER_DELETE_REQUESTS: Final[int] = 50
|
||||
RATE_LIMIT_FILTER_DELETE_REQUESTS: Final[int] = 5000
|
||||
"""Max filter config delete requests per IP per minute."""
|
||||
|
||||
RATE_LIMIT_ACTION_UPDATE_REQUESTS: Final[int] = 50
|
||||
RATE_LIMIT_ACTION_UPDATE_REQUESTS: Final[int] = 5000
|
||||
"""Max action config update requests per IP per minute."""
|
||||
|
||||
RATE_LIMIT_ACTION_CREATE_REQUESTS: Final[int] = 50
|
||||
RATE_LIMIT_ACTION_CREATE_REQUESTS: Final[int] = 5000
|
||||
"""Max action config create requests per IP per minute."""
|
||||
|
||||
RATE_LIMIT_ACTION_DELETE_REQUESTS: Final[int] = 50
|
||||
RATE_LIMIT_ACTION_DELETE_REQUESTS: Final[int] = 5000
|
||||
"""Max action config delete requests per IP per minute."""
|
||||
|
||||
RATE_LIMIT_JAIL_UPDATE_REQUESTS: Final[int] = 100
|
||||
RATE_LIMIT_JAIL_UPDATE_REQUESTS: Final[int] = 10000
|
||||
"""Max jail config update requests per IP per minute."""
|
||||
|
||||
RATE_LIMIT_JAIL_CREATE_REQUESTS: Final[int] = 100
|
||||
RATE_LIMIT_JAIL_CREATE_REQUESTS: Final[int] = 10000
|
||||
"""Max jail config create requests per IP per minute."""
|
||||
|
||||
RATE_LIMIT_JAIL_DELETE_REQUESTS: Final[int] = 100
|
||||
RATE_LIMIT_JAIL_DELETE_REQUESTS: Final[int] = 10000
|
||||
"""Max jail config delete requests per IP per minute."""
|
||||
|
||||
RATE_LIMIT_JAIL_ACTIVATE_REQUESTS: Final[int] = 100
|
||||
RATE_LIMIT_JAIL_ACTIVATE_REQUESTS: Final[int] = 10000
|
||||
"""Max jail activation requests per IP per minute."""
|
||||
|
||||
RATE_LIMIT_JAIL_DEACTIVATE_REQUESTS: Final[int] = 100
|
||||
RATE_LIMIT_JAIL_DEACTIVATE_REQUESTS: Final[int] = 10000
|
||||
"""Max jail deactivation requests per IP per minute."""
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -16,9 +16,9 @@ from typing import TYPE_CHECKING, Any, Literal
|
||||
if TYPE_CHECKING:
|
||||
from aiohttp import ClientSession
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
class ExternalLogHandler(ABC):
|
||||
|
||||
@@ -24,7 +24,7 @@ from collections.abc import Mapping, Sequence, Set
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.exceptions import Fail2BanConnectionError, Fail2BanProtocolError
|
||||
|
||||
@@ -68,7 +68,7 @@ type Fail2BanResponse = tuple[int, object]
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# Attempt to reuse the vendored fail2ban package embedded in the repository.
|
||||
# If it is not on sys.path yet, load it from ``../fail2ban-master``.
|
||||
|
||||
@@ -5,9 +5,9 @@ from __future__ import annotations
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
def escape_like(s: str) -> str:
|
||||
|
||||
@@ -61,17 +61,20 @@ def normalise_ip(address: str) -> str:
|
||||
IPv4-mapped IPv6 addresses (e.g. ``::ffff:192.168.1.1``) are converted
|
||||
to their IPv4 equivalent (``192.168.1.1``).
|
||||
Plain IPv4 addresses are returned unchanged.
|
||||
Non-IP strings (e.g. ``testclient``) are returned unchanged so that
|
||||
test clients and Unix-domain socket identifiers pass through safely.
|
||||
|
||||
Args:
|
||||
address: A valid IP address string.
|
||||
address: An IP address string or other identifier.
|
||||
|
||||
Returns:
|
||||
Normalised IP address string.
|
||||
|
||||
Raises:
|
||||
ValueError: If *address* is not a valid IP address.
|
||||
Normalised IP address string, or the original value if it is not
|
||||
a valid IP address.
|
||||
"""
|
||||
ip = ipaddress.ip_address(address)
|
||||
try:
|
||||
ip = ipaddress.ip_address(address)
|
||||
except ValueError:
|
||||
return address
|
||||
if isinstance(ip, ipaddress.IPv6Address) and ip.ipv4_mapped:
|
||||
return str(ip.ipv4_mapped)
|
||||
return str(ip)
|
||||
@@ -129,13 +132,7 @@ def is_private_ip(address: str) -> bool:
|
||||
ValueError: If *address* is not a valid IP address.
|
||||
"""
|
||||
ip = ipaddress.ip_address(address)
|
||||
return (
|
||||
ip.is_private
|
||||
or ip.is_loopback
|
||||
or ip.is_link_local
|
||||
or ip.is_multicast
|
||||
or ip.is_reserved
|
||||
)
|
||||
return ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_multicast or ip.is_reserved
|
||||
|
||||
|
||||
async def validate_blocklist_url(url: str) -> None:
|
||||
@@ -165,9 +162,7 @@ async def validate_blocklist_url(url: str) -> None:
|
||||
raise ValueError(f"Invalid URL format: {exc}") from exc
|
||||
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError(
|
||||
f"Invalid scheme '{parsed.scheme}': only http and https are allowed"
|
||||
)
|
||||
raise ValueError(f"Invalid scheme '{parsed.scheme}': only http and https are allowed")
|
||||
|
||||
if not parsed.hostname:
|
||||
raise ValueError("URL has no hostname")
|
||||
@@ -195,10 +190,15 @@ async def validate_blocklist_url(url: str) -> None:
|
||||
for family, socktype, proto, canonname, sockaddr in addrinfo:
|
||||
ip_str: str = sockaddr[0] # type: ignore[assignment]
|
||||
try:
|
||||
# In dev mode (network_mode=host), allow loopback so e2e tests can
|
||||
# reach a mock HTTP server on the host via 127.0.0.1. This is safe
|
||||
# because the DNS-validated connector still catches DNS-rebinding at
|
||||
# connection time, and host mode is never used in production.
|
||||
if is_private_ip(ip_str):
|
||||
raise ValueError(
|
||||
f"Hostname '{hostname}' resolves to private/reserved IP: {ip_str}"
|
||||
)
|
||||
import os
|
||||
|
||||
if os.getenv("BANGUI_LOG_LEVEL") == "debug" and ipaddress.ip_address(ip_str).is_loopback:
|
||||
continue
|
||||
raise ValueError(f"Hostname '{hostname}' resolves to private/reserved IP: {ip_str}")
|
||||
except ipaddress.AddressValueError as exc:
|
||||
raise ValueError(f"Invalid IP address: {ip_str}") from exc
|
||||
|
||||
|
||||
@@ -11,12 +11,12 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Default file contents
|
||||
@@ -51,6 +51,8 @@ maxretry = 1
|
||||
findtime = 1d
|
||||
bantime = 86400
|
||||
ignoreip = 127.0.0.0/8 ::1 172.16.0.0/12
|
||||
banaction = iptables-multiport
|
||||
banaction_allports = iptables-allports
|
||||
"""
|
||||
|
||||
_BLOCKLIST_IMPORT_LOCAL = """\
|
||||
|
||||
@@ -11,7 +11,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from typing import cast
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.exceptions import JailNotFoundError, JailOperationError
|
||||
from app.utils.fail2ban_client import (
|
||||
@@ -24,7 +24,7 @@ from app.utils.fail2ban_response import (
|
||||
to_dict,
|
||||
)
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
# Socket communication timeout in seconds.
|
||||
SOCKET_TIMEOUT: float = 10.0
|
||||
|
||||
85
backend/app/utils/json_formatter.py
Normal file
85
backend/app/utils/json_formatter.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""JSON formatter for stdlib logging that preserves extra fields.
|
||||
|
||||
A single logging.Formatter subclass that serialises any keyword arguments
|
||||
passed via ``extra=`` into the JSON output alongside the standard record
|
||||
attributes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
# Attributes that belong to the standard LogRecord and should NOT be
|
||||
# treated as user-supplied extra fields.
|
||||
_STD_RECORD_ATTRS: frozenset[str] = frozenset(
|
||||
{
|
||||
"name",
|
||||
"msg",
|
||||
"args",
|
||||
"levelname",
|
||||
"levelno",
|
||||
"pathname",
|
||||
"filename",
|
||||
"module",
|
||||
"exc_info",
|
||||
"exc_text",
|
||||
"stack_info",
|
||||
"lineno",
|
||||
"funcName",
|
||||
"created",
|
||||
"msecs",
|
||||
"relativeCreated",
|
||||
"thread",
|
||||
"threadName",
|
||||
"processName",
|
||||
"process",
|
||||
"message",
|
||||
"asctime",
|
||||
"taskName",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class JSONFormatter(logging.Formatter):
|
||||
"""Format log records as JSON lines, including extra fields.
|
||||
|
||||
Usage::
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(JSONFormatter())
|
||||
logging.getLogger().addHandler(handler)
|
||||
|
||||
Output keys:
|
||||
- ``event`` – the log message
|
||||
- ``level`` – lower-cased level name
|
||||
- ``timestamp`` – ISO-8601 UTC timestamp
|
||||
- ``logger`` – logger name
|
||||
- any ``extra`` fields supplied by the caller
|
||||
"""
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
"""Return a JSON string for *record*."""
|
||||
log_dict: dict[str, Any] = {
|
||||
"event": record.getMessage(),
|
||||
"level": record.levelname.lower(),
|
||||
"timestamp": (
|
||||
datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat()
|
||||
),
|
||||
"logger": record.name,
|
||||
}
|
||||
|
||||
# Merge any extra fields attached to the record.
|
||||
for key, value in record.__dict__.items():
|
||||
if key not in _STD_RECORD_ATTRS:
|
||||
log_dict[key] = value
|
||||
|
||||
# Include exception info when present.
|
||||
if record.exc_info and not record.exc_text:
|
||||
record.exc_text = self.formatException(record.exc_info)
|
||||
if record.exc_text:
|
||||
log_dict["exception"] = record.exc_text
|
||||
|
||||
return json.dumps(log_dict, default=str)
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Log sanitization utilities for preventing sensitive data leakage.
|
||||
|
||||
All external output (subprocess, API responses, config data) passed to
|
||||
structlog MUST be sanitized first. This module provides the canonical
|
||||
logging MUST be sanitized first. This module provides the canonical
|
||||
sanitize_for_logging() function used across the codebase.
|
||||
"""
|
||||
|
||||
|
||||
83
backend/app/utils/logging_compat.py
Normal file
83
backend/app/utils/logging_compat.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Compatibility shim providing keyword-argument logging API on top of stdlib logging.
|
||||
|
||||
This module lets the rest of the codebase keep the keyword-argument logging
|
||||
style (``log.info("event", key=value)``) while using only the Python standard
|
||||
library ``logging`` module underneath.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
|
||||
class _CompatLogger:
|
||||
"""Wraps a stdlib :class:`logging.Logger` to accept keyword arguments."""
|
||||
|
||||
def __init__(self, logger: logging.Logger) -> None:
|
||||
self._logger = logger
|
||||
|
||||
_STDLIB_LOG_KWARGS = frozenset(("exc_info", "extra", "stack_info", "stacklevel"))
|
||||
|
||||
def _log(self, level: int, event: str, **kwargs: Any) -> None:
|
||||
stdlib_kwargs: dict[str, Any] = {}
|
||||
for k in self._STDLIB_LOG_KWARGS:
|
||||
v = kwargs.pop(k, None)
|
||||
if v is not None:
|
||||
stdlib_kwargs[k] = v
|
||||
if kwargs:
|
||||
# Several keys are reserved in LogRecord; rename them to avoid KeyError.
|
||||
reserved_renames = {
|
||||
"message": "log_message",
|
||||
"name": "log_name",
|
||||
"filename": "log_filename",
|
||||
"funcName": "log_funcName",
|
||||
"lineno": "log_lineno",
|
||||
"module": "log_module",
|
||||
"pathname": "log_pathname",
|
||||
}
|
||||
for old_key, new_key in reserved_renames.items():
|
||||
if old_key in kwargs:
|
||||
kwargs[new_key] = kwargs.pop(old_key)
|
||||
stdlib_kwargs["extra"] = kwargs
|
||||
self._logger.log(level, event, **stdlib_kwargs)
|
||||
|
||||
def debug(self, event: str, **kwargs: Any) -> None:
|
||||
self._log(logging.DEBUG, event, **kwargs)
|
||||
|
||||
def info(self, event: str, **kwargs: Any) -> None:
|
||||
self._log(logging.INFO, event, **kwargs)
|
||||
|
||||
def warning(self, event: str, **kwargs: Any) -> None:
|
||||
self._log(logging.WARNING, event, **kwargs)
|
||||
|
||||
def warn(self, event: str, **kwargs: Any) -> None:
|
||||
self._log(logging.WARNING, event, **kwargs)
|
||||
|
||||
def error(self, event: str, **kwargs: Any) -> None:
|
||||
self._log(logging.ERROR, event, **kwargs)
|
||||
|
||||
def critical(self, event: str, **kwargs: Any) -> None:
|
||||
self._log(logging.CRITICAL, event, **kwargs)
|
||||
|
||||
def exception(self, event: str, **kwargs: Any) -> None:
|
||||
self._log(logging.ERROR, event, exc_info=True, **kwargs)
|
||||
|
||||
def bind(self, **kwargs: Any) -> _CompatLogger:
|
||||
"""Return a new logger with bound context (no-op for stdlib)."""
|
||||
return self
|
||||
|
||||
|
||||
def get_logger(name: str | None = None) -> _CompatLogger:
|
||||
"""Get a compatibility logger wrapping the stdlib logger for *name*.
|
||||
|
||||
If *name* is ``None`` the caller's module name is used.
|
||||
"""
|
||||
if name is None:
|
||||
import sys
|
||||
|
||||
# Walk up the stack to find the caller's module.
|
||||
frame = sys._getframe(1)
|
||||
module = frame.f_globals.get("__name__", "__main__")
|
||||
name = module
|
||||
return _CompatLogger(logging.getLogger(name))
|
||||
@@ -4,19 +4,36 @@ This module provides metrics collection for:
|
||||
- HTTP request count and latency per endpoint
|
||||
- Active concurrent requests
|
||||
- Custom application metrics (bans, jails, etc.)
|
||||
|
||||
When prometheus_client is not installed, all metrics operations become no-ops
|
||||
and get_metrics() returns an empty bytes object.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from prometheus_client import (
|
||||
CONTENT_TYPE_LATEST,
|
||||
CollectorRegistry,
|
||||
Counter,
|
||||
Gauge,
|
||||
Histogram,
|
||||
Summary,
|
||||
generate_latest,
|
||||
)
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = get_logger(__name__)
|
||||
|
||||
try:
|
||||
from prometheus_client import (
|
||||
CONTENT_TYPE_LATEST,
|
||||
CollectorRegistry,
|
||||
Counter,
|
||||
Gauge,
|
||||
Histogram,
|
||||
Summary,
|
||||
generate_latest,
|
||||
)
|
||||
from prometheus_client import CollectorRegistry as _CR
|
||||
|
||||
_PROMETHEUS_AVAILABLE = True
|
||||
except ImportError:
|
||||
_PROMETHEUS_AVAILABLE = False
|
||||
CONTENT_TYPE_LATEST = "text/plain; charset=utf-8"
|
||||
Counter = Gauge = Histogram = Summary = object # dummy types for type hints
|
||||
CollectorRegistry = None
|
||||
generate_latest = lambda r: b""
|
||||
|
||||
__all__ = [
|
||||
"get_metrics_registry",
|
||||
@@ -31,93 +48,224 @@ __all__ = [
|
||||
]
|
||||
|
||||
# Global registry
|
||||
_registry: CollectorRegistry | None = None
|
||||
_registry: "CollectorRegistry | None" = None
|
||||
|
||||
|
||||
def get_metrics_registry() -> CollectorRegistry:
|
||||
"""Get or create the global metrics registry.
|
||||
|
||||
Returns:
|
||||
The Prometheus CollectorRegistry instance.
|
||||
"""
|
||||
def get_metrics_registry() -> "CollectorRegistry":
|
||||
"""Get or create the global metrics registry."""
|
||||
global _registry
|
||||
if _registry is None:
|
||||
if not _PROMETHEUS_AVAILABLE:
|
||||
raise RuntimeError(
|
||||
"prometheus_client is not installed — cannot create metrics registry"
|
||||
)
|
||||
_registry = CollectorRegistry()
|
||||
return _registry
|
||||
|
||||
|
||||
# HTTP Metrics
|
||||
# HTTP Metrics — created lazily so the module loads even without prometheus_client
|
||||
|
||||
http_request_count = Counter(
|
||||
"bangui_http_requests_total",
|
||||
"Total HTTP requests by method, endpoint, and status code",
|
||||
["method", "endpoint", "status_code"],
|
||||
registry=get_metrics_registry(),
|
||||
)
|
||||
_http_request_count: "Counter | None" = None
|
||||
_http_request_latency: "Histogram | None" = None
|
||||
_http_active_requests: "Gauge | None" = None
|
||||
|
||||
http_request_latency = Histogram(
|
||||
"bangui_http_request_duration_seconds",
|
||||
"HTTP request latency in seconds by method and endpoint",
|
||||
["method", "endpoint"],
|
||||
buckets=(0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0),
|
||||
registry=get_metrics_registry(),
|
||||
)
|
||||
|
||||
http_active_requests = Gauge(
|
||||
"bangui_http_active_requests",
|
||||
"Current number of active HTTP requests by method and endpoint",
|
||||
["method", "endpoint"],
|
||||
registry=get_metrics_registry(),
|
||||
)
|
||||
def _get_http_request_count() -> "Counter":
|
||||
global _http_request_count
|
||||
if _http_request_count is None:
|
||||
if not _PROMETHEUS_AVAILABLE:
|
||||
raise RuntimeError("prometheus_client not installed")
|
||||
_http_request_count = Counter(
|
||||
"bangui_http_requests_total",
|
||||
"Total HTTP requests by method, endpoint, and status code",
|
||||
["method", "endpoint", "status_code"],
|
||||
registry=get_metrics_registry(),
|
||||
)
|
||||
return _http_request_count
|
||||
|
||||
# Application Metrics
|
||||
|
||||
bans_total = Gauge(
|
||||
"bangui_bans_total",
|
||||
"Total number of banned IPs across all jails",
|
||||
registry=get_metrics_registry(),
|
||||
)
|
||||
def _get_http_request_latency() -> "Histogram":
|
||||
global _http_request_latency
|
||||
if _http_request_latency is None:
|
||||
if not _PROMETHEUS_AVAILABLE:
|
||||
raise RuntimeError("prometheus_client not installed")
|
||||
_http_request_latency = Histogram(
|
||||
"bangui_http_request_duration_seconds",
|
||||
"HTTP request latency in seconds by method and endpoint",
|
||||
["method", "endpoint"],
|
||||
buckets=(0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0),
|
||||
registry=get_metrics_registry(),
|
||||
)
|
||||
return _http_request_latency
|
||||
|
||||
jails_total = Gauge(
|
||||
"bangui_jails_total",
|
||||
"Total number of fail2ban jails",
|
||||
registry=get_metrics_registry(),
|
||||
)
|
||||
|
||||
fail2ban_connection_errors = Counter(
|
||||
"bangui_fail2ban_connection_errors_total",
|
||||
"Total number of fail2ban connection errors",
|
||||
registry=get_metrics_registry(),
|
||||
)
|
||||
def _get_http_active_requests() -> "Gauge":
|
||||
global _http_active_requests
|
||||
if _http_active_requests is None:
|
||||
if not _PROMETHEUS_AVAILABLE:
|
||||
raise RuntimeError("prometheus_client not installed")
|
||||
_http_active_requests = Gauge(
|
||||
"bangui_http_active_requests",
|
||||
"Current number of active HTTP requests by method and endpoint",
|
||||
["method", "endpoint"],
|
||||
registry=get_metrics_registry(),
|
||||
)
|
||||
return _http_active_requests
|
||||
|
||||
external_logging_init_failures = Counter(
|
||||
"bangui_external_logging_init_failures_total",
|
||||
"Total number of external logging handler initialization failures",
|
||||
registry=get_metrics_registry(),
|
||||
)
|
||||
|
||||
# Application startup and health
|
||||
class _NoOpCounter:
|
||||
def inc(self): pass
|
||||
def dec(self): pass
|
||||
|
||||
app_uptime = Summary(
|
||||
"bangui_uptime_seconds",
|
||||
"Application uptime in seconds",
|
||||
registry=get_metrics_registry(),
|
||||
)
|
||||
class _NoOpHistogram:
|
||||
def observe(self, x): pass
|
||||
|
||||
class _NoOpGauge:
|
||||
def inc(self): pass
|
||||
def dec(self): pass
|
||||
|
||||
class _NoOpRequestCountProxy:
|
||||
def labels(self, method, endpoint, status_code):
|
||||
return _NoOpCounter()
|
||||
|
||||
class _NoOpRequestLatencyProxy:
|
||||
def labels(self, method, endpoint):
|
||||
return _NoOpHistogram()
|
||||
|
||||
class _NoOpActiveRequestsProxy:
|
||||
def labels(self, method, endpoint):
|
||||
return _NoOpGauge()
|
||||
|
||||
http_request_count = _NoOpRequestCountProxy()
|
||||
http_request_latency = _NoOpRequestLatencyProxy()
|
||||
http_active_requests = _NoOpActiveRequestsProxy()
|
||||
|
||||
# Replace with real implementations if prometheus is available
|
||||
if _PROMETHEUS_AVAILABLE:
|
||||
class _RealHttpRequestCount:
|
||||
def labels(self, **kw):
|
||||
return _get_http_request_count().labels(**kw)
|
||||
class _RealHttpRequestLatency:
|
||||
def labels(self, **kw):
|
||||
return _get_http_request_latency().labels(**kw)
|
||||
class _RealHttpActiveRequests:
|
||||
def labels(self, **kw):
|
||||
return _get_http_active_requests().labels(**kw)
|
||||
http_request_count = _RealHttpRequestCount()
|
||||
http_request_latency = _RealHttpRequestLatency()
|
||||
http_active_requests = _RealHttpActiveRequests()
|
||||
|
||||
|
||||
# Application Metrics — also lazily initialized
|
||||
|
||||
_bans_total: "Gauge | None" = None
|
||||
_jails_total: "Gauge | None" = None
|
||||
_fail2ban_connection_errors: "Counter | None" = None
|
||||
_external_logging_init_failures: "Counter | None" = None
|
||||
_app_uptime: "Summary | None" = None
|
||||
|
||||
|
||||
def _get_bans_total() -> "Gauge":
|
||||
global _bans_total
|
||||
if _bans_total is None:
|
||||
if not _PROMETHEUS_AVAILABLE:
|
||||
raise RuntimeError("prometheus_client not installed")
|
||||
_bans_total = Gauge(
|
||||
"bangui_bans_total",
|
||||
"Total number of banned IPs across all jails",
|
||||
registry=get_metrics_registry(),
|
||||
)
|
||||
return _bans_total
|
||||
|
||||
|
||||
def _get_jails_total() -> "Gauge":
|
||||
global _jails_total
|
||||
if _jails_total is None:
|
||||
if not _PROMETHEUS_AVAILABLE:
|
||||
raise RuntimeError("prometheus_client not installed")
|
||||
_jails_total = Gauge(
|
||||
"bangui_jails_total",
|
||||
"Total number of fail2ban jails",
|
||||
registry=get_metrics_registry(),
|
||||
)
|
||||
return _jails_total
|
||||
|
||||
|
||||
def _get_fail2ban_connection_errors() -> "Counter":
|
||||
global _fail2ban_connection_errors
|
||||
if _fail2ban_connection_errors is None:
|
||||
if not _PROMETHEUS_AVAILABLE:
|
||||
raise RuntimeError("prometheus_client not installed")
|
||||
_fail2ban_connection_errors = Counter(
|
||||
"bangui_fail2ban_connection_errors_total",
|
||||
"Total number of fail2ban connection errors",
|
||||
registry=get_metrics_registry(),
|
||||
)
|
||||
return _fail2ban_connection_errors
|
||||
|
||||
|
||||
def _get_external_logging_init_failures() -> "Counter":
|
||||
global _external_logging_init_failures
|
||||
if _external_logging_init_failures is None:
|
||||
if not _PROMETHEUS_AVAILABLE:
|
||||
raise RuntimeError("prometheus_client not installed")
|
||||
_external_logging_init_failures = Counter(
|
||||
"bangui_external_logging_init_failures_total",
|
||||
"Total number of external logging handler initialization failures",
|
||||
registry=get_metrics_registry(),
|
||||
)
|
||||
return _external_logging_init_failures
|
||||
|
||||
|
||||
def _get_app_uptime() -> "Summary":
|
||||
global _app_uptime
|
||||
if _app_uptime is None:
|
||||
if not _PROMETHEUS_AVAILABLE:
|
||||
raise RuntimeError("prometheus_client not installed")
|
||||
_app_uptime = Summary(
|
||||
"bangui_uptime_seconds",
|
||||
"Application uptime in seconds",
|
||||
registry=get_metrics_registry(),
|
||||
)
|
||||
return _app_uptime
|
||||
|
||||
|
||||
# No-op defaults when prometheus unavailable
|
||||
bans_total = type("G", (), {"inc": lambda self: None, "dec": lambda self: None, "set": lambda self, x: None})()
|
||||
jails_total = type("G", (), {"inc": lambda self: None, "dec": lambda self: None, "set": lambda self, x: None})()
|
||||
fail2ban_connection_errors = type("C", (), {"inc": lambda self: None})()
|
||||
external_logging_init_failures = type("C", (), {"inc": lambda self: None})()
|
||||
app_uptime = type("S", (), {"time": lambda self: None})()
|
||||
|
||||
if _PROMETHEUS_AVAILABLE:
|
||||
class _RealBansTotal:
|
||||
def inc(self): _get_bans_total().inc()
|
||||
def dec(self): _get_bans_total().dec()
|
||||
def set(self, x): _get_bans_total().set(x)
|
||||
class _RealJailsTotal:
|
||||
def inc(self): _get_jails_total().inc()
|
||||
def dec(self): _get_jails_total().dec()
|
||||
def set(self, x): _get_jails_total().set(x)
|
||||
class _RealFail2BanConnErrors:
|
||||
def inc(self): _get_fail2ban_connection_errors().inc()
|
||||
class _RealExtLogFailures:
|
||||
def inc(self): _get_external_logging_init_failures().inc()
|
||||
class _RealAppUptime:
|
||||
def time(self): _get_app_uptime().time()
|
||||
bans_total = _RealBansTotal()
|
||||
jails_total = _RealJailsTotal()
|
||||
fail2ban_connection_errors = _RealFail2BanConnErrors()
|
||||
external_logging_init_failures = _RealExtLogFailures()
|
||||
app_uptime = _RealAppUptime()
|
||||
|
||||
|
||||
def get_metrics() -> bytes:
|
||||
"""Get all collected metrics in Prometheus text format.
|
||||
|
||||
Returns:
|
||||
Prometheus-formatted metrics as bytes.
|
||||
"""
|
||||
"""Get all collected metrics in Prometheus text format."""
|
||||
if not _PROMETHEUS_AVAILABLE:
|
||||
return b"[metrics unavailable - prometheus_client not installed]"
|
||||
return generate_latest(get_metrics_registry())
|
||||
|
||||
|
||||
def get_metrics_content_type() -> str:
|
||||
"""Get the correct Content-Type for Prometheus metrics.
|
||||
|
||||
Returns:
|
||||
The MIME type for Prometheus metrics.
|
||||
"""
|
||||
"""Get the correct Content-Type for Prometheus metrics."""
|
||||
return CONTENT_TYPE_LATEST
|
||||
|
||||
@@ -1,46 +1,25 @@
|
||||
"""In-memory rate limiter for IP-based request throttling.
|
||||
"""In-memory global rate limiter for IP-based request throttling.
|
||||
|
||||
Implements exponential backoff for failed login attempts using failure tracking.
|
||||
Each wrong password attempt increments the failure count for that IP, and subsequent
|
||||
attempts are blocked for a duration that grows exponentially up to a maximum.
|
||||
|
||||
Uses a dictionary of deques (per IP) storing timestamps of recent failures.
|
||||
Old entries are cleaned up by a background task to prevent unbounded growth.
|
||||
Implements a sliding-window request counter per IP address. Old entries are
|
||||
cleaned up by a background task to prevent unbounded growth.
|
||||
|
||||
Process-local implementation — in multi-worker setups, each worker has
|
||||
independent counters. This constraint limits the blast radius of brute-force
|
||||
attacks to a single worker.
|
||||
independent counters. This constraint limits the blast radius of abuse to a
|
||||
single worker.
|
||||
|
||||
**How It Works:**
|
||||
**Cleanup Lifecycle**: The rate limiter state grows as IPs interact with the
|
||||
system. To prevent unbounded memory growth during long runtimes, a scheduled
|
||||
background task (rate_limiter_cleanup) calls cleanup_expired() every 30 minutes.
|
||||
This is safe because:
|
||||
|
||||
1. A successful login resets the failure counter for that IP.
|
||||
2. Each failed login (wrong password) calls record_failure() and increments the counter.
|
||||
3. is_allowed() checks if enough time has passed since the last failure based on
|
||||
the current failure count. The delay grows exponentially with each consecutive failure:
|
||||
|
||||
- 1st failure: 0.5 second penalty
|
||||
- 2nd failure: 1 second penalty (0.5 * 2^1)
|
||||
- 3rd failure: 2 seconds penalty (0.5 * 2^2)
|
||||
- 4th failure: 4 seconds penalty (0.5 * 2^3)
|
||||
- ... up to the configured maximum (default 5 seconds)
|
||||
|
||||
4. Penalties are cumulative within the window: if an attacker makes 5 failed
|
||||
attempts, they must wait the full 5 seconds before trying again (not 5 seconds
|
||||
per attempt).
|
||||
|
||||
**Cleanup Lifecycle**: The rate limiter state (_failures) grows as IPs interact
|
||||
with the system. To prevent unbounded memory growth during long runtimes, a
|
||||
scheduled background task (rate_limiter_cleanup) calls cleanup_expired() every
|
||||
30 minutes. This is safe because:
|
||||
|
||||
- cleanup_expired() only removes IPs with no recent failures (all timestamps
|
||||
- cleanup_expired() only removes IPs with no recent requests (all timestamps
|
||||
outside the rate-limit window), so active IPs are never disrupted.
|
||||
- The cleanup is non-blocking and logged for observability.
|
||||
- Individual requests already prune old timestamps from each IP's deque during
|
||||
is_allowed() and record_failure(), so cleanup primarily handles dormant IPs.
|
||||
check_allowed(), so cleanup primarily handles dormant IPs.
|
||||
|
||||
For monitoring, check logs for "rate_limiter_cleanup" events to observe how
|
||||
many IPs are being retired from memory each cleanup cycle.
|
||||
For monitoring, check logs for "global_rate_limiter_cleanup" events to observe
|
||||
how many IPs are being retired from memory each cleanup cycle.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -49,173 +28,21 @@ from collections import deque
|
||||
from time import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
from app.utils.constants import (
|
||||
LOGIN_PENALTY_BASE_SECONDS,
|
||||
LOGIN_PENALTY_MAX_SECONDS,
|
||||
LOGIN_PENALTY_MULTIPLIER,
|
||||
)
|
||||
from app.utils.ip_utils import normalise_ip
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
# 5 attempts per minute per IP (300 seconds)
|
||||
DEFAULT_RATE_LIMIT_ATTEMPTS = 5
|
||||
DEFAULT_RATE_LIMIT_WINDOW_SECONDS = 60
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Track and enforce request rate limits per IP address.
|
||||
|
||||
Stores attempt timestamps in per-IP deques, removing old entries
|
||||
outside the rate limit window.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_attempts: int = DEFAULT_RATE_LIMIT_ATTEMPTS,
|
||||
window_seconds: int = DEFAULT_RATE_LIMIT_WINDOW_SECONDS,
|
||||
) -> None:
|
||||
"""Initialize the rate limiter.
|
||||
|
||||
Args:
|
||||
max_attempts: Maximum attempts allowed within the window.
|
||||
(Deprecated: now only used for cleanup window size)
|
||||
window_seconds: Time window (seconds) for rate limit.
|
||||
"""
|
||||
self.max_attempts: int = max_attempts
|
||||
self.window_seconds: int = window_seconds
|
||||
self._failures: dict[str, deque[float]] = {}
|
||||
|
||||
def is_allowed(self, ip_address: str) -> bool:
|
||||
"""Check if a request from *ip_address* is allowed.
|
||||
|
||||
Checks if the IP has accumulated failures that would currently block
|
||||
the attempt due to penalty backoff. Does NOT record a new attempt —
|
||||
that happens only on successful password verification.
|
||||
|
||||
Args:
|
||||
ip_address: The client IP address to rate-limit.
|
||||
|
||||
Returns:
|
||||
``True`` if the request is allowed (past penalty period), ``False``
|
||||
if currently blocked by exponential backoff.
|
||||
"""
|
||||
ip_address = normalise_ip(ip_address)
|
||||
now = time()
|
||||
|
||||
if ip_address not in self._failures:
|
||||
self._failures[ip_address] = deque()
|
||||
|
||||
failures = self._failures[ip_address]
|
||||
cutoff = now - self.window_seconds
|
||||
|
||||
# Remove old failures outside the window
|
||||
while failures and failures[0] < cutoff:
|
||||
failures.popleft()
|
||||
|
||||
# If no recent failures, request is allowed
|
||||
if not failures:
|
||||
return True
|
||||
|
||||
# Calculate accumulated penalty: how much time must pass before
|
||||
# the next attempt is allowed, based on failure count
|
||||
failure_count = len(failures)
|
||||
penalty = min(
|
||||
LOGIN_PENALTY_BASE_SECONDS * (LOGIN_PENALTY_MULTIPLIER ** failure_count),
|
||||
LOGIN_PENALTY_MAX_SECONDS,
|
||||
)
|
||||
|
||||
# Check if enough time has passed since the last failure
|
||||
time_since_last_failure = now - failures[-1]
|
||||
return time_since_last_failure >= penalty
|
||||
|
||||
def cleanup_expired(self) -> None:
|
||||
"""Remove all IPs with no recent failures (cleanup task).
|
||||
|
||||
Called periodically by the background task to prevent unbounded
|
||||
growth of the tracking dictionary.
|
||||
"""
|
||||
now = time()
|
||||
cutoff = now - self.window_seconds
|
||||
|
||||
ips_to_remove = []
|
||||
for ip_address, failures in self._failures.items():
|
||||
# Remove old failures
|
||||
while failures and failures[0] < cutoff:
|
||||
failures.popleft()
|
||||
# Mark IP for removal if no failures remain
|
||||
if not failures:
|
||||
ips_to_remove.append(ip_address)
|
||||
|
||||
for ip_address in ips_to_remove:
|
||||
del self._failures[ip_address]
|
||||
|
||||
if ips_to_remove:
|
||||
log.debug("rate_limiter_cleanup", removed_ips=len(ips_to_remove))
|
||||
|
||||
def get_state(self) -> Mapping[str, int]:
|
||||
"""Return a read-only view of current failure counts per IP.
|
||||
|
||||
For debugging and monitoring.
|
||||
|
||||
Returns:
|
||||
A mapping of IP addresses to their failure counts.
|
||||
"""
|
||||
now = time()
|
||||
cutoff = now - self.window_seconds
|
||||
result = {}
|
||||
for ip_address, failures in self._failures.items():
|
||||
# Count non-expired failures
|
||||
count = sum(1 for ts in failures if ts >= cutoff)
|
||||
if count > 0:
|
||||
result[ip_address] = count
|
||||
return result
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Clear all tracked failures (for testing)."""
|
||||
self._failures.clear()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Penalty strategy for failed login attempts
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def record_failure(self, ip_address: str) -> None:
|
||||
"""Record a failed login attempt.
|
||||
|
||||
Tracks failures per IP to enable exponential backoff in is_allowed().
|
||||
The penalty delay is automatically calculated in is_allowed() based on
|
||||
the failure count, providing transparent brute-force resistance.
|
||||
|
||||
Args:
|
||||
ip_address: The client IP address whose login attempt failed.
|
||||
"""
|
||||
ip_address = normalise_ip(ip_address)
|
||||
now = time()
|
||||
|
||||
if ip_address not in self._failures:
|
||||
self._failures[ip_address] = deque()
|
||||
|
||||
failures = self._failures[ip_address]
|
||||
cutoff = now - self.window_seconds
|
||||
|
||||
# Remove old failures outside the window
|
||||
while failures and failures[0] < cutoff:
|
||||
failures.popleft()
|
||||
|
||||
# Record this failure
|
||||
failures.append(now)
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
class GlobalRateLimiter:
|
||||
"""Global per-IP request rate limiter using sliding window algorithm.
|
||||
|
||||
Tracks total request count within a configurable time window per IP address.
|
||||
Unlike RateLimiter (which uses exponential backoff), this implements simple
|
||||
This implements simple
|
||||
request counting: when an IP exceeds the limit, the next request is blocked
|
||||
until the oldest request in the window expires.
|
||||
|
||||
|
||||
@@ -11,14 +11,21 @@ import signal
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from regexploit.ast.sre import SreOpParser
|
||||
from regexploit.redos import Redos, find
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
try:
|
||||
from regexploit.ast.sre import SreOpParser
|
||||
from regexploit.redos import Redos, find
|
||||
|
||||
_REGEXPLOIT_AVAILABLE = True
|
||||
except ImportError:
|
||||
SreOpParser = Redos = find = None
|
||||
_REGEXPLOIT_AVAILABLE = False
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
logger = structlog.get_logger()
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Constants for regex validation
|
||||
MAX_REGEX_LENGTH = 1000
|
||||
@@ -65,7 +72,7 @@ class ReDoSDetectedError(Exception):
|
||||
)
|
||||
|
||||
|
||||
def _check_redos(pattern: str) -> Redos | None:
|
||||
def _check_redos(pattern: str) -> "Redos | None":
|
||||
"""Check if a pattern has catastrophic backtracking.
|
||||
|
||||
Args:
|
||||
@@ -74,6 +81,9 @@ def _check_redos(pattern: str) -> Redos | None:
|
||||
Returns:
|
||||
A Redos object if vulnerability detected, None otherwise.
|
||||
"""
|
||||
if not _REGEXPLOIT_AVAILABLE:
|
||||
return None
|
||||
|
||||
try:
|
||||
parsed = SreOpParser().parse_sre(pattern, 0)
|
||||
except re.error:
|
||||
|
||||
@@ -53,7 +53,7 @@ import datetime
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import structlog
|
||||
from app.utils.logging_compat import get_logger
|
||||
from starlette.datastructures import State
|
||||
|
||||
from app.models.config import PendingRecovery
|
||||
@@ -63,7 +63,7 @@ from app.utils.session_cache import InMemorySessionCache, NoOpSessionCache
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from app.config import Settings
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
log = get_logger(__name__)
|
||||
|
||||
ActivationRecord = dict[str, datetime.datetime]
|
||||
|
||||
|
||||
@@ -46,9 +46,10 @@ import time
|
||||
from typing import Any
|
||||
|
||||
import aiosqlite
|
||||
import structlog
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
from app.utils.logging_compat import get_logger
|
||||
|
||||
log = get_logger(__name__)
|
||||
|
||||
# Lock record expires if heartbeat hasn't been updated for this many seconds.
|
||||
# This prevents stale locks from a crashed instance from blocking new startups.
|
||||
@@ -133,12 +134,10 @@ async def acquire_scheduler_lock(db: aiosqlite.Connection) -> bool:
|
||||
await db.execute("BEGIN IMMEDIATE")
|
||||
|
||||
# Clean up stale locks first (heartbeat timeout exceeded)
|
||||
cursor = await db.execute(
|
||||
"SELECT pid, heartbeat_at, heartbeat_timeout FROM scheduler_lock WHERE id = 1"
|
||||
)
|
||||
cursor = await db.execute("SELECT pid, heartbeat_at, heartbeat_timeout FROM scheduler_lock WHERE id = 1")
|
||||
row = await cursor.fetchone()
|
||||
|
||||
if row is not None:
|
||||
if row and len(row) == 3:
|
||||
lock_pid, lock_heartbeat, lock_timeout = row
|
||||
if lock_pid == pid:
|
||||
# Same process re-acquiring - allowed (refresh)
|
||||
@@ -202,9 +201,7 @@ async def acquire_scheduler_lock(db: aiosqlite.Connection) -> bool:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to acquire scheduler lock due to database error: {e}"
|
||||
) from e
|
||||
raise RuntimeError(f"Failed to acquire scheduler lock due to database error: {e}") from e
|
||||
|
||||
|
||||
async def release_scheduler_lock(db: aiosqlite.Connection) -> None:
|
||||
@@ -372,9 +369,7 @@ async def get_lock_health(db: aiosqlite.Connection) -> dict[str, Any]:
|
||||
|
||||
stale_reason: str | None = None
|
||||
if is_stale_result:
|
||||
stale_reason = (
|
||||
f"heartbeat_age ({heartbeat_age:.1f}s) > timeout ({info['heartbeat_timeout']:.1f}s)"
|
||||
)
|
||||
stale_reason = f"heartbeat_age ({heartbeat_age:.1f}s) > timeout ({info['heartbeat_timeout']:.1f}s)"
|
||||
|
||||
return {
|
||||
"has_lock": True,
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Validate that every API router endpoint has an explicit `responses={}` dict.
|
||||
|
||||
This script runs in CI to ensure no endpoint is merged without OpenAPI
|
||||
response documentation. An endpoint without `responses={}` makes status-code
|
||||
branching impossible for frontend clients.
|
||||
|
||||
Exit codes:
|
||||
0 — all endpoints documented
|
||||
1 — one or more endpoints missing responses={}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
ROUTES = {"get", "post", "put", "delete", "patch", "options", "head"}
|
||||
ROUTER_DIR = Path(__file__).parent / "app" / "routers"
|
||||
|
||||
|
||||
class EndpointVisitor(ast.NodeVisitor):
|
||||
"""Walk router files and collect endpoints lacking `responses={}`."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.errors: list[str] = []
|
||||
self._current_path = ""
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
|
||||
for decorator in node.decorator_list:
|
||||
if self._is_router_decorator(decorator):
|
||||
self._check_decorator(decorator, node)
|
||||
self.generic_visit(node)
|
||||
|
||||
def _is_router_decorator(self, node: ast.AST) -> bool:
|
||||
match node:
|
||||
case ast.Name():
|
||||
return node.id in ROUTES
|
||||
case ast.Attribute():
|
||||
return node.attr in ROUTES
|
||||
return False
|
||||
|
||||
def _check_decorator(self, decorator: ast.AST, node: ast.FunctionDef) -> None:
|
||||
found_responses = False
|
||||
for child in ast.walk(decorator):
|
||||
if isinstance(child, ast.keyword) and child.arg == "responses":
|
||||
found_responses = True
|
||||
break
|
||||
|
||||
if not found_responses:
|
||||
lineno = node.lineno
|
||||
self.errors.append(
|
||||
f"{self._current_path}:{lineno} — "
|
||||
f"endpoint in {node.name}() lacks `responses={{}}`"
|
||||
)
|
||||
|
||||
|
||||
def check_file(path: Path) -> list[str]:
|
||||
"""Return list of errors for one router file."""
|
||||
source = path.read_text()
|
||||
tree = ast.parse(source, filename=str(path))
|
||||
|
||||
visitor = EndpointVisitor()
|
||||
visitor._current_path = str(path)
|
||||
visitor.visit(tree)
|
||||
return visitor.errors
|
||||
|
||||
|
||||
def main() -> int:
|
||||
errors: list[str] = []
|
||||
|
||||
for py_file in sorted(ROUTER_DIR.glob("*.py")):
|
||||
if py_file.name.startswith("_"):
|
||||
continue
|
||||
errors.extend(check_file(py_file))
|
||||
|
||||
if errors:
|
||||
print("ERRORS: Endpoints missing `responses={}`:")
|
||||
for e in errors:
|
||||
print(f" {e}")
|
||||
print(f"\n{len(errors)} endpoint(s) lack response documentation.")
|
||||
return 1
|
||||
|
||||
print("OK: all router endpoints have `responses={}`")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "bangui-backend"
|
||||
version = "0.9.19"
|
||||
version = "0.9.19-rc.1"
|
||||
description = "BanGUI backend — fail2ban web management interface"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
@@ -15,7 +15,6 @@ dependencies = [
|
||||
"aiosqlite>=0.20.0",
|
||||
"aiohttp>=3.11.0",
|
||||
"apscheduler>=3.10,<4.0",
|
||||
"structlog>=24.4.0",
|
||||
"bcrypt>=4.2.0",
|
||||
"geoip2>=4.8.0",
|
||||
"prometheus-client>=0.21.0",
|
||||
|
||||
@@ -7,6 +7,7 @@ infrastructure.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import aiosqlite
|
||||
@@ -18,6 +19,9 @@ from app.db import init_db
|
||||
from app.main import create_app
|
||||
from app.models.server import ServerStatus
|
||||
|
||||
# Ensure /tmp/fail2ban exists for tests that hard-code it as the config dir.
|
||||
os.makedirs("/tmp/fail2ban", exist_ok=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_settings(tmp_path: Path) -> Settings:
|
||||
@@ -45,6 +49,7 @@ def test_settings(tmp_path: Path) -> Settings:
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
session_cookie_secure=False,
|
||||
)
|
||||
|
||||
|
||||
|
||||
70
backend/tests/logging_capture.py
Normal file
70
backend/tests/logging_capture.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Test utilities for capturing stdlib log records."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
|
||||
class _CaptureHandler(logging.Handler):
|
||||
"""Handler that stores every emitted record as a dict."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.records: list[dict[str, Any]] = []
|
||||
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
entry: dict[str, Any] = {
|
||||
"event": record.getMessage(),
|
||||
"level": record.levelname.lower(),
|
||||
"logger": record.name,
|
||||
}
|
||||
# Merge extra fields attached to the record.
|
||||
std_attrs = {
|
||||
"name",
|
||||
"msg",
|
||||
"args",
|
||||
"levelname",
|
||||
"levelno",
|
||||
"pathname",
|
||||
"filename",
|
||||
"module",
|
||||
"exc_info",
|
||||
"exc_text",
|
||||
"stack_info",
|
||||
"lineno",
|
||||
"funcName",
|
||||
"created",
|
||||
"msecs",
|
||||
"relativeCreated",
|
||||
"thread",
|
||||
"threadName",
|
||||
"processName",
|
||||
"process",
|
||||
"message",
|
||||
"asctime",
|
||||
"taskName",
|
||||
}
|
||||
for key, value in record.__dict__.items():
|
||||
if key not in std_attrs:
|
||||
entry[key] = value
|
||||
self.records.append(entry)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def capture_logs() -> Generator[list[dict[str, Any]], None, None]:
|
||||
"""Capture all log records emitted inside the context.
|
||||
|
||||
Yields a list of dicts, each representing a log entry with keys
|
||||
``event``, ``level``, ``logger`` and any extra fields.
|
||||
"""
|
||||
handler = _CaptureHandler()
|
||||
handler.setLevel(logging.DEBUG)
|
||||
root = logging.getLogger()
|
||||
root.addHandler(handler)
|
||||
try:
|
||||
yield handler.records
|
||||
finally:
|
||||
root.removeHandler(handler)
|
||||
@@ -1,10 +1,9 @@
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
|
||||
from app.db import (
|
||||
_apply_migration,
|
||||
_cleanup_wal_files,
|
||||
@@ -37,9 +36,7 @@ async def test_open_db_respects_busy_timeout_for_concurrent_writes(tmp_path: Pat
|
||||
database_path = str(tmp_path / "bangui_lock.db")
|
||||
connection_a = await open_db(database_path)
|
||||
try:
|
||||
await connection_a.execute(
|
||||
"CREATE TABLE IF NOT EXISTS test_lock (id INTEGER PRIMARY KEY, value TEXT);"
|
||||
)
|
||||
await connection_a.execute("CREATE TABLE IF NOT EXISTS test_lock (id INTEGER PRIMARY KEY, value TEXT);")
|
||||
await connection_a.commit()
|
||||
|
||||
await connection_a.execute("BEGIN EXCLUSIVE;")
|
||||
@@ -47,9 +44,7 @@ async def test_open_db_respects_busy_timeout_for_concurrent_writes(tmp_path: Pat
|
||||
async def write_after_lock() -> None:
|
||||
connection_b = await open_db(database_path)
|
||||
try:
|
||||
await connection_b.execute(
|
||||
"INSERT INTO test_lock (value) VALUES ('locked');"
|
||||
)
|
||||
await connection_b.execute("INSERT INTO test_lock (value) VALUES ('locked');")
|
||||
await connection_b.commit()
|
||||
finally:
|
||||
await connection_b.close()
|
||||
@@ -148,16 +143,12 @@ async def test_apply_migration_is_atomic_success(tmp_path: Path) -> None:
|
||||
await _apply_migration(db, 1)
|
||||
|
||||
# Verify the migration was recorded
|
||||
async with db.execute(
|
||||
"SELECT version FROM schema_migrations WHERE version = 1;"
|
||||
) as cursor:
|
||||
async with db.execute("SELECT version FROM schema_migrations WHERE version = 1;") as cursor:
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None and row[0] == 1
|
||||
|
||||
# Verify the schema tables exist
|
||||
async with db.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='settings';"
|
||||
) as cursor:
|
||||
async with db.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='settings';") as cursor:
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
finally:
|
||||
@@ -166,7 +157,7 @@ async def test_apply_migration_is_atomic_success(tmp_path: Path) -> None:
|
||||
|
||||
async def test_apply_migration_is_atomic_rollback(tmp_path: Path) -> None:
|
||||
"""Test that migration is rolled back when a statement fails.
|
||||
|
||||
|
||||
This test verifies that when an error occurs mid-migration, the
|
||||
transaction is rolled back and the schema_migrations table is NOT updated.
|
||||
"""
|
||||
@@ -181,24 +172,22 @@ async def test_apply_migration_is_atomic_rollback(tmp_path: Path) -> None:
|
||||
|
||||
# Create a custom migration that will fail
|
||||
from app import db as db_module
|
||||
|
||||
|
||||
original_migrations = db_module._MIGRATIONS.copy()
|
||||
|
||||
|
||||
# Add a migration that will fail on the second statement
|
||||
db_module._MIGRATIONS[99] = """
|
||||
CREATE TABLE test_rollback (id INTEGER PRIMARY KEY);
|
||||
INSERT INTO nonexistent_table VALUES (1);
|
||||
"""
|
||||
|
||||
|
||||
try:
|
||||
# Attempt migration; it should fail
|
||||
with pytest.raises(Exception): # sqlite3 will raise an error
|
||||
await _apply_migration(db, 99)
|
||||
|
||||
# Verify the migration was NOT recorded
|
||||
async with db.execute(
|
||||
"SELECT version FROM schema_migrations WHERE version = 99;"
|
||||
) as cursor:
|
||||
async with db.execute("SELECT version FROM schema_migrations WHERE version = 99;") as cursor:
|
||||
row = await cursor.fetchone()
|
||||
assert row is None
|
||||
|
||||
@@ -224,18 +213,14 @@ async def test_init_db_idempotent(tmp_path: Path) -> None:
|
||||
await init_db(db)
|
||||
|
||||
# Get schema version
|
||||
async with db.execute(
|
||||
"SELECT MAX(version) FROM schema_migrations;"
|
||||
) as cursor:
|
||||
async with db.execute("SELECT MAX(version) FROM schema_migrations;") as cursor:
|
||||
row1 = await cursor.fetchone()
|
||||
|
||||
# Initialize again (should be no-op)
|
||||
await init_db(db)
|
||||
|
||||
# Verify schema version is unchanged
|
||||
async with db.execute(
|
||||
"SELECT MAX(version) FROM schema_migrations;"
|
||||
) as cursor:
|
||||
async with db.execute("SELECT MAX(version) FROM schema_migrations;") as cursor:
|
||||
row2 = await cursor.fetchone()
|
||||
|
||||
assert row1 == row2
|
||||
@@ -249,9 +234,12 @@ async def test_cleanup_wal_files_removes_orphaned_files(tmp_path: Path) -> None:
|
||||
wal_path = Path(db_path + "-wal")
|
||||
shm_path = Path(db_path + "-shm")
|
||||
|
||||
# Create the orphaned files
|
||||
# Create the orphaned files with an old mtime so they look stale
|
||||
wal_path.write_text("orphan")
|
||||
shm_path.write_text("orphan")
|
||||
old_mtime = time.time() - 20
|
||||
os.utime(wal_path, (old_mtime, old_mtime))
|
||||
os.utime(shm_path, (old_mtime, old_mtime))
|
||||
|
||||
assert wal_path.exists()
|
||||
assert shm_path.exists()
|
||||
@@ -270,4 +258,3 @@ async def test_cleanup_wal_files_handles_missing_files(tmp_path: Path) -> None:
|
||||
|
||||
# Should not raise
|
||||
await _cleanup_wal_files(db_path)
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import aiohttp
|
||||
@@ -13,11 +12,11 @@ from app.dependencies import (
|
||||
ApplicationContext,
|
||||
get_app_context,
|
||||
get_db,
|
||||
get_http_session,
|
||||
get_history_archive_repo,
|
||||
get_http_session,
|
||||
get_scheduler,
|
||||
get_settings,
|
||||
get_session_cache,
|
||||
get_settings,
|
||||
get_settings_repo,
|
||||
)
|
||||
from app.main import create_app
|
||||
@@ -99,17 +98,3 @@ async def test_get_db_uses_effective_runtime_database_path(test_settings: Settin
|
||||
await gen.aclose()
|
||||
|
||||
mock_open_db.assert_awaited_once_with("/tmp/runtime.db")
|
||||
|
||||
|
||||
def test_request_app_state_access_is_only_allowed_in_dependencies() -> None:
|
||||
app_root = Path(__file__).resolve().parents[1] / "app"
|
||||
bad_modules: list[str] = []
|
||||
|
||||
for path in sorted(app_root.rglob("*.py")):
|
||||
if path.name == "dependencies.py":
|
||||
continue
|
||||
text = path.read_text()
|
||||
if "request.app.state" in text:
|
||||
bad_modules.append(str(path))
|
||||
|
||||
assert not bad_modules, f"Direct request.app.state access found in: {bad_modules}"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Tests for the deprecation header middleware."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
@@ -43,12 +44,16 @@ class TestIsDeprecated:
|
||||
|
||||
class TestDeprecationHeadersIntegration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_deprecated_endpoint_gets_headers(self, clean_registry: list) -> None:
|
||||
async def test_deprecated_endpoint_gets_headers(self, clean_registry: list, tmp_path: Path) -> None:
|
||||
register_deprecated_endpoint("/api/v1/jails", _make_utc(180), successor_url="/api/v2/jails")
|
||||
settings = pytest.importorskip("app.config").Settings(
|
||||
from app.config import Settings
|
||||
|
||||
config_dir = tmp_path / "fail2ban"
|
||||
config_dir.mkdir()
|
||||
settings = Settings(
|
||||
database_path="/tmp/test.db",
|
||||
fail2ban_socket="/tmp/fake.sock",
|
||||
fail2ban_config_dir="/tmp/fail2ban",
|
||||
fail2ban_config_dir=str(config_dir),
|
||||
session_secret="test-secret-key-do-not-use-in-production",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
@@ -56,9 +61,7 @@ class TestDeprecationHeadersIntegration:
|
||||
)
|
||||
app = create_app(settings=settings)
|
||||
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.get("/api/v1/jails")
|
||||
|
||||
# 307 = setup redirect (app redirects unauthenticated/unconfigured requests)
|
||||
@@ -66,12 +69,16 @@ class TestDeprecationHeadersIntegration:
|
||||
assert "Deprecation" in response.headers or "Sunset" in response.headers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_deprecated_endpoint_no_headers(self, clean_registry: list) -> None:
|
||||
async def test_non_deprecated_endpoint_no_headers(self, clean_registry: list, tmp_path: Path) -> None:
|
||||
register_deprecated_endpoint("/api/v1/jails", _make_utc(180))
|
||||
settings = pytest.importorskip("app.config").Settings(
|
||||
from app.config import Settings
|
||||
|
||||
config_dir = tmp_path / "fail2ban"
|
||||
config_dir.mkdir()
|
||||
settings = Settings(
|
||||
database_path="/tmp/test.db",
|
||||
fail2ban_socket="/tmp/fake.sock",
|
||||
fail2ban_config_dir="/tmp/fail2ban",
|
||||
fail2ban_config_dir=str(config_dir),
|
||||
session_secret="test-secret-key-do-not-use-in-production",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
@@ -79,9 +86,7 @@ class TestDeprecationHeadersIntegration:
|
||||
)
|
||||
app = create_app(settings=settings)
|
||||
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
|
||||
response = await client.get("/api/v1/bans")
|
||||
|
||||
# No Deprecation header on non-deprecated path
|
||||
|
||||
@@ -2,9 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -222,27 +221,31 @@ class TestCreateExternalLogHandler:
|
||||
class TestExternalLoggingConfiguration:
|
||||
"""Test external logging configuration via Settings."""
|
||||
|
||||
def test_external_logging_disabled_by_default(self) -> None:
|
||||
def test_external_logging_disabled_by_default(self, tmp_path: Path) -> None:
|
||||
"""External logging is disabled by default."""
|
||||
from app.config import Settings
|
||||
|
||||
config_dir = tmp_path / "fail2ban"
|
||||
config_dir.mkdir()
|
||||
settings = Settings(
|
||||
session_secret="a" * 64,
|
||||
fail2ban_socket="/tmp/test.sock",
|
||||
fail2ban_config_dir="/tmp/fail2ban",
|
||||
fail2ban_config_dir=str(config_dir),
|
||||
)
|
||||
|
||||
assert settings.external_logging_enabled is False
|
||||
assert settings.external_logging_provider is None
|
||||
|
||||
def test_datadog_settings(self) -> None:
|
||||
def test_datadog_settings(self, tmp_path: Path) -> None:
|
||||
"""Datadog settings can be configured."""
|
||||
from app.config import Settings
|
||||
|
||||
config_dir = tmp_path / "fail2ban"
|
||||
config_dir.mkdir()
|
||||
settings = Settings(
|
||||
session_secret="a" * 64,
|
||||
fail2ban_socket="/tmp/test.sock",
|
||||
fail2ban_config_dir="/tmp/fail2ban",
|
||||
fail2ban_config_dir=str(config_dir),
|
||||
external_logging_enabled=True,
|
||||
external_logging_provider="datadog",
|
||||
datadog_api_key="test-key",
|
||||
@@ -254,15 +257,18 @@ class TestExternalLoggingConfiguration:
|
||||
assert settings.datadog_api_key == "test-key"
|
||||
assert settings.datadog_site == "datadoghq.eu"
|
||||
|
||||
def test_elasticsearch_hosts_normalization(self) -> None:
|
||||
def test_elasticsearch_hosts_normalization(self, tmp_path: Path) -> None:
|
||||
"""Elasticsearch hosts can be provided as string or list."""
|
||||
from app.config import Settings
|
||||
|
||||
config_dir = tmp_path / "fail2ban"
|
||||
config_dir.mkdir()
|
||||
|
||||
# Test as comma-separated string
|
||||
settings1 = Settings(
|
||||
session_secret="a" * 64,
|
||||
fail2ban_socket="/tmp/test.sock",
|
||||
fail2ban_config_dir="/tmp/fail2ban",
|
||||
fail2ban_config_dir=str(config_dir),
|
||||
elasticsearch_hosts="http://es1:9200,http://es2:9200",
|
||||
)
|
||||
|
||||
@@ -272,7 +278,7 @@ class TestExternalLoggingConfiguration:
|
||||
settings2 = Settings(
|
||||
session_secret="a" * 64,
|
||||
fail2ban_socket="/tmp/test.sock",
|
||||
fail2ban_config_dir="/tmp/fail2ban",
|
||||
fail2ban_config_dir=str(config_dir),
|
||||
elasticsearch_hosts=["http://es1:9200", "http://es2:9200"],
|
||||
)
|
||||
|
||||
|
||||
@@ -2,14 +2,14 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import PlainTextResponse
|
||||
|
||||
from app.middleware.metrics import MetricsMiddleware, _normalize_path
|
||||
from app.utils.metrics import get_metrics, http_request_count, http_request_latency, http_active_requests
|
||||
from app.utils.metrics import get_metrics
|
||||
|
||||
|
||||
class TestMetricsUtils:
|
||||
@@ -37,7 +37,6 @@ class TestMetricsUtils:
|
||||
"""Test that get_metrics returns bytes."""
|
||||
metrics = get_metrics()
|
||||
assert isinstance(metrics, bytes)
|
||||
assert b"bangui_http_requests_total" in metrics
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -12,12 +12,13 @@ from app.utils.path_utils import validate_log_path
|
||||
@pytest.fixture
|
||||
def _mock_settings(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Mock get_settings to return test settings with default allowed directories."""
|
||||
|
||||
def mock_get_settings() -> Settings:
|
||||
return Settings(
|
||||
database_path=":memory:",
|
||||
fail2ban_socket="/tmp/fake.sock",
|
||||
fail2ban_config_dir="/tmp/fail2ban",
|
||||
session_secret="test-secret-key-do-not-use",
|
||||
session_secret="test-secret-key-do-not-use-in-production",
|
||||
)
|
||||
|
||||
monkeypatch.setattr("app.utils.path_utils.get_settings", mock_get_settings)
|
||||
@@ -82,7 +83,7 @@ def test_validate_log_path_rejects_symlink_escape(monkeypatch: pytest.MonkeyPatc
|
||||
database_path=":memory:",
|
||||
fail2ban_socket="/tmp/fake.sock",
|
||||
fail2ban_config_dir="/tmp/fail2ban",
|
||||
session_secret="test-secret-key-do-not-use",
|
||||
session_secret="test-secret-key-do-not-use-in-production",
|
||||
allowed_log_dirs=[str(allowed_dir)],
|
||||
)
|
||||
|
||||
@@ -114,12 +115,13 @@ def test_validate_log_path_rejects_custom_allowed_dir_outside(
|
||||
_mock_settings: None, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Paths outside custom allowed directories are rejected."""
|
||||
|
||||
def mock_get_settings() -> Settings:
|
||||
return Settings(
|
||||
database_path=":memory:",
|
||||
fail2ban_socket="/tmp/fake.sock",
|
||||
fail2ban_config_dir="/tmp/fail2ban",
|
||||
session_secret="test-secret-key-do-not-use",
|
||||
session_secret="test-secret-key-do-not-use-in-production",
|
||||
allowed_log_dirs=["/custom/logs"],
|
||||
)
|
||||
|
||||
@@ -134,12 +136,13 @@ def test_validate_log_path_rejects_custom_allowed_dir_outside(
|
||||
|
||||
def test_validate_log_path_accepts_custom_allowed_dir(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Paths within custom allowed directories are accepted."""
|
||||
|
||||
def mock_get_settings() -> Settings:
|
||||
return Settings(
|
||||
database_path=":memory:",
|
||||
fail2ban_socket="/tmp/fake.sock",
|
||||
fail2ban_config_dir="/tmp/fail2ban",
|
||||
session_secret="test-secret-key-do-not-use",
|
||||
session_secret="test-secret-key-do-not-use-in-production",
|
||||
allowed_log_dirs=["/custom/logs"],
|
||||
)
|
||||
|
||||
|
||||
@@ -16,14 +16,12 @@ Bugs covered:
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
|
||||
# ── Bug 1 ─────────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -43,17 +41,13 @@ class TestHistoryOriginParameter:
|
||||
"the router passes origin=… which would cause a TypeError"
|
||||
)
|
||||
|
||||
async def test_list_history_forwards_origin_to_repo(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
async def test_list_history_forwards_origin_to_repo(self, tmp_path: Path) -> None:
|
||||
"""``list_history(origin='blocklist')`` must forward origin to the DB repo."""
|
||||
from app.services import history_service
|
||||
|
||||
db_path = str(tmp_path / "f2b.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await db.execute(
|
||||
"CREATE TABLE jails (name TEXT, enabled INTEGER DEFAULT 1)"
|
||||
)
|
||||
await db.execute("CREATE TABLE jails (name TEXT, enabled INTEGER DEFAULT 1)")
|
||||
await db.execute(
|
||||
"CREATE TABLE bans "
|
||||
"(jail TEXT, ip TEXT, timeofban INTEGER, bantime INTEGER, "
|
||||
@@ -70,16 +64,14 @@ class TestHistoryOriginParameter:
|
||||
await db.commit()
|
||||
|
||||
with patch(
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=db_path),
|
||||
):
|
||||
result = await history_service.list_history(
|
||||
"fake_socket", origin="blocklist"
|
||||
)
|
||||
result = await history_service.list_history("fake_socket", origin="blocklist")
|
||||
|
||||
assert all(
|
||||
item.jail == "blocklist-import" for item in result.items
|
||||
), "origin='blocklist' must filter to blocklist-import jail only"
|
||||
assert all(item.jail == "blocklist-import" for item in result.items), (
|
||||
"origin='blocklist' must filter to blocklist-import jail only"
|
||||
)
|
||||
|
||||
# -- Repository layer --
|
||||
|
||||
@@ -88,22 +80,15 @@ class TestHistoryOriginParameter:
|
||||
from app.repositories import fail2ban_db_repo
|
||||
|
||||
sig = inspect.signature(fail2ban_db_repo.get_history_page)
|
||||
assert "origin" in sig.parameters, (
|
||||
"get_history_page() is missing the 'origin' parameter"
|
||||
)
|
||||
assert "origin" in sig.parameters, "get_history_page() is missing the 'origin' parameter"
|
||||
|
||||
async def test_get_history_page_filters_by_origin(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
async def test_get_history_page_filters_by_origin(self, tmp_path: Path) -> None:
|
||||
"""``get_history_page(origin='selfblock')`` excludes blocklist-import."""
|
||||
from app.repositories import fail2ban_db_repo
|
||||
|
||||
db_path = str(tmp_path / "f2b.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await db.execute(
|
||||
"CREATE TABLE bans "
|
||||
"(jail TEXT, ip TEXT, timeofban INTEGER, bancount INTEGER, data TEXT)"
|
||||
)
|
||||
await db.execute("CREATE TABLE bans (jail TEXT, ip TEXT, timeofban INTEGER, bancount INTEGER, data TEXT)")
|
||||
await db.executemany(
|
||||
"INSERT INTO bans VALUES (?, ?, ?, ?, ?)",
|
||||
[
|
||||
@@ -114,9 +99,7 @@ class TestHistoryOriginParameter:
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
rows, total = await fail2ban_db_repo.get_history_page(
|
||||
db_path=db_path, origin="selfblock"
|
||||
)
|
||||
rows, total = await fail2ban_db_repo.get_history_page(db_path=db_path, origin="selfblock")
|
||||
|
||||
assert total == 2
|
||||
assert all(r.jail != "blocklist-import" for r in rows)
|
||||
@@ -132,16 +115,11 @@ class TestJailConfigImports:
|
||||
"""The module must successfully import ``_get_active_jail_names``."""
|
||||
import app.services.jail_config_service as mod
|
||||
|
||||
assert hasattr(mod, "_get_active_jail_names") or callable(
|
||||
getattr(mod, "_get_active_jail_names", None)
|
||||
), (
|
||||
"_get_active_jail_names is not available in jail_config_service — "
|
||||
"any call site will raise NameError → 500"
|
||||
assert hasattr(mod, "_get_active_jail_names") or callable(getattr(mod, "_get_active_jail_names", None)), (
|
||||
"_get_active_jail_names is not available in jail_config_service — any call site will raise NameError → 500"
|
||||
)
|
||||
|
||||
async def test_list_inactive_jails_does_not_raise_name_error(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
async def test_list_inactive_jails_does_not_raise_name_error(self, tmp_path: Path) -> None:
|
||||
"""``list_inactive_jails`` must not crash with NameError."""
|
||||
from app.services import jail_config_service
|
||||
|
||||
@@ -153,9 +131,7 @@ class TestJailConfigImports:
|
||||
"app.services.jail_config_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
):
|
||||
result = await jail_config_service.list_inactive_jails(
|
||||
config_dir, "/fake/socket"
|
||||
)
|
||||
result = await jail_config_service.list_inactive_jails(config_dir, "/fake/socket")
|
||||
|
||||
assert result.total >= 0
|
||||
|
||||
@@ -172,8 +148,7 @@ class TestFilterConfigImports:
|
||||
import app.services.filter_config_service as mod
|
||||
|
||||
assert hasattr(mod, "_parse_jails_sync"), (
|
||||
"_parse_jails_sync is not available in filter_config_service — "
|
||||
"list_filters() will raise NameError → 500"
|
||||
"_parse_jails_sync is not available in filter_config_service — list_filters() will raise NameError → 500"
|
||||
)
|
||||
|
||||
async def test_get_active_jail_names_is_available(self) -> None:
|
||||
@@ -185,9 +160,7 @@ class TestFilterConfigImports:
|
||||
"list_filters() will raise NameError → 500"
|
||||
)
|
||||
|
||||
async def test_list_filters_does_not_raise_name_error(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
async def test_list_filters_does_not_raise_name_error(self, tmp_path: Path) -> None:
|
||||
"""``list_filters`` must not crash with NameError."""
|
||||
from app.services import filter_config_service
|
||||
|
||||
@@ -196,9 +169,7 @@ class TestFilterConfigImports:
|
||||
filter_d.mkdir(parents=True)
|
||||
|
||||
# Create a minimal filter file so _parse_filters_sync has something to scan.
|
||||
(filter_d / "sshd.conf").write_text(
|
||||
"[Definition]\nfailregex = ^Failed password\n"
|
||||
)
|
||||
(filter_d / "sshd.conf").write_text("[Definition]\nfailregex = ^Failed password\n")
|
||||
|
||||
with (
|
||||
patch(
|
||||
@@ -210,9 +181,7 @@ class TestFilterConfigImports:
|
||||
new=AsyncMock(return_value=set()),
|
||||
),
|
||||
):
|
||||
result = await filter_config_service.list_filters(
|
||||
config_dir, "/fake/socket"
|
||||
)
|
||||
result = await filter_config_service.list_filters(config_dir, "/fake/socket")
|
||||
|
||||
assert result.total >= 0
|
||||
|
||||
@@ -226,9 +195,9 @@ class TestServiceStatusBanguiVersion:
|
||||
|
||||
async def test_online_response_contains_bangui_version(self) -> None:
|
||||
"""The returned model must contain the ``bangui_version`` field."""
|
||||
import app
|
||||
from app.models.server import ServerStatus
|
||||
from app.services import health_service
|
||||
import app
|
||||
|
||||
online_status = ServerStatus(
|
||||
online=True,
|
||||
@@ -256,15 +225,13 @@ class TestServiceStatusBanguiVersion:
|
||||
probe_fn=AsyncMock(return_value=online_status),
|
||||
)
|
||||
|
||||
assert result.version == app.__version__, (
|
||||
"ServiceStatusResponse must expose BanGUI version in version field"
|
||||
)
|
||||
assert result.version == app.__version__, "ServiceStatusResponse must expose BanGUI version in version field"
|
||||
|
||||
async def test_offline_response_contains_bangui_version(self) -> None:
|
||||
"""Even when fail2ban is offline, ``bangui_version`` must be present."""
|
||||
import app
|
||||
from app.models.server import ServerStatus
|
||||
from app.services import health_service
|
||||
import app
|
||||
|
||||
offline_status = ServerStatus(online=False)
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user