Compare commits
76 Commits
61daa8bbc0
...
v0.9.12
| Author | SHA1 | Date | |
|---|---|---|---|
| d909f93efc | |||
| 965cdd765b | |||
| 0663740b08 | |||
| 29587f2353 | |||
| 798ed08ddd | |||
| ed184f1c84 | |||
| 8e1b4fa978 | |||
| e604e3aadf | |||
| cf721513e8 | |||
| a32cc82851 | |||
| 26af69e2a3 | |||
| 00e702a2c0 | |||
| ee73373111 | |||
| a1f97bd78f | |||
| 99fbddb0e7 | |||
| b15629a078 | |||
| 136f21ecbe | |||
| bf2abda595 | |||
| 335f89c554 | |||
| 05dc9fa1e3 | |||
| 471eed9664 | |||
| 1f272dc348 | |||
| f9cec2a975 | |||
| cc235b95c6 | |||
| 29415da421 | |||
| 8a6bcc4d94 | |||
| a442836c5c | |||
| 3aba2b6446 | |||
| 28a7610276 | |||
| d30d138146 | |||
| 8c4fe767de | |||
| 52b0936200 | |||
| 1c0bac1353 | |||
| bdcdd5d672 | |||
| 482399c9e2 | |||
| ce59a66973 | |||
| dfbe126368 | |||
| c9e688cc52 | |||
| 1ce5da9e23 | |||
| 93f0feabde | |||
| 376c13370d | |||
| fb6d0e588f | |||
| e44caccb3c | |||
| 15e4a5434e | |||
| 1cc9968d31 | |||
| 80a6bac33e | |||
| 133ab2e82c | |||
| 60f2f35b25 | |||
| 59da34dc3b | |||
| 90f54cf39c | |||
| 93d26e3c60 | |||
| 954dcf7ea6 | |||
| bf8144916a | |||
| 481daa4e1a | |||
| 889976c7ee | |||
| d3d2cb0915 | |||
| bf82e38b6e | |||
| e98fd1de93 | |||
| 8f515893ea | |||
| 81f99d0b50 | |||
| 030bca09b7 | |||
| 5b7d1a4360 | |||
| e7834a888e | |||
| abb224e01b | |||
| 57cf93b1e5 | |||
| c41165c294 | |||
| cdf73e2d65 | |||
| 21753c4f06 | |||
| eb859af371 | |||
| 5a5c619a34 | |||
| 00119ed68d | |||
| b81e0cdbb4 | |||
| 41dcd60225 | |||
| 12f04bd8d6 | |||
| d4d04491d2 | |||
| 93dc699825 |
@@ -10,7 +10,7 @@
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
# ── Stage 1: build dependencies ──────────────────────────────
|
||||
FROM python:3.12-slim AS builder
|
||||
FROM docker.io/library/python:3.12-slim AS builder
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
@@ -28,7 +28,7 @@ RUN pip install --no-cache-dir --upgrade pip \
|
||||
&& pip install --no-cache-dir .
|
||||
|
||||
# ── Stage 2: runtime image ───────────────────────────────────
|
||||
FROM python:3.12-slim AS runtime
|
||||
FROM docker.io/library/python:3.12-slim AS runtime
|
||||
|
||||
LABEL maintainer="BanGUI" \
|
||||
description="BanGUI backend — fail2ban web management API"
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
# ── Stage 1: install & build ─────────────────────────────────
|
||||
FROM node:22-alpine AS builder
|
||||
FROM docker.io/library/node:22-alpine AS builder
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
@@ -23,7 +23,7 @@ COPY frontend/ /build/
|
||||
RUN npm run build
|
||||
|
||||
# ── Stage 2: serve with nginx ────────────────────────────────
|
||||
FROM nginx:1.27-alpine AS runtime
|
||||
FROM docker.io/library/nginx:1.27-alpine AS runtime
|
||||
|
||||
LABEL maintainer="BanGUI" \
|
||||
description="BanGUI frontend — fail2ban web management UI"
|
||||
|
||||
1
Docker/VERSION
Normal file
1
Docker/VERSION
Normal file
@@ -0,0 +1 @@
|
||||
v0.9.12
|
||||
@@ -2,7 +2,7 @@
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# check_ban_status.sh
|
||||
#
|
||||
# Queries the bangui-sim jail inside the running fail2ban
|
||||
# Queries the manual-Jail jail inside the running fail2ban
|
||||
# container and optionally unbans a specific IP.
|
||||
#
|
||||
# Usage:
|
||||
@@ -17,7 +17,7 @@
|
||||
set -euo pipefail
|
||||
|
||||
readonly CONTAINER="bangui-fail2ban-dev"
|
||||
readonly JAIL="bangui-sim"
|
||||
readonly JAIL="manual-Jail"
|
||||
|
||||
# ── Helper: run a fail2ban-client command inside the container ─
|
||||
f2b() {
|
||||
|
||||
73
Docker/docker-compose.yml
Normal file
73
Docker/docker-compose.yml
Normal file
@@ -0,0 +1,73 @@
|
||||
version: '3.8'
|
||||
services:
|
||||
fail2ban:
|
||||
image: lscr.io/linuxserver/fail2ban:latest
|
||||
container_name: fail2ban
|
||||
cap_add:
|
||||
- NET_ADMIN
|
||||
- NET_RAW
|
||||
network_mode: host
|
||||
environment:
|
||||
- PUID=1011
|
||||
- PGID=1001
|
||||
- TZ=Etc/UTC
|
||||
- VERBOSITY=-vv #optional
|
||||
|
||||
volumes:
|
||||
- /server/server_fail2ban/config:/config
|
||||
- /server/server_fail2ban/fail2ban-run:/var/run/fail2ban
|
||||
- /var/log:/var/log
|
||||
- /server/server_nextcloud/config/nextcloud.log:/remotelogs/nextcloud/nextcloud.log:ro #optional
|
||||
- /server/server_nginx/data/logs:/remotelogs/nginx:ro #optional
|
||||
- /server/server_gitea/log/gitea.log:/remotelogs/gitea/gitea.log:ro #optional
|
||||
|
||||
|
||||
#- /path/to/homeassistant/log:/remotelogs/homeassistant:ro #optional
|
||||
#- /path/to/unificontroller/log:/remotelogs/unificontroller:ro #optional
|
||||
#- /path/to/vaultwarden/log:/remotelogs/vaultwarden:ro #optional
|
||||
restart: unless-stopped
|
||||
|
||||
backend:
|
||||
image: git.lpl-mind.de/lukas.pupkalipinski/bangui/backend:latest
|
||||
container_name: bangui-backend
|
||||
restart: unless-stopped
|
||||
depends_on:
|
||||
fail2ban:
|
||||
condition: service_started
|
||||
environment:
|
||||
- PUID=1011
|
||||
- PGID=1001
|
||||
- BANGUI_DATABASE_PATH=/data/bangui.db
|
||||
- BANGUI_FAIL2BAN_SOCKET=/var/run/fail2ban/fail2ban.sock
|
||||
- BANGUI_FAIL2BAN_CONFIG_DIR=/config/fail2ban
|
||||
- BANGUI_LOG_LEVEL=info
|
||||
- BANGUI_SESSION_SECRET=${BANGUI_SESSION_SECRET:?Set BANGUI_SESSION_SECRET}
|
||||
- BANGUI_TIMEZONE=${BANGUI_TIMEZONE:-UTC}
|
||||
volumes:
|
||||
- /server/server_fail2ban/bangui-data:/data
|
||||
- /server/server_fail2ban/fail2ban-run:/var/run/fail2ban:ro
|
||||
- /server/server_fail2ban/config:/config:rw
|
||||
expose:
|
||||
- "8000"
|
||||
networks:
|
||||
- bangui-net
|
||||
|
||||
# ── Frontend (nginx serving built SPA + API proxy) ──────────
|
||||
frontend:
|
||||
image: git.lpl-mind.de/lukas.pupkalipinski/bangui/frontend:latest
|
||||
container_name: bangui-frontend
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- PUID=1011
|
||||
- PGID=1001
|
||||
ports:
|
||||
- "${BANGUI_PORT:-8080}:80"
|
||||
depends_on:
|
||||
backend:
|
||||
condition: service_started
|
||||
networks:
|
||||
- bangui-net
|
||||
|
||||
networks:
|
||||
bangui-net:
|
||||
name: bangui-net
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
This directory contains the fail2ban configuration and supporting scripts for a
|
||||
self-contained development test environment. A simulation script writes fake
|
||||
authentication-failure log lines, fail2ban detects them via the `bangui-sim`
|
||||
authentication-failure log lines, fail2ban detects them via the `manual-Jail`
|
||||
jail, and bans the offending IP — giving a fully reproducible ban/unban cycle
|
||||
without a real service.
|
||||
|
||||
@@ -71,14 +71,14 @@ Chains steps 1–3 automatically with appropriate sleep intervals.
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `fail2ban/filter.d/bangui-sim.conf` | Defines the `failregex` that matches simulation log lines |
|
||||
| `fail2ban/jail.d/bangui-sim.conf` | Jail settings: `maxretry=3`, `bantime=60s`, `findtime=120s` |
|
||||
| `fail2ban/filter.d/manual-Jail.conf` | Defines the `failregex` that matches simulation log lines |
|
||||
| `fail2ban/jail.d/manual-Jail.conf` | Jail settings: `maxretry=3`, `bantime=60s`, `findtime=120s` |
|
||||
| `Docker/logs/auth.log` | Log file written by the simulation script (host path) |
|
||||
|
||||
Inside the container the log file is mounted at `/remotelogs/bangui/auth.log`
|
||||
(see `fail2ban/paths-lsio.conf` — `remote_logs_path = /remotelogs`).
|
||||
|
||||
To change sensitivity, edit `fail2ban/jail.d/bangui-sim.conf`:
|
||||
To change sensitivity, edit `fail2ban/jail.d/manual-Jail.conf`:
|
||||
|
||||
```ini
|
||||
maxretry = 3 # failures before a ban
|
||||
@@ -108,14 +108,14 @@ Test the regex manually:
|
||||
|
||||
```bash
|
||||
docker exec bangui-fail2ban-dev \
|
||||
fail2ban-regex /remotelogs/bangui/auth.log bangui-sim
|
||||
fail2ban-regex /remotelogs/bangui/auth.log manual-Jail
|
||||
```
|
||||
|
||||
The output should show matched lines. If nothing matches, check that the log
|
||||
lines match the corresponding `failregex` pattern:
|
||||
|
||||
```
|
||||
# bangui-sim (auth log):
|
||||
# manual-Jail (auth log):
|
||||
YYYY-MM-DD HH:MM:SS bangui-auth: authentication failure from <IP>
|
||||
```
|
||||
|
||||
@@ -132,7 +132,7 @@ sudo modprobe ip_tables
|
||||
### IP not banned despite enough failures
|
||||
|
||||
Check whether the source IP falls inside the `ignoreip` range defined in
|
||||
`fail2ban/jail.d/bangui-sim.conf`:
|
||||
`fail2ban/jail.d/manual-Jail.conf`:
|
||||
|
||||
```ini
|
||||
ignoreip = 127.0.0.0/8 ::1 172.16.0.0/12
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#
|
||||
# Matches lines written by Docker/simulate_failed_logins.sh
|
||||
# Format: <timestamp> bangui-auth: authentication failure from <HOST>
|
||||
# Jail: manual-Jail
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
[Definition]
|
||||
@@ -18,8 +18,8 @@ logpath = /dev/null
|
||||
backend = auto
|
||||
maxretry = 1
|
||||
findtime = 1d
|
||||
# Block imported IPs for one week.
|
||||
bantime = 1w
|
||||
# Block imported IPs for 24 hours.
|
||||
bantime = 86400
|
||||
|
||||
# Never ban the Docker bridge network or localhost.
|
||||
ignoreip = 127.0.0.0/8 ::1 172.16.0.0/12
|
||||
|
||||
@@ -5,10 +5,10 @@
|
||||
# for lines produced by Docker/simulate_failed_logins.sh.
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
[bangui-sim]
|
||||
[manual-Jail]
|
||||
|
||||
enabled = true
|
||||
filter = bangui-sim
|
||||
filter = manual-Jail
|
||||
logpath = /remotelogs/bangui/auth.log
|
||||
backend = polling
|
||||
maxretry = 3
|
||||
@@ -56,11 +56,8 @@ echo " Registry : ${REGISTRY}"
|
||||
echo " Tag : ${TAG}"
|
||||
echo "============================================"
|
||||
|
||||
if [[ "${ENGINE}" == "podman" ]]; then
|
||||
if ! podman login --get-login "${REGISTRY}" &>/dev/null; then
|
||||
err "Not logged in. Run:\n podman login ${REGISTRY}"
|
||||
fi
|
||||
fi
|
||||
log "Logging in to ${REGISTRY}"
|
||||
"${ENGINE}" login "${REGISTRY}"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Build
|
||||
|
||||
91
Docker/release.sh
Normal file
91
Docker/release.sh
Normal file
@@ -0,0 +1,91 @@
|
||||
#!/usr/bin/env bash
|
||||
#
|
||||
# Bump the project version and push images to the registry.
|
||||
#
|
||||
# Usage:
|
||||
# ./release.sh
|
||||
#
|
||||
# The current version is stored in VERSION (next to this script).
|
||||
# You will be asked whether to bump major, minor, or patch.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
VERSION_FILE="${SCRIPT_DIR}/VERSION"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Read current version
|
||||
# ---------------------------------------------------------------------------
|
||||
if [[ ! -f "${VERSION_FILE}" ]]; then
|
||||
echo "0.0.0" > "${VERSION_FILE}"
|
||||
fi
|
||||
|
||||
CURRENT="$(cat "${VERSION_FILE}")"
|
||||
# Strip leading 'v' for arithmetic
|
||||
VERSION="${CURRENT#v}"
|
||||
|
||||
IFS='.' read -r MAJOR MINOR PATCH <<< "${VERSION}"
|
||||
|
||||
echo "============================================"
|
||||
echo " BanGUI — Release"
|
||||
echo " Current version: v${MAJOR}.${MINOR}.${PATCH}"
|
||||
echo "============================================"
|
||||
echo ""
|
||||
echo "How would you like to bump the version?"
|
||||
echo " 1) patch (v${MAJOR}.${MINOR}.${PATCH} → v${MAJOR}.${MINOR}.$((PATCH + 1)))"
|
||||
echo " 2) minor (v${MAJOR}.${MINOR}.${PATCH} → v${MAJOR}.$((MINOR + 1)).0)"
|
||||
echo " 3) major (v${MAJOR}.${MINOR}.${PATCH} → v$((MAJOR + 1)).0.0)"
|
||||
echo ""
|
||||
read -rp "Enter choice [1/2/3]: " CHOICE
|
||||
|
||||
case "${CHOICE}" in
|
||||
1) NEW_TAG="v${MAJOR}.${MINOR}.$((PATCH + 1))" ;;
|
||||
2) NEW_TAG="v${MAJOR}.$((MINOR + 1)).0" ;;
|
||||
3) NEW_TAG="v$((MAJOR + 1)).0.0" ;;
|
||||
*)
|
||||
echo "Invalid choice. Aborting." >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
echo ""
|
||||
echo "New version: ${NEW_TAG}"
|
||||
read -rp "Confirm? [y/N]: " CONFIRM
|
||||
if [[ ! "${CONFIRM}" =~ ^[yY]$ ]]; then
|
||||
echo "Aborted."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Write new version
|
||||
# ---------------------------------------------------------------------------
|
||||
echo "${NEW_TAG}" > "${VERSION_FILE}"
|
||||
echo "Version file updated → ${VERSION_FILE}"
|
||||
|
||||
# Keep frontend/package.json in sync so __APP_VERSION__ matches Docker/VERSION.
|
||||
FRONT_VERSION="${NEW_TAG#v}"
|
||||
FRONT_PKG="${SCRIPT_DIR}/../frontend/package.json"
|
||||
sed -i "s/\"version\": \"[^\"]*\"/\"version\": \"${FRONT_VERSION}\"/" "${FRONT_PKG}"
|
||||
echo "frontend/package.json version updated → ${FRONT_VERSION}"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Git tag (local only; push after container build)
|
||||
# ---------------------------------------------------------------------------
|
||||
cd "${SCRIPT_DIR}/.."
|
||||
git add Docker/VERSION frontend/package.json
|
||||
git commit -m "chore: release ${NEW_TAG}"
|
||||
git tag -a "${NEW_TAG}" -m "Release ${NEW_TAG}"
|
||||
echo "Local git commit + tag ${NEW_TAG} created."
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Push containers
|
||||
# ---------------------------------------------------------------------------
|
||||
bash "${SCRIPT_DIR}/push.sh" "${NEW_TAG}"
|
||||
bash "${SCRIPT_DIR}/push.sh"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Push git commits & tag
|
||||
# ---------------------------------------------------------------------------
|
||||
git push origin HEAD
|
||||
git push origin "${NEW_TAG}"
|
||||
echo "Git commit and tag ${NEW_TAG} pushed."
|
||||
@@ -3,7 +3,7 @@
|
||||
# simulate_failed_logins.sh
|
||||
#
|
||||
# Writes synthetic authentication-failure log lines to a file
|
||||
# that matches the bangui-sim fail2ban filter.
|
||||
# that matches the manual-Jail fail2ban filter.
|
||||
#
|
||||
# Usage:
|
||||
# bash Docker/simulate_failed_logins.sh [COUNT] [SOURCE_IP] [LOG_FILE]
|
||||
@@ -13,7 +13,7 @@
|
||||
# SOURCE_IP: 192.168.100.99
|
||||
# LOG_FILE : Docker/logs/auth.log (relative to repo root)
|
||||
#
|
||||
# Log line format (must match bangui-sim failregex exactly):
|
||||
# Log line format (must match manual-Jail failregex exactly):
|
||||
# YYYY-MM-DD HH:MM:SS bangui-auth: authentication failure from <IP>
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -82,10 +82,12 @@ The backend follows a **layered architecture** with strict separation of concern
|
||||
backend/
|
||||
├── app/
|
||||
│ ├── __init__.py
|
||||
│ ├── main.py # FastAPI app factory, lifespan, exception handlers
|
||||
│ ├── config.py # Pydantic settings (env vars, .env loading)
|
||||
│ ├── dependencies.py # FastAPI Depends() providers (DB, services, auth)
|
||||
│ ├── models/ # Pydantic schemas
|
||||
│ ├── `main.py` # FastAPI app factory, lifespan, exception handlers
|
||||
│ ├── `config.py` # Pydantic settings (env vars, .env loading)
|
||||
│ ├── `db.py` # Database connection and initialization
|
||||
│ ├── `exceptions.py` # Shared domain exception classes
|
||||
│ ├── `dependencies.py` # FastAPI Depends() providers (DB, services, auth)
|
||||
│ ├── `models/` # Pydantic schemas
|
||||
│ │ ├── auth.py # Login request/response, session models
|
||||
│ │ ├── ban.py # Ban request/response/domain models
|
||||
│ │ ├── jail.py # Jail request/response/domain models
|
||||
@@ -111,6 +113,12 @@ backend/
|
||||
│ │ ├── jail_service.py # Jail listing, start/stop/reload, status aggregation
|
||||
│ │ ├── ban_service.py # Ban/unban execution, currently-banned queries
|
||||
│ │ ├── config_service.py # Read/write fail2ban config, regex validation
|
||||
│ │ ├── config_file_service.py # Shared config parsing and file-level operations
|
||||
│ │ ├── raw_config_io_service.py # Raw config file I/O wrapper
|
||||
│ │ ├── jail_config_service.py # jail config activation/deactivation logic
|
||||
│ │ ├── filter_config_service.py # filter config lifecycle management
|
||||
│ │ ├── action_config_service.py # action config lifecycle management
|
||||
│ │ ├── log_service.py # Log preview and regex test operations
|
||||
│ │ ├── history_service.py # Historical ban queries, per-IP timeline
|
||||
│ │ ├── blocklist_service.py # Download, validate, apply blocklists
|
||||
│ │ ├── geo_service.py # IP-to-country resolution, ASN/RIR lookup
|
||||
@@ -119,17 +127,18 @@ backend/
|
||||
│ ├── repositories/ # Data access layer (raw queries only)
|
||||
│ │ ├── settings_repo.py # App configuration CRUD in SQLite
|
||||
│ │ ├── session_repo.py # Session storage and lookup
|
||||
│ │ ├── blocklist_repo.py # Blocklist sources and import log persistence
|
||||
│ │ └── import_log_repo.py # Import run history records
|
||||
│ │ ├── blocklist_repo.py # Blocklist sources and import log persistence│ │ ├── fail2ban_db_repo.py # fail2ban SQLite ban history read operations
|
||||
│ │ ├── geo_cache_repo.py # IP geolocation cache persistence│ │ └── import_log_repo.py # Import run history records
|
||||
│ ├── tasks/ # APScheduler background jobs
|
||||
│ │ ├── blocklist_import.py# Scheduled blocklist download and application
|
||||
│ │ ├── geo_cache_flush.py # Periodic geo cache persistence (dirty-set flush to SQLite)
|
||||
│ │ └── health_check.py # Periodic fail2ban connectivity probe
|
||||
│ │ ├── geo_cache_flush.py # Periodic geo cache persistence (dirty-set flush to SQLite)│ │ ├── geo_re_resolve.py # Periodic re-resolution of stale geo cache records│ │ └── health_check.py # Periodic fail2ban connectivity probe
|
||||
│ └── utils/ # Helpers, constants, shared types
|
||||
│ ├── fail2ban_client.py # Async wrapper around the fail2ban socket protocol
|
||||
│ ├── ip_utils.py # IP/CIDR validation and normalisation
|
||||
│ ├── time_utils.py # Timezone-aware datetime helpers
|
||||
│ └── constants.py # Shared constants (default paths, limits, etc.)
|
||||
│ ├── time_utils.py # Timezone-aware datetime helpers│ ├── jail_config.py # Jail config parser/serializer helper
|
||||
│ ├── conffile_parser.py # Fail2ban config file parser/serializer
|
||||
│ ├── config_parser.py # Structured config object parser
|
||||
│ ├── config_writer.py # Atomic config file write operations│ └── constants.py # Shared constants (default paths, limits, etc.)
|
||||
├── tests/
|
||||
│ ├── conftest.py # Shared fixtures (test app, client, mock DB)
|
||||
│ ├── test_routers/ # One test file per router
|
||||
@@ -158,8 +167,9 @@ The HTTP interface layer. Each router maps URL paths to handler functions. Route
|
||||
| `blocklist.py` | `/api/blocklists` | CRUD blocklist sources, trigger import, view import logs |
|
||||
| `geo.py` | `/api/geo` | IP geolocation lookup, ASN and RIR data |
|
||||
| `server.py` | `/api/server` | Log level, log target, DB path, purge age, flush logs |
|
||||
| `health.py` | `/api/health` | fail2ban connectivity health check and status |
|
||||
|
||||
#### Services (`app/services/`)
|
||||
#### Services (`app/services`)
|
||||
|
||||
The business logic layer. Services orchestrate operations, enforce rules, and coordinate between repositories, the fail2ban client, and external APIs. Each service covers a single domain.
|
||||
|
||||
@@ -171,8 +181,12 @@ The business logic layer. Services orchestrate operations, enforce rules, and co
|
||||
| `ban_service.py` | Executes ban and unban commands via the fail2ban socket, queries the currently banned IP list, validates IPs before banning |
|
||||
| `config_service.py` | Reads active jail and filter configuration from fail2ban, writes configuration changes, validates regex patterns, triggers reload; reads the fail2ban log file tail and queries service status for the Log tab |
|
||||
| `file_config_service.py` | Reads and writes raw fail2ban config files on disk (jail.d/, filter.d/, action.d/); lists files, reads content, overwrites files, toggles enabled/disabled |
|
||||
| `config_file_service.py` | Parses jail.conf / jail.local / jail.d/* to discover inactive jails; writes .local overrides to activate or deactivate jails; triggers fail2ban reload |
|
||||
| `conffile_parser.py` | Parses fail2ban `.conf` files into structured Python types (jail config, filter config, action config); also serialises back to text |
|
||||
| `jail_config_service.py` | Discovers inactive jails by parsing jail.conf / jail.local / jail.d/*; writes .local overrides to activate/deactivate jails; triggers fail2ban reload; validates jail configurations |
|
||||
| `filter_config_service.py` | Discovers available filters by scanning filter.d/; reads, creates, updates, and deletes filter definitions; assigns filters to jails |
|
||||
| `action_config_service.py` | Discovers available actions by scanning action.d/; reads, creates, updates, and deletes action definitions; assigns actions to jails |
|
||||
| `config_file_service.py` | Shared utilities for configuration parsing and manipulation: parses config files, validates names/IPs, manages atomic file writes, probes fail2ban socket |
|
||||
| `raw_config_io_service.py` | Low-level file I/O for raw fail2ban config files |
|
||||
| `log_service.py` | Log preview and regex test operations (extracted from config_service) |
|
||||
| `history_service.py` | Queries the fail2ban database for historical ban records, builds per-IP timelines, computes ban counts and repeat-offender flags |
|
||||
| `blocklist_service.py` | Downloads blocklists via aiohttp, validates IPs/CIDRs, applies bans through fail2ban or iptables, logs import results |
|
||||
| `geo_service.py` | Resolves IP addresses to country, ASN, and RIR using external APIs or a local database, caches results |
|
||||
@@ -188,15 +202,26 @@ The data access layer. Repositories execute raw SQL queries against the applicat
|
||||
| `settings_repo.py` | CRUD operations for application settings (master password hash, DB path, fail2ban socket path, preferences) |
|
||||
| `session_repo.py` | Store, retrieve, and delete session records for authentication |
|
||||
| `blocklist_repo.py` | Persist blocklist source definitions (name, URL, enabled/disabled) |
|
||||
| `fail2ban_db_repo.py` | Read historical ban records from the fail2ban SQLite database |
|
||||
| `geo_cache_repo.py` | Persist and query IP geo resolution cache |
|
||||
| `import_log_repo.py` | Record import run results (timestamp, source, IPs imported, errors) for the import log view |
|
||||
|
||||
#### Models (`app/models/`)
|
||||
|
||||
Pydantic schemas that define data shapes and validation. Models are split into three categories per domain:
|
||||
Pydantic schemas that define data shapes and validation. Models are split into three categories per domain.
|
||||
|
||||
- **Request models** — validate incoming API data (e.g., `BanRequest`, `LoginRequest`)
|
||||
- **Response models** — shape outgoing API data (e.g., `JailResponse`, `BanListResponse`)
|
||||
- **Domain models** — internal representations used between services and repositories (e.g., `Ban`, `Jail`)
|
||||
| Model file | Purpose |
|
||||
|---|---|
|
||||
| `auth.py` | Login/request and session models |
|
||||
| `ban.py` | Ban creation and lookup models |
|
||||
| `blocklist.py` | Blocklist source and import log models |
|
||||
| `config.py` | Fail2ban config view/edit models |
|
||||
| `file_config.py` | Raw config file read/write models |
|
||||
| `geo.py` | Geo and ASN lookup models |
|
||||
| `history.py` | Historical ban query and timeline models |
|
||||
| `jail.py` | Jail listing and status models |
|
||||
| `server.py` | Server status and settings models |
|
||||
| `setup.py` | First-run setup wizard models |
|
||||
|
||||
#### Tasks (`app/tasks/`)
|
||||
|
||||
@@ -206,6 +231,7 @@ APScheduler background jobs that run on a schedule without user interaction.
|
||||
|---|---|
|
||||
| `blocklist_import.py` | Downloads all enabled blocklist sources, validates entries, applies bans, records results in the import log |
|
||||
| `geo_cache_flush.py` | Periodically flushes newly resolved IPs from the in-memory dirty set to the `geo_cache` SQLite table (default: every 60 seconds). GET requests populate only the in-memory cache; this task persists them without blocking any request. |
|
||||
| `geo_re_resolve.py` | Periodically re-resolves stale entries in `geo_cache` to keep geolocation data fresh |
|
||||
| `health_check.py` | Periodically pings the fail2ban socket and updates the cached server status so the frontend always has fresh data |
|
||||
|
||||
#### Utils (`app/utils/`)
|
||||
@@ -216,7 +242,16 @@ Pure helper modules with no framework dependencies.
|
||||
|---|---|
|
||||
| `fail2ban_client.py` | Async client that communicates with fail2ban via its Unix domain socket — sends commands and parses responses using the fail2ban protocol. Modelled after [`./fail2ban-master/fail2ban/client/csocket.py`](../fail2ban-master/fail2ban/client/csocket.py) and [`./fail2ban-master/fail2ban/client/fail2banclient.py`](../fail2ban-master/fail2ban/client/fail2banclient.py). |
|
||||
| `ip_utils.py` | Validates IPv4/IPv6 addresses and CIDR ranges using the `ipaddress` stdlib module, normalises formats |
|
||||
| `jail_utils.py` | Jail helper functions for configuration and status inference |
|
||||
| `jail_config.py` | Jail config parser and serializer for fail2ban config manipulation |
|
||||
| `time_utils.py` | Timezone-aware datetime construction, formatting helpers, time-range calculations |
|
||||
| `log_utils.py` | Structured log formatting and enrichment helpers |
|
||||
| `conffile_parser.py` | Parses Fail2ban `.conf` files into structured objects and serialises back to text |
|
||||
| `config_parser.py` | Builds structured config objects from file content tokens |
|
||||
| `config_writer.py` | Atomic config file writes, backups, and safe replace semantics |
|
||||
| `config_file_utils.py` | Common file-level config utility helpers |
|
||||
| `fail2ban_db_utils.py` | Fail2ban DB path discovery and ban-history parsing helpers |
|
||||
| `setup_utils.py` | Setup wizard helper utilities |
|
||||
| `constants.py` | Shared constants: default socket path, default database path, time-range presets, limits |
|
||||
|
||||
#### Configuration (`app/config.py`)
|
||||
|
||||
5
Docs/Refactoring.md
Normal file
5
Docs/Refactoring.md
Normal file
@@ -0,0 +1,5 @@
|
||||
# BanGUI — Architecture Issues & Refactoring Plan
|
||||
|
||||
This document catalogues architecture violations, code smells, and structural issues found during a full project review. Issues are grouped by category and prioritised.
|
||||
|
||||
---
|
||||
134
Docs/Tasks.md
134
Docs/Tasks.md
@@ -2,138 +2,8 @@
|
||||
|
||||
This document breaks the entire BanGUI project into development stages, ordered so that each stage builds on the previous one. Every task is described in prose with enough detail for a developer to begin work. References point to the relevant documentation.
|
||||
|
||||
---
|
||||
|
||||
## Agent Operating Instructions
|
||||
|
||||
These instructions apply to every AI agent working in this repository. Read them fully before touching any file.
|
||||
|
||||
### Before You Begin
|
||||
|
||||
1. Read [Instructions.md](Instructions.md) in full — it defines the project context, coding standards, and workflow rules. Every rule there is authoritative and takes precedence over any assumption you make.
|
||||
2. Read [Architekture.md](Architekture.md) to understand the system structure before touching any component.
|
||||
3. Read the development guide relevant to your task: [Backend-Development.md](Backend-Development.md) or [Web-Development.md](Web-Development.md) (or both).
|
||||
4. Read [Features.md](Features.md) so you understand what the product is supposed to do and do not accidentally break intended behaviour.
|
||||
|
||||
### How to Work Through This Document
|
||||
|
||||
- Tasks are grouped by feature area. Each group is self-contained.
|
||||
- Work through tasks in the order they appear within a group; earlier tasks establish foundations for later ones.
|
||||
- Mark a task **in-progress** before you start it and **completed** the moment it is done. Never batch completions.
|
||||
- If a task depends on another task that is not yet complete, stop and complete the dependency first.
|
||||
- If you are uncertain whether a change is correct, read the relevant documentation section again before proceeding. Do not guess.
|
||||
|
||||
### Code Quality Rules (Summary)
|
||||
|
||||
- No TODOs, no placeholders, no half-finished functions.
|
||||
- Full type annotations on every function (Python) and full TypeScript types on every symbol (no `any`).
|
||||
- Layered architecture: routers → services → repositories. No layer may skip another.
|
||||
- All backend errors are raised as typed HTTP exceptions; all unexpected errors are logged via structlog before re-raising.
|
||||
- All frontend state lives in typed hooks; no raw `fetch` calls outside of the `api/` layer.
|
||||
- After every code change, run the full test suite (`make test`) and ensure it is green.
|
||||
|
||||
### Definition of Done
|
||||
|
||||
A task is done when:
|
||||
- The code compiles and the test suite passes (`make test`).
|
||||
- The feature works end-to-end in the dev stack (`make up`).
|
||||
- No new lint errors are introduced.
|
||||
- The change is consistent with all documentation rules.
|
||||
|
||||
---
|
||||
|
||||
## Bug Fixes
|
||||
|
||||
---
|
||||
|
||||
### BUG-001 — fail2ban: `bangui-sim` jail fails to start due to missing `banaction`
|
||||
|
||||
**Status:** Done
|
||||
|
||||
**Summary:** `jail.local` created with `[DEFAULT]` overrides for `banaction` and `banaction_allports`. The container init script (`init-fail2ban-config`) overwrites `jail.conf` from the image's `/defaults/` on every start, so modifying `jail.conf` directly is ineffective. `jail.local` is not in the container's defaults and thus persists correctly. Additionally fixed a `TypeError` in `config_file_service.py` where `except jail_service.JailNotFoundError` failed when `jail_service` was mocked in tests — resolved by importing `JailNotFoundError` directly.
|
||||
|
||||
#### Error
|
||||
|
||||
```
|
||||
Failed during configuration: Bad value substitution: option 'action' in section 'bangui-sim'
|
||||
contains an interpolation key 'banaction' which is not a valid option name.
|
||||
Raw value: '%(action_)s'
|
||||
```
|
||||
|
||||
#### Root Cause
|
||||
|
||||
fail2ban's interpolation system resolves option values at configuration load time by
|
||||
substituting `%(key)s` placeholders with values from the same section or from `[DEFAULT]`.
|
||||
|
||||
The chain that fails is:
|
||||
|
||||
1. Every jail inherits `action = %(action_)s` from `[DEFAULT]` (no override in `bangui-sim.conf`).
|
||||
2. `action_` is defined in `[DEFAULT]` as `%(banaction)s[port="%(port)s", protocol="%(protocol)s", chain="%(chain)s"]`.
|
||||
3. `banaction` is **commented out** in `[DEFAULT]`:
|
||||
```ini
|
||||
# Docker/fail2ban-dev-config/fail2ban/jail.conf [DEFAULT]
|
||||
#banaction = iptables-multiport ← this line is disabled
|
||||
```
|
||||
4. Because `banaction` is absent from the interpolation namespace, fail2ban cannot resolve
|
||||
`action_`, which makes it unable to resolve `action`, and the jail fails to load.
|
||||
|
||||
The same root cause affects every jail in `jail.d/` that does not define its own `banaction`,
|
||||
including `blocklist-import.conf`.
|
||||
|
||||
#### Fix
|
||||
|
||||
**File:** `Docker/fail2ban-dev-config/fail2ban/jail.conf`
|
||||
|
||||
Uncomment the `banaction` line inside the `[DEFAULT]` section so the value is globally
|
||||
available to all jails:
|
||||
|
||||
```ini
|
||||
banaction = iptables-multiport
|
||||
banaction_allports = iptables-allports
|
||||
```
|
||||
|
||||
This is safe: the dev compose (`Docker/compose.debug.yml`) already grants the fail2ban
|
||||
container `NET_ADMIN` and `NET_RAW` capabilities, which are the prerequisites for
|
||||
iptables-based banning.
|
||||
|
||||
#### Tasks
|
||||
|
||||
- [x] **BUG-001-T1 — Add `banaction` override via `jail.local` [DEFAULT]**
|
||||
|
||||
Open `Docker/fail2ban-dev-config/fail2ban/jail.conf`.
|
||||
Find the two commented-out lines near the `action_` definition:
|
||||
```ini
|
||||
#banaction = iptables-multiport
|
||||
#banaction_allports = iptables-allports
|
||||
```
|
||||
Remove the leading `#` from both lines so they become active options.
|
||||
Do not change any other part of the file.
|
||||
|
||||
- [x] **BUG-001-T2 — Restart the fail2ban container and verify clean startup**
|
||||
|
||||
Bring the dev stack down and back up:
|
||||
```bash
|
||||
make down && make up
|
||||
```
|
||||
Wait for the fail2ban container to reach `healthy`, then inspect its logs:
|
||||
```bash
|
||||
make logs # or: docker logs bangui-fail2ban-dev 2>&1 | grep -i error
|
||||
```
|
||||
Confirm that no `Bad value substitution` or `Failed during configuration` lines appear
|
||||
and that both `bangui-sim` and `blocklist-import` jails show as **enabled** in the output.
|
||||
|
||||
- [x] **BUG-001-T3 — Verify ban/unban cycle works end-to-end**
|
||||
|
||||
With the stack running, trigger the simulation script:
|
||||
```bash
|
||||
bash Docker/simulate_failed_logins.sh
|
||||
```
|
||||
Then confirm fail2ban has recorded a ban:
|
||||
```bash
|
||||
bash Docker/check_ban_status.sh
|
||||
```
|
||||
The script should report at least one banned IP in the `bangui-sim` jail.
|
||||
Also verify that the BanGUI dashboard reflects the new ban entry.
|
||||
Reference: `Docs/Refactoring.md` for full analysis of each issue.
|
||||
|
||||
---
|
||||
|
||||
## Open Issues
|
||||
|
||||
224
backend/EXTRACTION_SUMMARY.md
Normal file
224
backend/EXTRACTION_SUMMARY.md
Normal file
@@ -0,0 +1,224 @@
|
||||
# Config File Service Extraction Summary
|
||||
|
||||
## ✓ Extraction Complete
|
||||
|
||||
Three new service modules have been created by extracting functions from `config_file_service.py`.
|
||||
|
||||
### Files Created
|
||||
|
||||
| File | Lines | Status |
|
||||
|------|-------|--------|
|
||||
| [jail_config_service.py](jail_config_service.py) | 991 | ✓ Created |
|
||||
| [filter_config_service.py](filter_config_service.py) | 765 | ✓ Created |
|
||||
| [action_config_service.py](action_config_service.py) | 988 | ✓ Created |
|
||||
| **Total** | **2,744** | **✓ Verified** |
|
||||
|
||||
---
|
||||
|
||||
## 1. JAIL_CONFIG Service (`jail_config_service.py`)
|
||||
|
||||
### Public Functions (7)
|
||||
- `list_inactive_jails(config_dir, socket_path)` → InactiveJailListResponse
|
||||
- `activate_jail(config_dir, socket_path, name, req)` → JailActivationResponse
|
||||
- `deactivate_jail(config_dir, socket_path, name)` → JailActivationResponse
|
||||
- `delete_jail_local_override(config_dir, socket_path, name)` → None
|
||||
- `validate_jail_config(config_dir, name)` → JailValidationResult
|
||||
- `rollback_jail(config_dir, socket_path, name, start_cmd_parts)` → RollbackResponse
|
||||
- `_rollback_activation_async(config_dir, name, socket_path, original_content)` → bool
|
||||
|
||||
### Helper Functions (5)
|
||||
- `_write_local_override_sync()` - Atomic write of jail.d/{name}.local
|
||||
- `_restore_local_file_sync()` - Restore or delete .local file during rollback
|
||||
- `_validate_regex_patterns()` - Validate failregex/ignoreregex patterns
|
||||
- `_set_jail_local_key_sync()` - Update single key in jail section
|
||||
- `_validate_jail_config_sync()` - Synchronous validation (filter/action files, patterns, logpath)
|
||||
|
||||
### Custom Exceptions (3)
|
||||
- `JailNotFoundInConfigError`
|
||||
- `JailAlreadyActiveError`
|
||||
- `JailAlreadyInactiveError`
|
||||
|
||||
### Shared Dependencies Imported
|
||||
- `_safe_jail_name()` - From config_file_service
|
||||
- `_parse_jails_sync()` - From config_file_service
|
||||
- `_build_inactive_jail()` - From config_file_service
|
||||
- `_get_active_jail_names()` - From config_file_service
|
||||
- `_probe_fail2ban_running()` - From config_file_service
|
||||
- `wait_for_fail2ban()` - From config_file_service
|
||||
- `start_daemon()` - From config_file_service
|
||||
- `_resolve_filter()` - From config_file_service
|
||||
- `_parse_multiline()` - From config_file_service
|
||||
- `_SOCKET_TIMEOUT`, `_META_SECTIONS` - Constants
|
||||
|
||||
---
|
||||
|
||||
## 2. FILTER_CONFIG Service (`filter_config_service.py`)
|
||||
|
||||
### Public Functions (6)
|
||||
- `list_filters(config_dir, socket_path)` → FilterListResponse
|
||||
- `get_filter(config_dir, socket_path, name)` → FilterConfig
|
||||
- `update_filter(config_dir, socket_path, name, req, do_reload=False)` → FilterConfig
|
||||
- `create_filter(config_dir, socket_path, req, do_reload=False)` → FilterConfig
|
||||
- `delete_filter(config_dir, name)` → None
|
||||
- `assign_filter_to_jail(config_dir, socket_path, jail_name, req, do_reload=False)` → None
|
||||
|
||||
### Helper Functions (4)
|
||||
- `_extract_filter_base_name(filter_raw)` - Extract base name from filter string
|
||||
- `_build_filter_to_jails_map()` - Map filters to jails using them
|
||||
- `_parse_filters_sync()` - Scan filter.d/ and return tuples
|
||||
- `_write_filter_local_sync()` - Atomic write of filter.d/{name}.local
|
||||
- `_validate_regex_patterns()` - Validate regex patterns (shared with jail_config)
|
||||
|
||||
### Custom Exceptions (5)
|
||||
- `FilterNotFoundError`
|
||||
- `FilterAlreadyExistsError`
|
||||
- `FilterReadonlyError`
|
||||
- `FilterInvalidRegexError`
|
||||
- `FilterNameError` (re-exported from config_file_service)
|
||||
|
||||
### Shared Dependencies Imported
|
||||
- `_safe_filter_name()` - From config_file_service
|
||||
- `_safe_jail_name()` - From config_file_service
|
||||
- `_parse_jails_sync()` - From config_file_service
|
||||
- `_get_active_jail_names()` - From config_file_service
|
||||
- `_resolve_filter()` - From config_file_service
|
||||
- `_parse_multiline()` - From config_file_service
|
||||
- `_SAFE_FILTER_NAME_RE` - Constant pattern
|
||||
|
||||
---
|
||||
|
||||
## 3. ACTION_CONFIG Service (`action_config_service.py`)
|
||||
|
||||
### Public Functions (7)
|
||||
- `list_actions(config_dir, socket_path)` → ActionListResponse
|
||||
- `get_action(config_dir, socket_path, name)` → ActionConfig
|
||||
- `update_action(config_dir, socket_path, name, req, do_reload=False)` → ActionConfig
|
||||
- `create_action(config_dir, socket_path, req, do_reload=False)` → ActionConfig
|
||||
- `delete_action(config_dir, name)` → None
|
||||
- `assign_action_to_jail(config_dir, socket_path, jail_name, req, do_reload=False)` → None
|
||||
- `remove_action_from_jail(config_dir, socket_path, jail_name, action_name, do_reload=False)` → None
|
||||
|
||||
### Helper Functions (5)
|
||||
- `_safe_action_name(name)` - Validate action name
|
||||
- `_extract_action_base_name()` - Extract base name from action string
|
||||
- `_build_action_to_jails_map()` - Map actions to jails using them
|
||||
- `_parse_actions_sync()` - Scan action.d/ and return tuples
|
||||
- `_append_jail_action_sync()` - Append action to jail.d/{name}.local
|
||||
- `_remove_jail_action_sync()` - Remove action from jail.d/{name}.local
|
||||
- `_write_action_local_sync()` - Atomic write of action.d/{name}.local
|
||||
|
||||
### Custom Exceptions (4)
|
||||
- `ActionNotFoundError`
|
||||
- `ActionAlreadyExistsError`
|
||||
- `ActionReadonlyError`
|
||||
- `ActionNameError`
|
||||
|
||||
### Shared Dependencies Imported
|
||||
- `_safe_jail_name()` - From config_file_service
|
||||
- `_parse_jails_sync()` - From config_file_service
|
||||
- `_get_active_jail_names()` - From config_file_service
|
||||
- `_build_parser()` - From config_file_service
|
||||
- `_SAFE_ACTION_NAME_RE` - Constant pattern
|
||||
|
||||
---
|
||||
|
||||
## 4. SHARED Utilities (remain in `config_file_service.py`)
|
||||
|
||||
### Utility Functions (14)
|
||||
- `_safe_jail_name(name)` → str
|
||||
- `_safe_filter_name(name)` → str
|
||||
- `_ordered_config_files(config_dir)` → list[Path]
|
||||
- `_build_parser()` → configparser.RawConfigParser
|
||||
- `_is_truthy(value)` → bool
|
||||
- `_parse_int_safe(value)` → int | None
|
||||
- `_parse_time_to_seconds(value, default)` → int
|
||||
- `_parse_multiline(raw)` → list[str]
|
||||
- `_resolve_filter(raw_filter, jail_name, mode)` → str
|
||||
- `_parse_jails_sync(config_dir)` → tuple
|
||||
- `_build_inactive_jail(name, settings, source_file, config_dir=None)` → InactiveJail
|
||||
- `_get_active_jail_names(socket_path)` → set[str]
|
||||
- `_probe_fail2ban_running(socket_path)` → bool
|
||||
- `wait_for_fail2ban(socket_path, max_wait_seconds, poll_interval)` → bool
|
||||
- `start_daemon(start_cmd_parts)` → bool
|
||||
|
||||
### Shared Exceptions (3)
|
||||
- `JailNameError`
|
||||
- `FilterNameError`
|
||||
- `ConfigWriteError`
|
||||
|
||||
### Constants (7)
|
||||
- `_SOCKET_TIMEOUT`
|
||||
- `_SAFE_JAIL_NAME_RE`
|
||||
- `_META_SECTIONS`
|
||||
- `_TRUE_VALUES`
|
||||
- `_FALSE_VALUES`
|
||||
|
||||
---
|
||||
|
||||
## Import Dependencies
|
||||
|
||||
### jail_config_service imports:
|
||||
```python
|
||||
config_file_service: (shared utilities + private functions)
|
||||
jail_service.reload_all()
|
||||
Fail2BanConnectionError
|
||||
```
|
||||
|
||||
### filter_config_service imports:
|
||||
```python
|
||||
config_file_service: (shared utilities + _set_jail_local_key_sync)
|
||||
jail_service.reload_all()
|
||||
conffile_parser: (parse/merge/serialize filter functions)
|
||||
jail_config_service: (JailNotFoundInConfigError - lazy import)
|
||||
```
|
||||
|
||||
### action_config_service imports:
|
||||
```python
|
||||
config_file_service: (shared utilities + _build_parser)
|
||||
jail_service.reload_all()
|
||||
conffile_parser: (parse/merge/serialize action functions)
|
||||
jail_config_service: (JailNotFoundInConfigError - lazy import)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Cross-Service Dependencies
|
||||
|
||||
**Circular imports handled via lazy imports:**
|
||||
- `filter_config_service` imports `JailNotFoundInConfigError` from `jail_config_service` inside function
|
||||
- `action_config_service` imports `JailNotFoundInConfigError` from `jail_config_service` inside function
|
||||
|
||||
**Shared functions re-used:**
|
||||
- `_set_jail_local_key_sync()` exported from `jail_config_service`, used by `filter_config_service`
|
||||
- `_append_jail_action_sync()` and `_remove_jail_action_sync()` internal to `action_config_service`
|
||||
|
||||
---
|
||||
|
||||
## Verification Results
|
||||
|
||||
✓ **Syntax Check:** All three files compile without errors
|
||||
✓ **Import Verification:** All imports resolved correctly
|
||||
✓ **Total Lines:** 2,744 lines across three new files
|
||||
✓ **Function Coverage:** 100% of specified functions extracted
|
||||
✓ **Type Hints:** Preserved throughout
|
||||
✓ **Docstrings:** All preserved with full documentation
|
||||
✓ **Comments:** All inline comments preserved
|
||||
|
||||
---
|
||||
|
||||
## Next Steps (if needed)
|
||||
|
||||
1. **Update router imports** - Point from config_file_service to specific service modules:
|
||||
- `jail_config_service` for jail operations
|
||||
- `filter_config_service` for filter operations
|
||||
- `action_config_service` for action operations
|
||||
|
||||
2. **Update config_file_service.py** - Remove all extracted functions (optional cleanup)
|
||||
- Optionally keep it as a facade/aggregator
|
||||
- Or reduce it to only the shared utilities module
|
||||
|
||||
3. **Add __all__ exports** to each new module for cleaner public API
|
||||
|
||||
4. **Update type hints** in models if needed for cross-service usage
|
||||
|
||||
5. **Testing** - Run existing tests to ensure no regressions
|
||||
@@ -1 +1,68 @@
|
||||
"""BanGUI backend application package."""
|
||||
"""BanGUI backend application package.
|
||||
|
||||
This package exposes the application version based on the project metadata.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Final
|
||||
|
||||
import importlib.metadata
|
||||
import tomllib
|
||||
|
||||
PACKAGE_NAME: Final[str] = "bangui-backend"
|
||||
|
||||
|
||||
def _read_pyproject_version() -> str:
|
||||
"""Read the project version from ``pyproject.toml``.
|
||||
|
||||
This is used as a fallback when the package metadata is not available (e.g.
|
||||
when running directly from a source checkout without installing the package).
|
||||
"""
|
||||
|
||||
project_root = Path(__file__).resolve().parents[1]
|
||||
pyproject_path = project_root / "pyproject.toml"
|
||||
if not pyproject_path.exists():
|
||||
raise FileNotFoundError(f"pyproject.toml not found at {pyproject_path}")
|
||||
|
||||
data = tomllib.loads(pyproject_path.read_text(encoding="utf-8"))
|
||||
return str(data["project"]["version"])
|
||||
|
||||
|
||||
def _read_docker_version() -> str:
|
||||
"""Read the project version from ``Docker/VERSION``.
|
||||
|
||||
This file is the single source of truth for release scripts and must not be
|
||||
out of sync with the frontend and backend versions.
|
||||
"""
|
||||
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
version_path = repo_root / "Docker" / "VERSION"
|
||||
if not version_path.exists():
|
||||
raise FileNotFoundError(f"Docker/VERSION not found at {version_path}")
|
||||
|
||||
version = version_path.read_text(encoding="utf-8").strip()
|
||||
return version.lstrip("v")
|
||||
|
||||
|
||||
def _read_version() -> str:
|
||||
"""Return the current package version.
|
||||
|
||||
Prefer the release artifact in ``Docker/VERSION`` when available so the
|
||||
backend version always matches what the release tooling publishes.
|
||||
|
||||
If that file is missing (e.g. in a production wheel or a local checkout),
|
||||
fall back to ``pyproject.toml`` and finally installed package metadata.
|
||||
"""
|
||||
|
||||
try:
|
||||
return _read_docker_version()
|
||||
except FileNotFoundError:
|
||||
try:
|
||||
return _read_pyproject_version()
|
||||
except FileNotFoundError:
|
||||
return importlib.metadata.version(PACKAGE_NAME)
|
||||
|
||||
|
||||
__version__ = _read_version()
|
||||
|
||||
@@ -7,7 +7,7 @@ directly — to keep coupling explicit and testable.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Protocol, cast
|
||||
|
||||
import aiosqlite
|
||||
import structlog
|
||||
@@ -19,6 +19,13 @@ from app.utils.time_utils import utc_now
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
|
||||
class AppState(Protocol):
|
||||
"""Partial view of the FastAPI application state used by dependencies."""
|
||||
|
||||
settings: Settings
|
||||
|
||||
|
||||
_COOKIE_NAME = "bangui_session"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -85,7 +92,8 @@ async def get_settings(request: Request) -> Settings:
|
||||
Returns:
|
||||
The application settings loaded at startup.
|
||||
"""
|
||||
return request.app.state.settings # type: ignore[no-any-return]
|
||||
state = cast("AppState", request.app.state)
|
||||
return state.settings
|
||||
|
||||
|
||||
async def require_auth(
|
||||
|
||||
53
backend/app/exceptions.py
Normal file
53
backend/app/exceptions.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Shared domain exception classes used across routers and services."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class JailNotFoundError(Exception):
|
||||
"""Raised when a requested jail name does not exist."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
super().__init__(f"Jail not found: {name!r}")
|
||||
|
||||
|
||||
class JailOperationError(Exception):
|
||||
"""Raised when a fail2ban jail operation fails."""
|
||||
|
||||
|
||||
class ConfigValidationError(Exception):
|
||||
"""Raised when config values fail validation before applying."""
|
||||
|
||||
|
||||
class ConfigOperationError(Exception):
|
||||
"""Raised when a config payload update or command fails."""
|
||||
|
||||
|
||||
class ServerOperationError(Exception):
|
||||
"""Raised when a server control command (e.g. refresh) fails."""
|
||||
|
||||
|
||||
class FilterInvalidRegexError(Exception):
|
||||
"""Raised when a regex pattern fails to compile."""
|
||||
|
||||
def __init__(self, pattern: str, error: str) -> None:
|
||||
"""Initialize with the invalid pattern and compile error."""
|
||||
self.pattern = pattern
|
||||
self.error = error
|
||||
super().__init__(f"Invalid regex {pattern!r}: {error}")
|
||||
|
||||
|
||||
class JailNotFoundInConfigError(Exception):
|
||||
"""Raised when the requested jail name is not defined in any config file."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
super().__init__(f"Jail not found in config: {name!r}")
|
||||
|
||||
|
||||
class ConfigWriteError(Exception):
|
||||
"""Raised when writing a configuration file modification fails."""
|
||||
|
||||
def __init__(self, message: str) -> None:
|
||||
self.message = message
|
||||
super().__init__(message)
|
||||
@@ -31,6 +31,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app import __version__
|
||||
from app.config import Settings, get_settings
|
||||
from app.db import init_db
|
||||
from app.routers import (
|
||||
@@ -49,6 +50,7 @@ from app.routers import (
|
||||
)
|
||||
from app.tasks import blocklist_import, geo_cache_flush, geo_re_resolve, health_check
|
||||
from app.utils.fail2ban_client import Fail2BanConnectionError, Fail2BanProtocolError
|
||||
from app.utils.jail_config import ensure_jail_configs
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Ensure the bundled fail2ban package is importable from fail2ban-master/
|
||||
@@ -137,7 +139,13 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
|
||||
log.info("bangui_starting_up", database_path=settings.database_path)
|
||||
|
||||
# --- Ensure required jail config files are present ---
|
||||
ensure_jail_configs(Path(settings.fail2ban_config_dir) / "jail.d")
|
||||
|
||||
# --- Application database ---
|
||||
db_path: Path = Path(settings.database_path)
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
log.debug("database_directory_ensured", directory=str(db_path.parent))
|
||||
db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path)
|
||||
db.row_factory = aiosqlite.Row
|
||||
await init_db(db)
|
||||
@@ -154,11 +162,7 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
await geo_service.load_cache_from_db(db)
|
||||
|
||||
# Log unresolved geo entries so the operator can see the scope of the issue.
|
||||
async with db.execute(
|
||||
"SELECT COUNT(*) FROM geo_cache WHERE country_code IS NULL"
|
||||
) as cur:
|
||||
row = await cur.fetchone()
|
||||
unresolved_count: int = int(row[0]) if row else 0
|
||||
unresolved_count = await geo_service.count_unresolved(db)
|
||||
if unresolved_count > 0:
|
||||
log.warning("geo_cache_unresolved_ips", unresolved=unresolved_count)
|
||||
|
||||
@@ -320,17 +324,15 @@ class SetupRedirectMiddleware(BaseHTTPMiddleware):
|
||||
if path.startswith("/api") and not getattr(
|
||||
request.app.state, "_setup_complete_cached", False
|
||||
):
|
||||
db: aiosqlite.Connection | None = getattr(request.app.state, "db", None)
|
||||
if db is not None:
|
||||
from app.services import setup_service # noqa: PLC0415
|
||||
from app.services import setup_service # noqa: PLC0415
|
||||
|
||||
if await setup_service.is_setup_complete(db):
|
||||
request.app.state._setup_complete_cached = True
|
||||
else:
|
||||
return RedirectResponse(
|
||||
url="/api/setup",
|
||||
status_code=status.HTTP_307_TEMPORARY_REDIRECT,
|
||||
)
|
||||
db: aiosqlite.Connection | None = getattr(request.app.state, "db", None)
|
||||
if db is None or not await setup_service.is_setup_complete(db):
|
||||
return RedirectResponse(
|
||||
url="/api/setup",
|
||||
status_code=status.HTTP_307_TEMPORARY_REDIRECT,
|
||||
)
|
||||
request.app.state._setup_complete_cached = True
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
@@ -360,7 +362,7 @@ def create_app(settings: Settings | None = None) -> FastAPI:
|
||||
app: FastAPI = FastAPI(
|
||||
title="BanGUI",
|
||||
description="Web interface for monitoring, managing, and configuring fail2ban.",
|
||||
version="0.1.0",
|
||||
version=__version__,
|
||||
lifespan=_lifespan,
|
||||
)
|
||||
|
||||
|
||||
@@ -807,6 +807,14 @@ class InactiveJail(BaseModel):
|
||||
"inactive jails that appear in this list."
|
||||
),
|
||||
)
|
||||
has_local_override: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"``True`` when a ``jail.d/{name}.local`` file exists for this jail. "
|
||||
"Only meaningful for inactive jails; indicates that a cleanup action "
|
||||
"is available."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class InactiveJailListResponse(BaseModel):
|
||||
@@ -993,7 +1001,7 @@ class ServiceStatusResponse(BaseModel):
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
online: bool = Field(..., description="Whether fail2ban is reachable via its socket.")
|
||||
version: str | None = Field(default=None, description="fail2ban version string, or None when offline.")
|
||||
version: str | None = Field(default=None, description="BanGUI application version (or None when offline).")
|
||||
jail_count: int = Field(default=0, ge=0, description="Number of currently active jails.")
|
||||
total_bans: int = Field(default=0, ge=0, description="Aggregated current ban count across all jails.")
|
||||
total_failures: int = Field(default=0, ge=0, description="Aggregated current failure count across all jails.")
|
||||
|
||||
@@ -3,8 +3,18 @@
|
||||
Response models for the ``GET /api/geo/lookup/{ip}`` endpoint.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
|
||||
|
||||
class GeoDetail(BaseModel):
|
||||
"""Enriched geolocation data for an IP address.
|
||||
@@ -64,3 +74,26 @@ class IpLookupResponse(BaseModel):
|
||||
default=None,
|
||||
description="Enriched geographical and network information.",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# shared service types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeoInfo:
|
||||
"""Geo resolution result used throughout backend services."""
|
||||
|
||||
country_code: str | None
|
||||
country_name: str | None
|
||||
asn: str | None
|
||||
org: str | None
|
||||
|
||||
|
||||
GeoEnricher = Callable[[str], Awaitable[GeoInfo | None]]
|
||||
GeoBatchLookup = Callable[
|
||||
[list[str], "aiohttp.ClientSession", "aiosqlite.Connection | None"],
|
||||
Awaitable[dict[str, GeoInfo]],
|
||||
]
|
||||
GeoCacheLookup = Callable[[list[str]], tuple[dict[str, GeoInfo], list[str]]]
|
||||
|
||||
365
backend/app/repositories/fail2ban_db_repo.py
Normal file
365
backend/app/repositories/fail2ban_db_repo.py
Normal file
@@ -0,0 +1,365 @@
|
||||
"""Fail2Ban SQLite database repository.
|
||||
|
||||
This module contains helper functions that query the read-only fail2ban
|
||||
SQLite database file. All functions accept a *db_path* and manage their own
|
||||
connection using aiosqlite in read-only mode.
|
||||
|
||||
The functions intentionally return plain Python data structures (dataclasses) so
|
||||
service layers can focus on business logic and formatting.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiosqlite
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
||||
from app.models.ban import BanOrigin
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BanRecord:
|
||||
"""A single row from the fail2ban ``bans`` table."""
|
||||
|
||||
jail: str
|
||||
ip: str
|
||||
timeofban: int
|
||||
bancount: int
|
||||
data: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BanIpCount:
|
||||
"""Aggregated ban count for a single IP."""
|
||||
|
||||
ip: str
|
||||
event_count: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JailBanCount:
|
||||
"""Aggregated ban count for a single jail."""
|
||||
|
||||
jail: str
|
||||
count: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class HistoryRecord:
|
||||
"""A single row from the fail2ban ``bans`` table for history queries."""
|
||||
|
||||
jail: str
|
||||
ip: str
|
||||
timeofban: int
|
||||
bancount: int
|
||||
data: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_db_uri(db_path: str) -> str:
|
||||
"""Return a read-only sqlite URI for the given file path."""
|
||||
|
||||
return f"file:{db_path}?mode=ro"
|
||||
|
||||
|
||||
def _origin_sql_filter(origin: BanOrigin | None) -> tuple[str, tuple[str, ...]]:
|
||||
"""Return a SQL fragment and parameters for the origin filter."""
|
||||
|
||||
if origin == "blocklist":
|
||||
return " AND jail = ?", ("blocklist-import",)
|
||||
if origin == "selfblock":
|
||||
return " AND jail != ?", ("blocklist-import",)
|
||||
return "", ()
|
||||
|
||||
|
||||
def _rows_to_ban_records(rows: Iterable[aiosqlite.Row]) -> list[BanRecord]:
|
||||
return [
|
||||
BanRecord(
|
||||
jail=str(r["jail"]),
|
||||
ip=str(r["ip"]),
|
||||
timeofban=int(r["timeofban"]),
|
||||
bancount=int(r["bancount"]),
|
||||
data=str(r["data"]),
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
def _rows_to_history_records(rows: Iterable[aiosqlite.Row]) -> list[HistoryRecord]:
|
||||
return [
|
||||
HistoryRecord(
|
||||
jail=str(r["jail"]),
|
||||
ip=str(r["ip"]),
|
||||
timeofban=int(r["timeofban"]),
|
||||
bancount=int(r["bancount"]),
|
||||
data=str(r["data"]),
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def check_db_nonempty(db_path: str) -> bool:
|
||||
"""Return True if the fail2ban database contains at least one ban row."""
|
||||
|
||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db, db.execute(
|
||||
"SELECT 1 FROM bans LIMIT 1"
|
||||
) as cur:
|
||||
row = await cur.fetchone()
|
||||
return row is not None
|
||||
|
||||
|
||||
async def get_currently_banned(
|
||||
db_path: str,
|
||||
since: int,
|
||||
origin: BanOrigin | None = None,
|
||||
*,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> tuple[list[BanRecord], int]:
|
||||
"""Return a page of currently banned IPs and the total matching count.
|
||||
|
||||
Args:
|
||||
db_path: File path to the fail2ban SQLite database.
|
||||
since: Unix timestamp to filter bans newer than or equal to.
|
||||
origin: Optional origin filter.
|
||||
limit: Optional maximum number of rows to return.
|
||||
offset: Optional offset for pagination.
|
||||
|
||||
Returns:
|
||||
A ``(records, total)`` tuple.
|
||||
"""
|
||||
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
|
||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
|
||||
async with db.execute(
|
||||
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
|
||||
(since, *origin_params),
|
||||
) as cur:
|
||||
count_row = await cur.fetchone()
|
||||
total: int = int(count_row[0]) if count_row else 0
|
||||
|
||||
query = (
|
||||
"SELECT jail, ip, timeofban, bancount, data "
|
||||
"FROM bans "
|
||||
"WHERE timeofban >= ?" + origin_clause + " ORDER BY timeofban DESC"
|
||||
)
|
||||
params: list[object] = [since, *origin_params]
|
||||
if limit is not None:
|
||||
query += " LIMIT ?"
|
||||
params.append(limit)
|
||||
if offset is not None:
|
||||
query += " OFFSET ?"
|
||||
params.append(offset)
|
||||
|
||||
async with db.execute(query, params) as cur:
|
||||
rows = await cur.fetchall()
|
||||
|
||||
return _rows_to_ban_records(rows), total
|
||||
|
||||
|
||||
async def get_ban_counts_by_bucket(
|
||||
db_path: str,
|
||||
since: int,
|
||||
bucket_secs: int,
|
||||
num_buckets: int,
|
||||
origin: BanOrigin | None = None,
|
||||
) -> list[int]:
|
||||
"""Return ban counts aggregated into equal-width time buckets."""
|
||||
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
|
||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"SELECT CAST((timeofban - ?) / ? AS INTEGER) AS bucket_idx, "
|
||||
"COUNT(*) AS cnt "
|
||||
"FROM bans "
|
||||
"WHERE timeofban >= ?" + origin_clause + " GROUP BY bucket_idx "
|
||||
"ORDER BY bucket_idx",
|
||||
(since, bucket_secs, since, *origin_params),
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
|
||||
counts: list[int] = [0] * num_buckets
|
||||
for row in rows:
|
||||
idx: int = int(row["bucket_idx"])
|
||||
if 0 <= idx < num_buckets:
|
||||
counts[idx] = int(row["cnt"])
|
||||
|
||||
return counts
|
||||
|
||||
|
||||
async def get_ban_event_counts(
|
||||
db_path: str,
|
||||
since: int,
|
||||
origin: BanOrigin | None = None,
|
||||
) -> list[BanIpCount]:
|
||||
"""Return total ban events per unique IP in the window."""
|
||||
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
|
||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"SELECT ip, COUNT(*) AS event_count "
|
||||
"FROM bans "
|
||||
"WHERE timeofban >= ?" + origin_clause + " GROUP BY ip",
|
||||
(since, *origin_params),
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
|
||||
return [
|
||||
BanIpCount(ip=str(r["ip"]), event_count=int(r["event_count"]))
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
async def get_bans_by_jail(
|
||||
db_path: str,
|
||||
since: int,
|
||||
origin: BanOrigin | None = None,
|
||||
) -> tuple[int, list[JailBanCount]]:
|
||||
"""Return per-jail ban counts and the total ban count."""
|
||||
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
|
||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
|
||||
async with db.execute(
|
||||
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
|
||||
(since, *origin_params),
|
||||
) as cur:
|
||||
count_row = await cur.fetchone()
|
||||
total: int = int(count_row[0]) if count_row else 0
|
||||
|
||||
async with db.execute(
|
||||
"SELECT jail, COUNT(*) AS cnt "
|
||||
"FROM bans "
|
||||
"WHERE timeofban >= ?" + origin_clause + " GROUP BY jail ORDER BY cnt DESC",
|
||||
(since, *origin_params),
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
|
||||
return total, [
|
||||
JailBanCount(jail=str(r["jail"]), count=int(r["cnt"])) for r in rows
|
||||
]
|
||||
|
||||
|
||||
async def get_bans_table_summary(
|
||||
db_path: str,
|
||||
) -> tuple[int, int | None, int | None]:
|
||||
"""Return basic summary stats for the ``bans`` table.
|
||||
|
||||
Returns:
|
||||
A tuple ``(row_count, min_timeofban, max_timeofban)``. If the table is
|
||||
empty the min/max values will be ``None``.
|
||||
"""
|
||||
|
||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"SELECT COUNT(*), MIN(timeofban), MAX(timeofban) FROM bans"
|
||||
) as cur:
|
||||
row = await cur.fetchone()
|
||||
|
||||
if row is None:
|
||||
return 0, None, None
|
||||
|
||||
return (
|
||||
int(row[0]),
|
||||
int(row[1]) if row[1] is not None else None,
|
||||
int(row[2]) if row[2] is not None else None,
|
||||
)
|
||||
|
||||
|
||||
async def get_history_page(
|
||||
db_path: str,
|
||||
since: int | None = None,
|
||||
jail: str | None = None,
|
||||
ip_filter: str | None = None,
|
||||
origin: BanOrigin | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 100,
|
||||
) -> tuple[list[HistoryRecord], int]:
|
||||
"""Return a paginated list of history records with total count."""
|
||||
|
||||
wheres: list[str] = []
|
||||
params: list[object] = []
|
||||
|
||||
if since is not None:
|
||||
wheres.append("timeofban >= ?")
|
||||
params.append(since)
|
||||
|
||||
if jail is not None:
|
||||
wheres.append("jail = ?")
|
||||
params.append(jail)
|
||||
|
||||
if ip_filter is not None:
|
||||
wheres.append("ip LIKE ?")
|
||||
params.append(f"{ip_filter}%")
|
||||
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
if origin_clause:
|
||||
origin_clause_clean = origin_clause.removeprefix(" AND ")
|
||||
wheres.append(origin_clause_clean)
|
||||
params.extend(origin_params)
|
||||
|
||||
where_sql: str = ("WHERE " + " AND ".join(wheres)) if wheres else ""
|
||||
|
||||
effective_page_size: int = page_size
|
||||
offset: int = (page - 1) * effective_page_size
|
||||
|
||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
|
||||
async with db.execute(
|
||||
f"SELECT COUNT(*) FROM bans {where_sql}", # noqa: S608
|
||||
params,
|
||||
) as cur:
|
||||
count_row = await cur.fetchone()
|
||||
total: int = int(count_row[0]) if count_row else 0
|
||||
|
||||
async with db.execute(
|
||||
f"SELECT jail, ip, timeofban, bancount, data "
|
||||
f"FROM bans {where_sql} "
|
||||
"ORDER BY timeofban DESC "
|
||||
"LIMIT ? OFFSET ?",
|
||||
[*params, effective_page_size, offset],
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
|
||||
return _rows_to_history_records(rows), total
|
||||
|
||||
|
||||
async def get_history_for_ip(db_path: str, ip: str) -> list[HistoryRecord]:
|
||||
"""Return the full ban timeline for a specific IP."""
|
||||
|
||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"SELECT jail, ip, timeofban, bancount, data "
|
||||
"FROM bans "
|
||||
"WHERE ip = ? "
|
||||
"ORDER BY timeofban DESC",
|
||||
(ip,),
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
|
||||
return _rows_to_history_records(rows)
|
||||
148
backend/app/repositories/geo_cache_repo.py
Normal file
148
backend/app/repositories/geo_cache_repo.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""Repository for the geo cache persistent store.
|
||||
|
||||
This module provides typed, async helpers for querying and mutating the
|
||||
``geo_cache`` table in the BanGUI application database.
|
||||
|
||||
All functions accept an open :class:`aiosqlite.Connection` and do not manage
|
||||
connection lifetimes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
import aiosqlite
|
||||
|
||||
|
||||
class GeoCacheRow(TypedDict):
|
||||
"""A single row from the ``geo_cache`` table."""
|
||||
|
||||
ip: str
|
||||
country_code: str | None
|
||||
country_name: str | None
|
||||
asn: str | None
|
||||
org: str | None
|
||||
|
||||
|
||||
async def load_all(db: aiosqlite.Connection) -> list[GeoCacheRow]:
|
||||
"""Load all geo cache rows from the database.
|
||||
|
||||
Args:
|
||||
db: Open BanGUI application database connection.
|
||||
|
||||
Returns:
|
||||
List of rows from the ``geo_cache`` table.
|
||||
"""
|
||||
rows: list[GeoCacheRow] = []
|
||||
async with db.execute(
|
||||
"SELECT ip, country_code, country_name, asn, org FROM geo_cache"
|
||||
) as cur:
|
||||
async for row in cur:
|
||||
rows.append(
|
||||
GeoCacheRow(
|
||||
ip=str(row[0]),
|
||||
country_code=row[1],
|
||||
country_name=row[2],
|
||||
asn=row[3],
|
||||
org=row[4],
|
||||
)
|
||||
)
|
||||
return rows
|
||||
|
||||
|
||||
async def get_unresolved_ips(db: aiosqlite.Connection) -> list[str]:
|
||||
"""Return all IPs in ``geo_cache`` where ``country_code`` is NULL.
|
||||
|
||||
Args:
|
||||
db: Open BanGUI application database connection.
|
||||
|
||||
Returns:
|
||||
List of IPv4/IPv6 strings that need geo resolution.
|
||||
"""
|
||||
ips: list[str] = []
|
||||
async with db.execute(
|
||||
"SELECT ip FROM geo_cache WHERE country_code IS NULL"
|
||||
) as cur:
|
||||
async for row in cur:
|
||||
ips.append(str(row[0]))
|
||||
return ips
|
||||
|
||||
|
||||
async def count_unresolved(db: aiosqlite.Connection) -> int:
|
||||
"""Return the number of unresolved rows (country_code IS NULL)."""
|
||||
async with db.execute(
|
||||
"SELECT COUNT(*) FROM geo_cache WHERE country_code IS NULL"
|
||||
) as cur:
|
||||
row = await cur.fetchone()
|
||||
return int(row[0]) if row else 0
|
||||
|
||||
|
||||
async def upsert_entry(
|
||||
db: aiosqlite.Connection,
|
||||
ip: str,
|
||||
country_code: str | None,
|
||||
country_name: str | None,
|
||||
asn: str | None,
|
||||
org: str | None,
|
||||
) -> None:
|
||||
"""Insert or update a resolved geo cache entry."""
|
||||
await db.execute(
|
||||
"""
|
||||
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(ip) DO UPDATE SET
|
||||
country_code = excluded.country_code,
|
||||
country_name = excluded.country_name,
|
||||
asn = excluded.asn,
|
||||
org = excluded.org,
|
||||
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
|
||||
""",
|
||||
(ip, country_code, country_name, asn, org),
|
||||
)
|
||||
|
||||
|
||||
async def upsert_neg_entry(db: aiosqlite.Connection, ip: str) -> None:
|
||||
"""Record a failed lookup attempt as a negative entry."""
|
||||
await db.execute(
|
||||
"INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)",
|
||||
(ip,),
|
||||
)
|
||||
|
||||
|
||||
async def bulk_upsert_entries(
|
||||
db: aiosqlite.Connection,
|
||||
rows: Sequence[tuple[str, str | None, str | None, str | None, str | None]],
|
||||
) -> int:
|
||||
"""Bulk insert or update multiple geo cache entries."""
|
||||
if not rows:
|
||||
return 0
|
||||
|
||||
await db.executemany(
|
||||
"""
|
||||
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(ip) DO UPDATE SET
|
||||
country_code = excluded.country_code,
|
||||
country_name = excluded.country_name,
|
||||
asn = excluded.asn,
|
||||
org = excluded.org,
|
||||
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
|
||||
""",
|
||||
rows,
|
||||
)
|
||||
return len(rows)
|
||||
|
||||
|
||||
async def bulk_upsert_neg_entries(db: aiosqlite.Connection, ips: list[str]) -> int:
|
||||
"""Bulk insert negative lookup entries."""
|
||||
if not ips:
|
||||
return 0
|
||||
|
||||
await db.executemany(
|
||||
"INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)",
|
||||
[(ip,) for ip in ips],
|
||||
)
|
||||
return len(ips)
|
||||
@@ -8,12 +8,26 @@ table. All methods are plain async functions that accept a
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, TypedDict, cast
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping
|
||||
|
||||
import aiosqlite
|
||||
|
||||
|
||||
class ImportLogRow(TypedDict):
|
||||
"""Row shape returned by queries on the import_log table."""
|
||||
|
||||
id: int
|
||||
source_id: int | None
|
||||
source_url: str
|
||||
timestamp: str
|
||||
ips_imported: int
|
||||
ips_skipped: int
|
||||
errors: str | None
|
||||
|
||||
|
||||
async def add_log(
|
||||
db: aiosqlite.Connection,
|
||||
*,
|
||||
@@ -54,7 +68,7 @@ async def list_logs(
|
||||
source_id: int | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
) -> tuple[list[ImportLogRow], int]:
|
||||
"""Return a paginated list of import log entries.
|
||||
|
||||
Args:
|
||||
@@ -68,8 +82,8 @@ async def list_logs(
|
||||
*total* is the count of all matching rows (ignoring pagination).
|
||||
"""
|
||||
where = ""
|
||||
params_count: list[Any] = []
|
||||
params_rows: list[Any] = []
|
||||
params_count: list[object] = []
|
||||
params_rows: list[object] = []
|
||||
|
||||
if source_id is not None:
|
||||
where = " WHERE source_id = ?"
|
||||
@@ -102,7 +116,7 @@ async def list_logs(
|
||||
return items, total
|
||||
|
||||
|
||||
async def get_last_log(db: aiosqlite.Connection) -> dict[str, Any] | None:
|
||||
async def get_last_log(db: aiosqlite.Connection) -> ImportLogRow | None:
|
||||
"""Return the most recent import log entry across all sources.
|
||||
|
||||
Args:
|
||||
@@ -143,13 +157,14 @@ def compute_total_pages(total: int, page_size: int) -> int:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _row_to_dict(row: Any) -> dict[str, Any]:
|
||||
def _row_to_dict(row: object) -> ImportLogRow:
|
||||
"""Convert an aiosqlite row to a plain Python dict.
|
||||
|
||||
Args:
|
||||
row: An :class:`aiosqlite.Row` or sequence returned by a cursor.
|
||||
row: An :class:`aiosqlite.Row` or similar mapping returned by a cursor.
|
||||
|
||||
Returns:
|
||||
Dict mapping column names to Python values.
|
||||
"""
|
||||
return dict(row)
|
||||
mapping = cast("Mapping[str, object]", row)
|
||||
return cast("ImportLogRow", dict(mapping))
|
||||
|
||||
@@ -20,8 +20,8 @@ from fastapi import APIRouter, HTTPException, Request, status
|
||||
from app.dependencies import AuthDep
|
||||
from app.models.ban import ActiveBanListResponse, BanRequest, UnbanAllResponse, UnbanRequest
|
||||
from app.models.jail import JailCommandResponse
|
||||
from app.services import jail_service
|
||||
from app.services.jail_service import JailNotFoundError, JailOperationError
|
||||
from app.services import geo_service, jail_service
|
||||
from app.exceptions import JailNotFoundError, JailOperationError
|
||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/api/bans", tags=["Bans"])
|
||||
@@ -73,6 +73,7 @@ async def get_active_bans(
|
||||
try:
|
||||
return await jail_service.get_active_bans(
|
||||
socket_path,
|
||||
geo_batch_lookup=geo_service.lookup_batch,
|
||||
http_session=http_session,
|
||||
app_db=app_db,
|
||||
)
|
||||
|
||||
@@ -42,8 +42,7 @@ from app.models.blocklist import (
|
||||
ScheduleConfig,
|
||||
ScheduleInfo,
|
||||
)
|
||||
from app.repositories import import_log_repo
|
||||
from app.services import blocklist_service
|
||||
from app.services import blocklist_service, geo_service
|
||||
from app.tasks import blocklist_import as blocklist_import_task
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/api/blocklists", tags=["Blocklists"])
|
||||
@@ -132,7 +131,15 @@ async def run_import_now(
|
||||
"""
|
||||
http_session: aiohttp.ClientSession = request.app.state.http_session
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
return await blocklist_service.import_all(db, http_session, socket_path)
|
||||
from app.services import jail_service
|
||||
|
||||
return await blocklist_service.import_all(
|
||||
db,
|
||||
http_session,
|
||||
socket_path,
|
||||
geo_is_cached=geo_service.is_cached,
|
||||
geo_batch_lookup=geo_service.lookup_batch,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -225,19 +232,9 @@ async def get_import_log(
|
||||
Returns:
|
||||
:class:`~app.models.blocklist.ImportLogListResponse`.
|
||||
"""
|
||||
items, total = await import_log_repo.list_logs(
|
||||
return await blocklist_service.list_import_logs(
|
||||
db, source_id=source_id, page=page, page_size=page_size
|
||||
)
|
||||
total_pages = import_log_repo.compute_total_pages(total, page_size)
|
||||
from app.models.blocklist import ImportLogEntry # noqa: PLC0415
|
||||
|
||||
return ImportLogListResponse(
|
||||
items=[ImportLogEntry.model_validate(i) for i in items],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -40,6 +40,7 @@ from __future__ import annotations
|
||||
import datetime
|
||||
from typing import Annotated
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, HTTPException, Path, Query, Request, status
|
||||
|
||||
from app.dependencies import AuthDep
|
||||
@@ -75,31 +76,39 @@ from app.models.config import (
|
||||
RollbackResponse,
|
||||
ServiceStatusResponse,
|
||||
)
|
||||
from app.services import config_file_service, config_service, jail_service
|
||||
from app.services.config_file_service import (
|
||||
from app.services import config_service, jail_service, log_service
|
||||
from app.services import (
|
||||
action_config_service,
|
||||
config_file_service,
|
||||
filter_config_service,
|
||||
jail_config_service,
|
||||
)
|
||||
from app.services.action_config_service import (
|
||||
ActionAlreadyExistsError,
|
||||
ActionNameError,
|
||||
ActionNotFoundError,
|
||||
ActionReadonlyError,
|
||||
ConfigWriteError,
|
||||
)
|
||||
from app.services.filter_config_service import (
|
||||
FilterAlreadyExistsError,
|
||||
FilterInvalidRegexError,
|
||||
FilterNameError,
|
||||
FilterNotFoundError,
|
||||
FilterReadonlyError,
|
||||
)
|
||||
from app.services.jail_config_service import (
|
||||
JailAlreadyActiveError,
|
||||
JailAlreadyInactiveError,
|
||||
JailNameError,
|
||||
JailNotFoundInConfigError,
|
||||
)
|
||||
from app.services.config_service import (
|
||||
ConfigOperationError,
|
||||
ConfigValidationError,
|
||||
JailNotFoundError,
|
||||
)
|
||||
from app.exceptions import ConfigOperationError, ConfigValidationError, JailNotFoundError, JailOperationError
|
||||
from app.tasks.health_check import _run_probe
|
||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/api/config", tags=["Config"])
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -194,7 +203,7 @@ async def get_inactive_jails(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
return await config_file_service.list_inactive_jails(config_dir, socket_path)
|
||||
return await jail_config_service.list_inactive_jails(config_dir, socket_path)
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -357,11 +366,17 @@ async def reload_fail2ban(
|
||||
_auth: Validated session.
|
||||
|
||||
Raises:
|
||||
HTTPException: 409 when fail2ban reports the reload failed.
|
||||
HTTPException: 502 when fail2ban is unreachable.
|
||||
"""
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
except JailOperationError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"fail2ban reload failed: {exc}",
|
||||
) from exc
|
||||
except Fail2BanConnectionError as exc:
|
||||
raise _bad_gateway(exc) from exc
|
||||
|
||||
@@ -381,24 +396,55 @@ async def restart_fail2ban(
|
||||
) -> None:
|
||||
"""Trigger a full fail2ban service restart.
|
||||
|
||||
The fail2ban daemon is completely stopped and then started again,
|
||||
re-reading all configuration files in the process.
|
||||
Stops the fail2ban daemon via the Unix domain socket, then starts it
|
||||
again using the configured ``fail2ban_start_command``. After starting,
|
||||
probes the socket for up to 10 seconds to confirm the daemon came back
|
||||
online.
|
||||
|
||||
Args:
|
||||
request: Incoming request.
|
||||
_auth: Validated session.
|
||||
|
||||
Raises:
|
||||
HTTPException: 502 when fail2ban is unreachable.
|
||||
HTTPException: 409 when fail2ban reports the stop command failed.
|
||||
HTTPException: 502 when fail2ban is unreachable for the stop command.
|
||||
HTTPException: 503 when fail2ban does not come back online within
|
||||
10 seconds after being started. Check the fail2ban log for
|
||||
initialisation errors. Use
|
||||
``POST /api/config/jails/{name}/rollback`` if a specific jail
|
||||
is suspect.
|
||||
"""
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
start_cmd: str = request.app.state.settings.fail2ban_start_command
|
||||
start_cmd_parts: list[str] = start_cmd.split()
|
||||
|
||||
# Step 1: stop the daemon via socket.
|
||||
try:
|
||||
# Perform restart by sending the restart command via the fail2ban socket.
|
||||
# If fail2ban is not running, this will raise an exception, and we return 502.
|
||||
await jail_service.restart(socket_path)
|
||||
except JailOperationError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"fail2ban stop command failed: {exc}",
|
||||
) from exc
|
||||
except Fail2BanConnectionError as exc:
|
||||
raise _bad_gateway(exc) from exc
|
||||
|
||||
# Step 2: start the daemon via subprocess.
|
||||
await config_file_service.start_daemon(start_cmd_parts)
|
||||
|
||||
# Step 3: probe the socket until fail2ban is responsive or the budget expires.
|
||||
fail2ban_running: bool = await config_file_service.wait_for_fail2ban(socket_path, max_wait_seconds=10.0)
|
||||
if not fail2ban_running:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=(
|
||||
"fail2ban was stopped but did not come back online within 10 seconds. "
|
||||
"Check the fail2ban log for initialisation errors. "
|
||||
"Use POST /api/config/jails/{name}/rollback if a specific jail is suspect."
|
||||
),
|
||||
)
|
||||
log.info("fail2ban_restarted")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regex tester (stateless)
|
||||
@@ -426,7 +472,7 @@ async def regex_test(
|
||||
Returns:
|
||||
:class:`~app.models.config.RegexTestResponse` with match result and groups.
|
||||
"""
|
||||
return config_service.test_regex(body)
|
||||
return log_service.test_regex(body)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -532,7 +578,7 @@ async def preview_log(
|
||||
Returns:
|
||||
:class:`~app.models.config.LogPreviewResponse` with per-line results.
|
||||
"""
|
||||
return await config_service.preview_log(body)
|
||||
return await log_service.preview_log(body)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -561,9 +607,7 @@ async def get_map_color_thresholds(
|
||||
"""
|
||||
from app.services import setup_service
|
||||
|
||||
high, medium, low = await setup_service.get_map_color_thresholds(
|
||||
request.app.state.db
|
||||
)
|
||||
high, medium, low = await setup_service.get_map_color_thresholds(request.app.state.db)
|
||||
return MapColorThresholdsResponse(
|
||||
threshold_high=high,
|
||||
threshold_medium=medium,
|
||||
@@ -653,9 +697,7 @@ async def activate_jail(
|
||||
req = body if body is not None else ActivateJailRequest()
|
||||
|
||||
try:
|
||||
result = await config_file_service.activate_jail(
|
||||
config_dir, socket_path, name, req
|
||||
)
|
||||
result = await jail_config_service.activate_jail(config_dir, socket_path, name, req)
|
||||
except JailNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except JailNotFoundInConfigError:
|
||||
@@ -729,7 +771,7 @@ async def deactivate_jail(
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
|
||||
try:
|
||||
result = await config_file_service.deactivate_jail(config_dir, socket_path, name)
|
||||
result = await jail_config_service.deactivate_jail(config_dir, socket_path, name)
|
||||
except JailNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except JailNotFoundInConfigError:
|
||||
@@ -755,6 +797,58 @@ async def deactivate_jail(
|
||||
return result
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/jails/{name}/local",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete the jail.d override file for an inactive jail",
|
||||
)
|
||||
async def delete_jail_local_override(
|
||||
request: Request,
|
||||
_auth: AuthDep,
|
||||
name: _NamePath,
|
||||
) -> None:
|
||||
"""Remove the ``jail.d/{name}.local`` override file for an inactive jail.
|
||||
|
||||
This endpoint is the clean-up action for inactive jails that still carry
|
||||
a ``.local`` override file (e.g. one written with ``enabled = false`` by a
|
||||
previous deactivation). The file is deleted without modifying fail2ban's
|
||||
running state, since the jail is already inactive.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object.
|
||||
_auth: Validated session.
|
||||
name: Name of the jail whose ``.local`` file should be removed.
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if *name* contains invalid characters.
|
||||
HTTPException: 404 if *name* is not found in any config file.
|
||||
HTTPException: 409 if the jail is currently active.
|
||||
HTTPException: 500 if the file cannot be deleted.
|
||||
HTTPException: 502 if fail2ban is unreachable.
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
|
||||
try:
|
||||
await jail_config_service.delete_jail_local_override(config_dir, socket_path, name)
|
||||
except JailNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except JailNotFoundInConfigError:
|
||||
raise _not_found(name) from None
|
||||
except JailAlreadyActiveError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Jail {name!r} is currently active; deactivate it first.",
|
||||
) from None
|
||||
except ConfigWriteError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to delete config override: {exc}",
|
||||
) from exc
|
||||
except Fail2BanConnectionError as exc:
|
||||
raise _bad_gateway(exc) from exc
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Jail validation & rollback endpoints (Task 3)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -789,7 +883,7 @@ async def validate_jail(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
return await config_file_service.validate_jail_config(config_dir, name)
|
||||
return await jail_config_service.validate_jail_config(config_dir, name)
|
||||
except JailNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
|
||||
@@ -855,9 +949,7 @@ async def rollback_jail(
|
||||
start_cmd_parts: list[str] = start_cmd.split()
|
||||
|
||||
try:
|
||||
result = await config_file_service.rollback_jail(
|
||||
config_dir, socket_path, name, start_cmd_parts
|
||||
)
|
||||
result = await jail_config_service.rollback_jail(config_dir, socket_path, name, start_cmd_parts)
|
||||
except JailNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigWriteError as exc:
|
||||
@@ -909,7 +1001,7 @@ async def list_filters(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
result = await config_file_service.list_filters(config_dir, socket_path)
|
||||
result = await filter_config_service.list_filters(config_dir, socket_path)
|
||||
# Sort: active first (by name), then inactive (by name).
|
||||
result.filters.sort(key=lambda f: (not f.active, f.name.lower()))
|
||||
return result
|
||||
@@ -946,7 +1038,7 @@ async def get_filter(
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
try:
|
||||
return await config_file_service.get_filter(config_dir, socket_path, name)
|
||||
return await filter_config_service.get_filter(config_dir, socket_path, name)
|
||||
except FilterNotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
@@ -1010,9 +1102,7 @@ async def update_filter(
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
try:
|
||||
return await config_file_service.update_filter(
|
||||
config_dir, socket_path, name, body, do_reload=reload
|
||||
)
|
||||
return await filter_config_service.update_filter(config_dir, socket_path, name, body, do_reload=reload)
|
||||
except FilterNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except FilterNotFoundError:
|
||||
@@ -1062,9 +1152,7 @@ async def create_filter(
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
try:
|
||||
return await config_file_service.create_filter(
|
||||
config_dir, socket_path, body, do_reload=reload
|
||||
)
|
||||
return await filter_config_service.create_filter(config_dir, socket_path, body, do_reload=reload)
|
||||
except FilterNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except FilterAlreadyExistsError as exc:
|
||||
@@ -1111,7 +1199,7 @@ async def delete_filter(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
await config_file_service.delete_filter(config_dir, name)
|
||||
await filter_config_service.delete_filter(config_dir, name)
|
||||
except FilterNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except FilterNotFoundError:
|
||||
@@ -1160,9 +1248,7 @@ async def assign_filter_to_jail(
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
try:
|
||||
await config_file_service.assign_filter_to_jail(
|
||||
config_dir, socket_path, name, body, do_reload=reload
|
||||
)
|
||||
await filter_config_service.assign_filter_to_jail(config_dir, socket_path, name, body, do_reload=reload)
|
||||
except (JailNameError, FilterNameError) as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except JailNotFoundInConfigError:
|
||||
@@ -1226,7 +1312,7 @@ async def list_actions(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
result = await config_file_service.list_actions(config_dir, socket_path)
|
||||
result = await action_config_service.list_actions(config_dir, socket_path)
|
||||
result.actions.sort(key=lambda a: (not a.active, a.name.lower()))
|
||||
return result
|
||||
|
||||
@@ -1261,7 +1347,7 @@ async def get_action(
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
try:
|
||||
return await config_file_service.get_action(config_dir, socket_path, name)
|
||||
return await action_config_service.get_action(config_dir, socket_path, name)
|
||||
except ActionNotFoundError:
|
||||
raise _action_not_found(name) from None
|
||||
|
||||
@@ -1306,9 +1392,7 @@ async def update_action(
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
try:
|
||||
return await config_file_service.update_action(
|
||||
config_dir, socket_path, name, body, do_reload=reload
|
||||
)
|
||||
return await action_config_service.update_action(config_dir, socket_path, name, body, do_reload=reload)
|
||||
except ActionNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ActionNotFoundError:
|
||||
@@ -1354,9 +1438,7 @@ async def create_action(
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
try:
|
||||
return await config_file_service.create_action(
|
||||
config_dir, socket_path, body, do_reload=reload
|
||||
)
|
||||
return await action_config_service.create_action(config_dir, socket_path, body, do_reload=reload)
|
||||
except ActionNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ActionAlreadyExistsError as exc:
|
||||
@@ -1399,7 +1481,7 @@ async def delete_action(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
await config_file_service.delete_action(config_dir, name)
|
||||
await action_config_service.delete_action(config_dir, name)
|
||||
except ActionNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ActionNotFoundError:
|
||||
@@ -1449,9 +1531,7 @@ async def assign_action_to_jail(
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
try:
|
||||
await config_file_service.assign_action_to_jail(
|
||||
config_dir, socket_path, name, body, do_reload=reload
|
||||
)
|
||||
await action_config_service.assign_action_to_jail(config_dir, socket_path, name, body, do_reload=reload)
|
||||
except (JailNameError, ActionNameError) as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except JailNotFoundInConfigError:
|
||||
@@ -1500,9 +1580,7 @@ async def remove_action_from_jail(
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
try:
|
||||
await config_file_service.remove_action_from_jail(
|
||||
config_dir, socket_path, name, action_name, do_reload=reload
|
||||
)
|
||||
await action_config_service.remove_action_from_jail(config_dir, socket_path, name, action_name, do_reload=reload)
|
||||
except (JailNameError, ActionNameError) as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except JailNotFoundInConfigError:
|
||||
@@ -1588,8 +1666,12 @@ async def get_service_status(
|
||||
handles this gracefully and returns ``online=False``).
|
||||
"""
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
from app.services import health_service
|
||||
|
||||
try:
|
||||
return await config_service.get_service_status(socket_path)
|
||||
return await config_service.get_service_status(
|
||||
socket_path,
|
||||
probe_fn=health_service.probe,
|
||||
)
|
||||
except Fail2BanConnectionError as exc:
|
||||
raise _bad_gateway(exc) from exc
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from fastapi import APIRouter, Query, Request
|
||||
|
||||
from app import __version__
|
||||
from app.dependencies import AuthDep
|
||||
from app.models.ban import (
|
||||
BanOrigin,
|
||||
@@ -29,7 +30,7 @@ from app.models.ban import (
|
||||
TimeRange,
|
||||
)
|
||||
from app.models.server import ServerStatus, ServerStatusResponse
|
||||
from app.services import ban_service
|
||||
from app.services import ban_service, geo_service
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/api/dashboard", tags=["Dashboard"])
|
||||
|
||||
@@ -69,6 +70,7 @@ async def get_server_status(
|
||||
"server_status",
|
||||
ServerStatus(online=False),
|
||||
)
|
||||
cached.version = __version__
|
||||
return ServerStatusResponse(status=cached)
|
||||
|
||||
|
||||
@@ -119,6 +121,7 @@ async def get_dashboard_bans(
|
||||
page_size=page_size,
|
||||
http_session=http_session,
|
||||
app_db=None,
|
||||
geo_batch_lookup=geo_service.lookup_batch,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
@@ -162,6 +165,8 @@ async def get_bans_by_country(
|
||||
socket_path,
|
||||
range,
|
||||
http_session=http_session,
|
||||
geo_cache_lookup=geo_service.lookup_cached_only,
|
||||
geo_batch_lookup=geo_service.lookup_batch,
|
||||
app_db=None,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
@@ -14,8 +14,8 @@ Endpoints:
|
||||
* ``GET /api/config/filters/{name}/parsed`` — parse a filter file into a structured model
|
||||
* ``PUT /api/config/filters/{name}/parsed`` — update a filter file from a structured model
|
||||
* ``GET /api/config/actions`` — list all action files
|
||||
* ``GET /api/config/actions/{name}`` — get one action file (with content)
|
||||
* ``PUT /api/config/actions/{name}`` — update an action file
|
||||
* ``GET /api/config/actions/{name}/raw`` — get one action file (raw content)
|
||||
* ``PUT /api/config/actions/{name}/raw`` — update an action file (raw content)
|
||||
* ``POST /api/config/actions`` — create a new action file
|
||||
* ``GET /api/config/actions/{name}/parsed`` — parse an action file into a structured model
|
||||
* ``PUT /api/config/actions/{name}/parsed`` — update an action file from a structured model
|
||||
@@ -51,8 +51,8 @@ from app.models.file_config import (
|
||||
JailConfigFileEnabledUpdate,
|
||||
JailConfigFilesResponse,
|
||||
)
|
||||
from app.services import file_config_service
|
||||
from app.services.file_config_service import (
|
||||
from app.services import raw_config_io_service
|
||||
from app.services.raw_config_io_service import (
|
||||
ConfigDirError,
|
||||
ConfigFileExistsError,
|
||||
ConfigFileNameError,
|
||||
@@ -134,7 +134,7 @@ async def list_jail_config_files(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
return await file_config_service.list_jail_config_files(config_dir)
|
||||
return await raw_config_io_service.list_jail_config_files(config_dir)
|
||||
except ConfigDirError as exc:
|
||||
raise _service_unavailable(str(exc)) from exc
|
||||
|
||||
@@ -166,7 +166,7 @@ async def get_jail_config_file(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
return await file_config_service.get_jail_config_file(config_dir, filename)
|
||||
return await raw_config_io_service.get_jail_config_file(config_dir, filename)
|
||||
except ConfigFileNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigFileNotFoundError:
|
||||
@@ -204,7 +204,7 @@ async def write_jail_config_file(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
await file_config_service.write_jail_config_file(config_dir, filename, body)
|
||||
await raw_config_io_service.write_jail_config_file(config_dir, filename, body)
|
||||
except ConfigFileNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigFileNotFoundError:
|
||||
@@ -244,7 +244,7 @@ async def set_jail_config_file_enabled(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
await file_config_service.set_jail_config_enabled(
|
||||
await raw_config_io_service.set_jail_config_enabled(
|
||||
config_dir, filename, body.enabled
|
||||
)
|
||||
except ConfigFileNameError as exc:
|
||||
@@ -285,7 +285,7 @@ async def create_jail_config_file(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
filename = await file_config_service.create_jail_config_file(config_dir, body)
|
||||
filename = await raw_config_io_service.create_jail_config_file(config_dir, body)
|
||||
except ConfigFileNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigFileExistsError:
|
||||
@@ -338,7 +338,7 @@ async def get_filter_file_raw(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
return await file_config_service.get_filter_file(config_dir, name)
|
||||
return await raw_config_io_service.get_filter_file(config_dir, name)
|
||||
except ConfigFileNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigFileNotFoundError:
|
||||
@@ -373,7 +373,7 @@ async def write_filter_file(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
await file_config_service.write_filter_file(config_dir, name, body)
|
||||
await raw_config_io_service.write_filter_file(config_dir, name, body)
|
||||
except ConfigFileNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigFileNotFoundError:
|
||||
@@ -412,7 +412,7 @@ async def create_filter_file(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
filename = await file_config_service.create_filter_file(config_dir, body)
|
||||
filename = await raw_config_io_service.create_filter_file(config_dir, body)
|
||||
except ConfigFileNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigFileExistsError:
|
||||
@@ -454,13 +454,13 @@ async def list_action_files(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
return await file_config_service.list_action_files(config_dir)
|
||||
return await raw_config_io_service.list_action_files(config_dir)
|
||||
except ConfigDirError as exc:
|
||||
raise _service_unavailable(str(exc)) from exc
|
||||
|
||||
|
||||
@router.get(
|
||||
"/actions/{name}",
|
||||
"/actions/{name}/raw",
|
||||
response_model=ConfFileContent,
|
||||
summary="Return an action definition file with its content",
|
||||
)
|
||||
@@ -486,7 +486,7 @@ async def get_action_file(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
return await file_config_service.get_action_file(config_dir, name)
|
||||
return await raw_config_io_service.get_action_file(config_dir, name)
|
||||
except ConfigFileNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigFileNotFoundError:
|
||||
@@ -496,7 +496,7 @@ async def get_action_file(
|
||||
|
||||
|
||||
@router.put(
|
||||
"/actions/{name}",
|
||||
"/actions/{name}/raw",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Update an action definition file",
|
||||
)
|
||||
@@ -521,7 +521,7 @@ async def write_action_file(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
await file_config_service.write_action_file(config_dir, name, body)
|
||||
await raw_config_io_service.write_action_file(config_dir, name, body)
|
||||
except ConfigFileNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigFileNotFoundError:
|
||||
@@ -560,7 +560,7 @@ async def create_action_file(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
filename = await file_config_service.create_action_file(config_dir, body)
|
||||
filename = await raw_config_io_service.create_action_file(config_dir, body)
|
||||
except ConfigFileNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigFileExistsError:
|
||||
@@ -613,7 +613,7 @@ async def get_parsed_filter(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
return await file_config_service.get_parsed_filter_file(config_dir, name)
|
||||
return await raw_config_io_service.get_parsed_filter_file(config_dir, name)
|
||||
except ConfigFileNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigFileNotFoundError:
|
||||
@@ -651,7 +651,7 @@ async def update_parsed_filter(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
await file_config_service.update_parsed_filter_file(config_dir, name, body)
|
||||
await raw_config_io_service.update_parsed_filter_file(config_dir, name, body)
|
||||
except ConfigFileNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigFileNotFoundError:
|
||||
@@ -698,7 +698,7 @@ async def get_parsed_action(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
return await file_config_service.get_parsed_action_file(config_dir, name)
|
||||
return await raw_config_io_service.get_parsed_action_file(config_dir, name)
|
||||
except ConfigFileNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigFileNotFoundError:
|
||||
@@ -736,7 +736,7 @@ async def update_parsed_action(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
await file_config_service.update_parsed_action_file(config_dir, name, body)
|
||||
await raw_config_io_service.update_parsed_action_file(config_dir, name, body)
|
||||
except ConfigFileNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigFileNotFoundError:
|
||||
@@ -783,7 +783,7 @@ async def get_parsed_jail_file(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
return await file_config_service.get_parsed_jail_file(config_dir, filename)
|
||||
return await raw_config_io_service.get_parsed_jail_file(config_dir, filename)
|
||||
except ConfigFileNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigFileNotFoundError:
|
||||
@@ -821,7 +821,7 @@ async def update_parsed_jail_file(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
await file_config_service.update_parsed_jail_file(config_dir, filename, body)
|
||||
await raw_config_io_service.update_parsed_jail_file(config_dir, filename, body)
|
||||
except ConfigFileNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigFileNotFoundError:
|
||||
|
||||
@@ -13,11 +13,13 @@ from typing import TYPE_CHECKING, Annotated
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
|
||||
from app.services.jail_service import IpLookupResult
|
||||
|
||||
import aiosqlite
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Request, status
|
||||
|
||||
from app.dependencies import AuthDep, get_db
|
||||
from app.models.geo import GeoCacheStatsResponse, GeoDetail, IpLookupResponse
|
||||
from app.models.geo import GeoCacheStatsResponse, GeoDetail, GeoInfo, IpLookupResponse
|
||||
from app.services import geo_service, jail_service
|
||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||
|
||||
@@ -61,7 +63,7 @@ async def lookup_ip(
|
||||
return await geo_service.lookup(addr, http_session)
|
||||
|
||||
try:
|
||||
result = await jail_service.lookup_ip(
|
||||
result: IpLookupResult = await jail_service.lookup_ip(
|
||||
socket_path,
|
||||
ip,
|
||||
geo_enricher=_enricher,
|
||||
@@ -77,9 +79,9 @@ async def lookup_ip(
|
||||
detail=f"Cannot reach fail2ban: {exc}",
|
||||
) from exc
|
||||
|
||||
raw_geo = result.get("geo")
|
||||
raw_geo = result["geo"]
|
||||
geo_detail: GeoDetail | None = None
|
||||
if raw_geo is not None:
|
||||
if isinstance(raw_geo, GeoInfo):
|
||||
geo_detail = GeoDetail(
|
||||
country_code=raw_geo.country_code,
|
||||
country_name=raw_geo.country_name,
|
||||
@@ -153,12 +155,7 @@ async def re_resolve_geo(
|
||||
that were retried.
|
||||
"""
|
||||
# Collect all IPs in geo_cache that still lack a country code.
|
||||
unresolved: list[str] = []
|
||||
async with db.execute(
|
||||
"SELECT ip FROM geo_cache WHERE country_code IS NULL"
|
||||
) as cur:
|
||||
async for row in cur:
|
||||
unresolved.append(str(row[0]))
|
||||
unresolved = await geo_service.get_unresolved_ips(db)
|
||||
|
||||
if not unresolved:
|
||||
return {"resolved": 0, "total": 0}
|
||||
|
||||
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
||||
from fastapi import APIRouter, HTTPException, Query, Request
|
||||
|
||||
from app.dependencies import AuthDep
|
||||
from app.models.ban import TimeRange
|
||||
from app.models.ban import BanOrigin, TimeRange
|
||||
from app.models.history import HistoryListResponse, IpDetailResponse
|
||||
from app.services import geo_service, history_service
|
||||
|
||||
@@ -52,6 +52,10 @@ async def get_history(
|
||||
default=None,
|
||||
description="Restrict results to IPs matching this prefix.",
|
||||
),
|
||||
origin: BanOrigin | None = Query(
|
||||
default=None,
|
||||
description="Filter by ban origin: 'blocklist' or 'selfblock'. Omit for all.",
|
||||
),
|
||||
page: int = Query(default=1, ge=1, description="1-based page number."),
|
||||
page_size: int = Query(
|
||||
default=_DEFAULT_PAGE_SIZE,
|
||||
@@ -89,6 +93,7 @@ async def get_history(
|
||||
range_=range,
|
||||
jail=jail,
|
||||
ip_filter=ip,
|
||||
origin=origin,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
geo_enricher=_enricher,
|
||||
|
||||
@@ -31,8 +31,8 @@ from app.models.jail import (
|
||||
JailDetailResponse,
|
||||
JailListResponse,
|
||||
)
|
||||
from app.services import jail_service
|
||||
from app.services.jail_service import JailNotFoundError, JailOperationError
|
||||
from app.services import geo_service, jail_service
|
||||
from app.exceptions import JailNotFoundError, JailOperationError
|
||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/api/jails", tags=["Jails"])
|
||||
@@ -606,6 +606,7 @@ async def get_jail_banned_ips(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
search=search,
|
||||
geo_batch_lookup=geo_service.lookup_batch,
|
||||
http_session=http_session,
|
||||
app_db=app_db,
|
||||
)
|
||||
|
||||
@@ -15,7 +15,7 @@ from fastapi import APIRouter, HTTPException, Request, status
|
||||
from app.dependencies import AuthDep
|
||||
from app.models.server import ServerSettingsResponse, ServerSettingsUpdate
|
||||
from app.services import server_service
|
||||
from app.services.server_service import ServerOperationError
|
||||
from app.exceptions import ServerOperationError
|
||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/api/server", tags=["Server"])
|
||||
|
||||
1070
backend/app/services/action_config_service.py
Normal file
1070
backend/app/services/action_config_service.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -20,7 +20,7 @@ if TYPE_CHECKING:
|
||||
from app.models.auth import Session
|
||||
|
||||
from app.repositories import session_repo
|
||||
from app.services import setup_service
|
||||
from app.utils.setup_utils import get_password_hash
|
||||
from app.utils.time_utils import add_minutes, utc_now
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
@@ -65,7 +65,7 @@ async def login(
|
||||
Raises:
|
||||
ValueError: If the password is incorrect or no password hash is stored.
|
||||
"""
|
||||
stored_hash = await setup_service.get_password_hash(db)
|
||||
stored_hash = await get_password_hash(db)
|
||||
if stored_hash is None:
|
||||
log.warning("bangui_login_no_hash")
|
||||
raise ValueError("No password is configured — run setup first.")
|
||||
|
||||
@@ -11,12 +11,9 @@ so BanGUI never modifies or locks the fail2ban database.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiosqlite
|
||||
import structlog
|
||||
|
||||
from app.models.ban import (
|
||||
@@ -31,15 +28,21 @@ from app.models.ban import (
|
||||
BanTrendResponse,
|
||||
DashboardBanItem,
|
||||
DashboardBanListResponse,
|
||||
JailBanCount,
|
||||
TimeRange,
|
||||
_derive_origin,
|
||||
bucket_count,
|
||||
)
|
||||
from app.utils.fail2ban_client import Fail2BanClient
|
||||
from app.models.ban import (
|
||||
JailBanCount as JailBanCountModel,
|
||||
)
|
||||
from app.repositories import fail2ban_db_repo
|
||||
from app.utils.fail2ban_db_utils import get_fail2ban_db_path, parse_data_json, ts_to_iso
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
|
||||
from app.models.geo import GeoBatchLookup, GeoCacheLookup, GeoEnricher, GeoInfo
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
@@ -74,6 +77,9 @@ def _origin_sql_filter(origin: BanOrigin | None) -> tuple[str, tuple[str, ...]]:
|
||||
return "", ()
|
||||
|
||||
|
||||
_TIME_RANGE_SLACK_SECONDS: int = 60
|
||||
|
||||
|
||||
def _since_unix(range_: TimeRange) -> int:
|
||||
"""Return the Unix timestamp representing the start of the time window.
|
||||
|
||||
@@ -88,92 +94,13 @@ def _since_unix(range_: TimeRange) -> int:
|
||||
range_: One of the supported time-range presets.
|
||||
|
||||
Returns:
|
||||
Unix timestamp (seconds since epoch) equal to *now − range_*.
|
||||
Unix timestamp (seconds since epoch) equal to *now − range_* with a
|
||||
small slack window for clock drift and test seeding delays.
|
||||
"""
|
||||
seconds: int = TIME_RANGE_SECONDS[range_]
|
||||
return int(time.time()) - seconds
|
||||
return int(time.time()) - seconds - _TIME_RANGE_SLACK_SECONDS
|
||||
|
||||
|
||||
def _ts_to_iso(unix_ts: int) -> str:
|
||||
"""Convert a Unix timestamp to an ISO 8601 UTC string.
|
||||
|
||||
Args:
|
||||
unix_ts: Seconds since the Unix epoch.
|
||||
|
||||
Returns:
|
||||
ISO 8601 UTC timestamp, e.g. ``"2026-03-01T12:00:00+00:00"``.
|
||||
"""
|
||||
return datetime.fromtimestamp(unix_ts, tz=UTC).isoformat()
|
||||
|
||||
|
||||
async def _get_fail2ban_db_path(socket_path: str) -> str:
|
||||
"""Query fail2ban for the path to its SQLite database.
|
||||
|
||||
Sends the ``get dbfile`` command via the fail2ban socket and returns
|
||||
the value of the ``dbfile`` setting.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
|
||||
Returns:
|
||||
Absolute path to the fail2ban SQLite database file.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If fail2ban reports that no database is configured
|
||||
or if the socket response is unexpected.
|
||||
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
|
||||
cannot be reached.
|
||||
"""
|
||||
async with Fail2BanClient(socket_path, timeout=_SOCKET_TIMEOUT) as client:
|
||||
response = await client.send(["get", "dbfile"])
|
||||
|
||||
try:
|
||||
code, data = response
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise RuntimeError(f"Unexpected response from fail2ban: {response!r}") from exc
|
||||
|
||||
if code != 0:
|
||||
raise RuntimeError(f"fail2ban error code {code}: {data!r}")
|
||||
|
||||
if data is None:
|
||||
raise RuntimeError("fail2ban has no database configured (dbfile is None)")
|
||||
|
||||
return str(data)
|
||||
|
||||
|
||||
def _parse_data_json(raw: Any) -> tuple[list[str], int]:
|
||||
"""Extract matches and failure count from the ``bans.data`` column.
|
||||
|
||||
The ``data`` column stores a JSON blob with optional keys:
|
||||
|
||||
* ``matches`` — list of raw matched log lines.
|
||||
* ``failures`` — total failure count that triggered the ban.
|
||||
|
||||
Args:
|
||||
raw: The raw ``data`` column value (string, dict, or ``None``).
|
||||
|
||||
Returns:
|
||||
A ``(matches, failures)`` tuple. Both default to empty/zero when
|
||||
parsing fails or the column is absent.
|
||||
"""
|
||||
if raw is None:
|
||||
return [], 0
|
||||
|
||||
obj: dict[str, Any] = {}
|
||||
if isinstance(raw, str):
|
||||
try:
|
||||
parsed: Any = json.loads(raw)
|
||||
if isinstance(parsed, dict):
|
||||
obj = parsed
|
||||
# json.loads("null") → None, or other non-dict — treat as empty
|
||||
except json.JSONDecodeError:
|
||||
return [], 0
|
||||
elif isinstance(raw, dict):
|
||||
obj = raw
|
||||
|
||||
matches: list[str] = [str(m) for m in (obj.get("matches") or [])]
|
||||
failures: int = int(obj.get("failures", 0))
|
||||
return matches, failures
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -189,7 +116,8 @@ async def list_bans(
|
||||
page_size: int = _DEFAULT_PAGE_SIZE,
|
||||
http_session: aiohttp.ClientSession | None = None,
|
||||
app_db: aiosqlite.Connection | None = None,
|
||||
geo_enricher: Any | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
origin: BanOrigin | None = None,
|
||||
) -> DashboardBanListResponse:
|
||||
"""Return a paginated list of bans within the selected time window.
|
||||
@@ -228,14 +156,13 @@ async def list_bans(
|
||||
:class:`~app.models.ban.DashboardBanListResponse` containing the
|
||||
paginated items and total count.
|
||||
"""
|
||||
from app.services import geo_service # noqa: PLC0415
|
||||
|
||||
since: int = _since_unix(range_)
|
||||
effective_page_size: int = min(page_size, _MAX_PAGE_SIZE)
|
||||
offset: int = (page - 1) * effective_page_size
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
|
||||
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||
db_path: str = await get_fail2ban_db_path(socket_path)
|
||||
log.info(
|
||||
"ban_service_list_bans",
|
||||
db_path=db_path,
|
||||
@@ -244,45 +171,32 @@ async def list_bans(
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
|
||||
f2b_db.row_factory = aiosqlite.Row
|
||||
|
||||
async with f2b_db.execute(
|
||||
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
|
||||
(since, *origin_params),
|
||||
) as cur:
|
||||
count_row = await cur.fetchone()
|
||||
total: int = int(count_row[0]) if count_row else 0
|
||||
|
||||
async with f2b_db.execute(
|
||||
"SELECT jail, ip, timeofban, bancount, data "
|
||||
"FROM bans "
|
||||
"WHERE timeofban >= ?"
|
||||
+ origin_clause
|
||||
+ " ORDER BY timeofban DESC "
|
||||
"LIMIT ? OFFSET ?",
|
||||
(since, *origin_params, effective_page_size, offset),
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
rows, total = await fail2ban_db_repo.get_currently_banned(
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
origin=origin,
|
||||
limit=effective_page_size,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
# Batch-resolve geo data for all IPs on this page in a single API call.
|
||||
# This avoids hitting the 45 req/min single-IP rate limit when the
|
||||
# page contains many bans (e.g. after a large blocklist import).
|
||||
geo_map: dict[str, Any] = {}
|
||||
if http_session is not None and rows:
|
||||
page_ips: list[str] = [str(r["ip"]) for r in rows]
|
||||
geo_map: dict[str, GeoInfo] = {}
|
||||
if http_session is not None and rows and geo_batch_lookup is not None:
|
||||
page_ips: list[str] = [r.ip for r in rows]
|
||||
try:
|
||||
geo_map = await geo_service.lookup_batch(page_ips, http_session, db=app_db)
|
||||
geo_map = await geo_batch_lookup(page_ips, http_session, db=app_db)
|
||||
except Exception: # noqa: BLE001
|
||||
log.warning("ban_service_batch_geo_failed_list_bans")
|
||||
|
||||
items: list[DashboardBanItem] = []
|
||||
for row in rows:
|
||||
jail: str = str(row["jail"])
|
||||
ip: str = str(row["ip"])
|
||||
banned_at: str = _ts_to_iso(int(row["timeofban"]))
|
||||
ban_count: int = int(row["bancount"])
|
||||
matches, _ = _parse_data_json(row["data"])
|
||||
jail: str = row.jail
|
||||
ip: str = row.ip
|
||||
banned_at: str = ts_to_iso(row.timeofban)
|
||||
ban_count: int = row.bancount
|
||||
matches, _ = parse_data_json(row.data)
|
||||
service: str | None = matches[0] if matches else None
|
||||
|
||||
country_code: str | None = None
|
||||
@@ -343,7 +257,9 @@ async def bans_by_country(
|
||||
socket_path: str,
|
||||
range_: TimeRange,
|
||||
http_session: aiohttp.ClientSession | None = None,
|
||||
geo_enricher: Any | None = None,
|
||||
geo_cache_lookup: GeoCacheLookup | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
app_db: aiosqlite.Connection | None = None,
|
||||
origin: BanOrigin | None = None,
|
||||
) -> BansByCountryResponse:
|
||||
@@ -382,11 +298,10 @@ async def bans_by_country(
|
||||
:class:`~app.models.ban.BansByCountryResponse` with per-country
|
||||
aggregation and the companion ban list.
|
||||
"""
|
||||
from app.services import geo_service # noqa: PLC0415
|
||||
|
||||
since: int = _since_unix(range_)
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||
db_path: str = await get_fail2ban_db_path(socket_path)
|
||||
log.info(
|
||||
"ban_service_bans_by_country",
|
||||
db_path=db_path,
|
||||
@@ -395,64 +310,54 @@ async def bans_by_country(
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
|
||||
f2b_db.row_factory = aiosqlite.Row
|
||||
# Total count and companion rows reuse the same SQL query logic.
|
||||
# Passing limit=0 returns only the total from the count query.
|
||||
_, total = await fail2ban_db_repo.get_currently_banned(
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
origin=origin,
|
||||
limit=0,
|
||||
offset=0,
|
||||
)
|
||||
|
||||
# Total count for the window.
|
||||
async with f2b_db.execute(
|
||||
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
|
||||
(since, *origin_params),
|
||||
) as cur:
|
||||
count_row = await cur.fetchone()
|
||||
total: int = int(count_row[0]) if count_row else 0
|
||||
agg_rows = await fail2ban_db_repo.get_ban_event_counts(
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
# Aggregation: unique IPs + their total event count.
|
||||
# No LIMIT here — we need all unique source IPs for accurate country counts.
|
||||
async with f2b_db.execute(
|
||||
"SELECT ip, COUNT(*) AS event_count "
|
||||
"FROM bans "
|
||||
"WHERE timeofban >= ?"
|
||||
+ origin_clause
|
||||
+ " GROUP BY ip",
|
||||
(since, *origin_params),
|
||||
) as cur:
|
||||
agg_rows = await cur.fetchall()
|
||||
companion_rows, _ = await fail2ban_db_repo.get_currently_banned(
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
origin=origin,
|
||||
limit=_MAX_COMPANION_BANS,
|
||||
offset=0,
|
||||
)
|
||||
|
||||
# Companion table: most recent raw rows for display alongside the map.
|
||||
async with f2b_db.execute(
|
||||
"SELECT jail, ip, timeofban, bancount, data "
|
||||
"FROM bans "
|
||||
"WHERE timeofban >= ?"
|
||||
+ origin_clause
|
||||
+ " ORDER BY timeofban DESC "
|
||||
"LIMIT ?",
|
||||
(since, *origin_params, _MAX_COMPANION_BANS),
|
||||
) as cur:
|
||||
companion_rows = await cur.fetchall()
|
||||
unique_ips: list[str] = [r.ip for r in agg_rows]
|
||||
geo_map: dict[str, GeoInfo] = {}
|
||||
|
||||
unique_ips: list[str] = [str(r["ip"]) for r in agg_rows]
|
||||
geo_map: dict[str, Any] = {}
|
||||
|
||||
if http_session is not None and unique_ips:
|
||||
if http_session is not None and unique_ips and geo_cache_lookup is not None:
|
||||
# Serve only what is already in the in-memory cache — no API calls on
|
||||
# the hot path. Uncached IPs are resolved asynchronously in the
|
||||
# background so subsequent requests benefit from a warmer cache.
|
||||
geo_map, uncached = geo_service.lookup_cached_only(unique_ips)
|
||||
geo_map, uncached = geo_cache_lookup(unique_ips)
|
||||
if uncached:
|
||||
log.info(
|
||||
"ban_service_geo_background_scheduled",
|
||||
uncached=len(uncached),
|
||||
cached=len(geo_map),
|
||||
)
|
||||
# Fire-and-forget: lookup_batch handles rate-limiting / retries.
|
||||
# The dirty-set flush task persists results to the DB.
|
||||
asyncio.create_task( # noqa: RUF006
|
||||
geo_service.lookup_batch(uncached, http_session, db=app_db),
|
||||
name="geo_bans_by_country",
|
||||
)
|
||||
if geo_batch_lookup is not None:
|
||||
# Fire-and-forget: lookup_batch handles rate-limiting / retries.
|
||||
# The dirty-set flush task persists results to the DB.
|
||||
asyncio.create_task( # noqa: RUF006
|
||||
geo_batch_lookup(uncached, http_session, db=app_db),
|
||||
name="geo_bans_by_country",
|
||||
)
|
||||
elif geo_enricher is not None and unique_ips:
|
||||
# Fallback: legacy per-IP enricher (used in tests / older callers).
|
||||
async def _safe_lookup(ip: str) -> tuple[str, Any]:
|
||||
async def _safe_lookup(ip: str) -> tuple[str, GeoInfo | None]:
|
||||
try:
|
||||
return ip, await geo_enricher(ip)
|
||||
except Exception: # noqa: BLE001
|
||||
@@ -460,18 +365,18 @@ async def bans_by_country(
|
||||
return ip, None
|
||||
|
||||
results = await asyncio.gather(*(_safe_lookup(ip) for ip in unique_ips))
|
||||
geo_map = dict(results)
|
||||
geo_map = {ip: geo for ip, geo in results if geo is not None}
|
||||
|
||||
# Build country aggregation from the SQL-grouped rows.
|
||||
countries: dict[str, int] = {}
|
||||
country_names: dict[str, str] = {}
|
||||
|
||||
for row in agg_rows:
|
||||
ip: str = str(row["ip"])
|
||||
for agg_row in agg_rows:
|
||||
ip: str = agg_row.ip
|
||||
geo = geo_map.get(ip)
|
||||
cc: str | None = geo.country_code if geo else None
|
||||
cn: str | None = geo.country_name if geo else None
|
||||
event_count: int = int(row["event_count"])
|
||||
event_count: int = agg_row.event_count
|
||||
|
||||
if cc:
|
||||
countries[cc] = countries.get(cc, 0) + event_count
|
||||
@@ -480,27 +385,27 @@ async def bans_by_country(
|
||||
|
||||
# Build companion table from recent rows (geo already cached from batch step).
|
||||
bans: list[DashboardBanItem] = []
|
||||
for row in companion_rows:
|
||||
ip = str(row["ip"])
|
||||
for companion_row in companion_rows:
|
||||
ip = companion_row.ip
|
||||
geo = geo_map.get(ip)
|
||||
cc = geo.country_code if geo else None
|
||||
cn = geo.country_name if geo else None
|
||||
asn: str | None = geo.asn if geo else None
|
||||
org: str | None = geo.org if geo else None
|
||||
matches, _ = _parse_data_json(row["data"])
|
||||
matches, _ = parse_data_json(companion_row.data)
|
||||
|
||||
bans.append(
|
||||
DashboardBanItem(
|
||||
ip=ip,
|
||||
jail=str(row["jail"]),
|
||||
banned_at=_ts_to_iso(int(row["timeofban"])),
|
||||
jail=companion_row.jail,
|
||||
banned_at=ts_to_iso(companion_row.timeofban),
|
||||
service=matches[0] if matches else None,
|
||||
country_code=cc,
|
||||
country_name=cn,
|
||||
asn=asn,
|
||||
org=org,
|
||||
ban_count=int(row["bancount"]),
|
||||
origin=_derive_origin(str(row["jail"])),
|
||||
ban_count=companion_row.bancount,
|
||||
origin=_derive_origin(companion_row.jail),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -554,7 +459,7 @@ async def ban_trend(
|
||||
num_buckets: int = bucket_count(range_)
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
|
||||
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||
db_path: str = await get_fail2ban_db_path(socket_path)
|
||||
log.info(
|
||||
"ban_service_ban_trend",
|
||||
db_path=db_path,
|
||||
@@ -565,32 +470,18 @@ async def ban_trend(
|
||||
num_buckets=num_buckets,
|
||||
)
|
||||
|
||||
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
|
||||
f2b_db.row_factory = aiosqlite.Row
|
||||
|
||||
async with f2b_db.execute(
|
||||
"SELECT CAST((timeofban - ?) / ? AS INTEGER) AS bucket_idx, "
|
||||
"COUNT(*) AS cnt "
|
||||
"FROM bans "
|
||||
"WHERE timeofban >= ?"
|
||||
+ origin_clause
|
||||
+ " GROUP BY bucket_idx "
|
||||
"ORDER BY bucket_idx",
|
||||
(since, bucket_secs, since, *origin_params),
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
|
||||
# Map bucket_idx → count; ignore any out-of-range indices.
|
||||
counts: dict[int, int] = {}
|
||||
for row in rows:
|
||||
idx: int = int(row["bucket_idx"])
|
||||
if 0 <= idx < num_buckets:
|
||||
counts[idx] = int(row["cnt"])
|
||||
counts = await fail2ban_db_repo.get_ban_counts_by_bucket(
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
bucket_secs=bucket_secs,
|
||||
num_buckets=num_buckets,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
buckets: list[BanTrendBucket] = [
|
||||
BanTrendBucket(
|
||||
timestamp=_ts_to_iso(since + i * bucket_secs),
|
||||
count=counts.get(i, 0),
|
||||
timestamp=ts_to_iso(since + i * bucket_secs),
|
||||
count=counts[i],
|
||||
)
|
||||
for i in range(num_buckets)
|
||||
]
|
||||
@@ -633,60 +524,44 @@ async def bans_by_jail(
|
||||
since: int = _since_unix(range_)
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
|
||||
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||
db_path: str = await get_fail2ban_db_path(socket_path)
|
||||
log.debug(
|
||||
"ban_service_bans_by_jail",
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
since_iso=_ts_to_iso(since),
|
||||
since_iso=ts_to_iso(since),
|
||||
range=range_,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
|
||||
f2b_db.row_factory = aiosqlite.Row
|
||||
total, jail_counts = await fail2ban_db_repo.get_bans_by_jail(
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
async with f2b_db.execute(
|
||||
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
|
||||
(since, *origin_params),
|
||||
) as cur:
|
||||
count_row = await cur.fetchone()
|
||||
total: int = int(count_row[0]) if count_row else 0
|
||||
# Diagnostic guard: if zero results were returned, check whether the table
|
||||
# has *any* rows and log a warning with min/max timeofban so operators can
|
||||
# diagnose timezone or filter mismatches from logs.
|
||||
if total == 0:
|
||||
table_row_count, min_timeofban, max_timeofban = await fail2ban_db_repo.get_bans_table_summary(db_path)
|
||||
if table_row_count > 0:
|
||||
log.warning(
|
||||
"ban_service_bans_by_jail_empty_despite_data",
|
||||
table_row_count=table_row_count,
|
||||
min_timeofban=min_timeofban,
|
||||
max_timeofban=max_timeofban,
|
||||
since=since,
|
||||
range=range_,
|
||||
)
|
||||
|
||||
# Diagnostic guard: if zero results were returned, check whether the
|
||||
# table has *any* rows and log a warning with min/max timeofban so
|
||||
# operators can diagnose timezone or filter mismatches from logs.
|
||||
if total == 0:
|
||||
async with f2b_db.execute(
|
||||
"SELECT COUNT(*), MIN(timeofban), MAX(timeofban) FROM bans"
|
||||
) as cur:
|
||||
diag_row = await cur.fetchone()
|
||||
if diag_row and diag_row[0] > 0:
|
||||
log.warning(
|
||||
"ban_service_bans_by_jail_empty_despite_data",
|
||||
table_row_count=diag_row[0],
|
||||
min_timeofban=diag_row[1],
|
||||
max_timeofban=diag_row[2],
|
||||
since=since,
|
||||
range=range_,
|
||||
)
|
||||
|
||||
async with f2b_db.execute(
|
||||
"SELECT jail, COUNT(*) AS cnt "
|
||||
"FROM bans "
|
||||
"WHERE timeofban >= ?"
|
||||
+ origin_clause
|
||||
+ " GROUP BY jail ORDER BY cnt DESC",
|
||||
(since, *origin_params),
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
|
||||
jails: list[JailBanCount] = [
|
||||
JailBanCount(jail=str(row["jail"]), count=int(row["cnt"])) for row in rows
|
||||
]
|
||||
log.debug(
|
||||
"ban_service_bans_by_jail_result",
|
||||
total=total,
|
||||
jail_count=len(jails),
|
||||
jail_count=len(jail_counts),
|
||||
)
|
||||
|
||||
return BansByJailResponse(
|
||||
jails=[JailBanCountModel(jail=j.jail, count=j.count) for j in jail_counts],
|
||||
total=total,
|
||||
)
|
||||
return BansByJailResponse(jails=jails, total=total)
|
||||
|
||||
@@ -14,26 +14,35 @@ under the key ``"blocklist_schedule"``.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from collections.abc import Awaitable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
|
||||
from app.models.blocklist import (
|
||||
BlocklistSource,
|
||||
ImportLogEntry,
|
||||
ImportLogListResponse,
|
||||
ImportRunResult,
|
||||
ImportSourceResult,
|
||||
PreviewResponse,
|
||||
ScheduleConfig,
|
||||
ScheduleInfo,
|
||||
)
|
||||
from app.exceptions import JailNotFoundError
|
||||
from app.repositories import blocklist_repo, import_log_repo, settings_repo
|
||||
from app.utils.ip_utils import is_valid_ip, is_valid_network
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
|
||||
from app.models.geo import GeoBatchLookup
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
#: Settings key used to persist the schedule config.
|
||||
@@ -54,7 +63,7 @@ _PREVIEW_MAX_BYTES: int = 65536
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _row_to_source(row: dict[str, Any]) -> BlocklistSource:
|
||||
def _row_to_source(row: dict[str, object]) -> BlocklistSource:
|
||||
"""Convert a repository row dict to a :class:`BlocklistSource`.
|
||||
|
||||
Args:
|
||||
@@ -236,6 +245,9 @@ async def import_source(
|
||||
http_session: aiohttp.ClientSession,
|
||||
socket_path: str,
|
||||
db: aiosqlite.Connection,
|
||||
geo_is_cached: Callable[[str], bool] | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
ban_ip: Callable[[str, str, str], Awaitable[None]] | None = None,
|
||||
) -> ImportSourceResult:
|
||||
"""Download and apply bans from a single blocklist source.
|
||||
|
||||
@@ -293,8 +305,14 @@ async def import_source(
|
||||
ban_error: str | None = None
|
||||
imported_ips: list[str] = []
|
||||
|
||||
# Import jail_service here to avoid circular import at module level.
|
||||
from app.services import jail_service # noqa: PLC0415
|
||||
if ban_ip is None:
|
||||
try:
|
||||
jail_svc = importlib.import_module("app.services.jail_service")
|
||||
ban_ip_fn = jail_svc.ban_ip
|
||||
except (ModuleNotFoundError, AttributeError) as exc:
|
||||
raise ValueError("ban_ip callback is required") from exc
|
||||
else:
|
||||
ban_ip_fn = ban_ip
|
||||
|
||||
for line in content.splitlines():
|
||||
stripped = line.strip()
|
||||
@@ -307,10 +325,10 @@ async def import_source(
|
||||
continue
|
||||
|
||||
try:
|
||||
await jail_service.ban_ip(socket_path, BLOCKLIST_JAIL, stripped)
|
||||
await ban_ip_fn(socket_path, BLOCKLIST_JAIL, stripped)
|
||||
imported += 1
|
||||
imported_ips.append(stripped)
|
||||
except jail_service.JailNotFoundError as exc:
|
||||
except JailNotFoundError as exc:
|
||||
# The target jail does not exist in fail2ban — there is no point
|
||||
# continuing because every subsequent ban would also fail.
|
||||
ban_error = str(exc)
|
||||
@@ -337,12 +355,8 @@ async def import_source(
|
||||
)
|
||||
|
||||
# --- Pre-warm geo cache for newly imported IPs ---
|
||||
if imported_ips:
|
||||
from app.services import geo_service # noqa: PLC0415
|
||||
|
||||
uncached_ips: list[str] = [
|
||||
ip for ip in imported_ips if not geo_service.is_cached(ip)
|
||||
]
|
||||
if imported_ips and geo_is_cached is not None:
|
||||
uncached_ips: list[str] = [ip for ip in imported_ips if not geo_is_cached(ip)]
|
||||
skipped_geo: int = len(imported_ips) - len(uncached_ips)
|
||||
|
||||
if skipped_geo > 0:
|
||||
@@ -353,9 +367,9 @@ async def import_source(
|
||||
to_lookup=len(uncached_ips),
|
||||
)
|
||||
|
||||
if uncached_ips:
|
||||
if uncached_ips and geo_batch_lookup is not None:
|
||||
try:
|
||||
await geo_service.lookup_batch(uncached_ips, http_session, db=db)
|
||||
await geo_batch_lookup(uncached_ips, http_session, db=db)
|
||||
log.info(
|
||||
"blocklist_geo_prewarm_complete",
|
||||
source_id=source.id,
|
||||
@@ -381,6 +395,9 @@ async def import_all(
|
||||
db: aiosqlite.Connection,
|
||||
http_session: aiohttp.ClientSession,
|
||||
socket_path: str,
|
||||
geo_is_cached: Callable[[str], bool] | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
ban_ip: Callable[[str, str, str], Awaitable[None]] | None = None,
|
||||
) -> ImportRunResult:
|
||||
"""Import all enabled blocklist sources.
|
||||
|
||||
@@ -404,7 +421,15 @@ async def import_all(
|
||||
|
||||
for row in sources:
|
||||
source = _row_to_source(row)
|
||||
result = await import_source(source, http_session, socket_path, db)
|
||||
result = await import_source(
|
||||
source,
|
||||
http_session,
|
||||
socket_path,
|
||||
db,
|
||||
geo_is_cached=geo_is_cached,
|
||||
geo_batch_lookup=geo_batch_lookup,
|
||||
ban_ip=ban_ip,
|
||||
)
|
||||
results.append(result)
|
||||
total_imported += result.ips_imported
|
||||
total_skipped += result.ips_skipped
|
||||
@@ -503,12 +528,44 @@ async def get_schedule_info(
|
||||
)
|
||||
|
||||
|
||||
async def list_import_logs(
|
||||
db: aiosqlite.Connection,
|
||||
*,
|
||||
source_id: int | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> ImportLogListResponse:
|
||||
"""Return a paginated list of import log entries.
|
||||
|
||||
Args:
|
||||
db: Active application database connection.
|
||||
source_id: Optional filter to only return logs for a specific source.
|
||||
page: 1-based page number.
|
||||
page_size: Items per page.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.blocklist.ImportLogListResponse`.
|
||||
"""
|
||||
items, total = await import_log_repo.list_logs(
|
||||
db, source_id=source_id, page=page, page_size=page_size
|
||||
)
|
||||
total_pages = import_log_repo.compute_total_pages(total, page_size)
|
||||
|
||||
return ImportLogListResponse(
|
||||
items=[ImportLogEntry.model_validate(i) for i in items],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _aiohttp_timeout(seconds: float) -> Any:
|
||||
def _aiohttp_timeout(seconds: float) -> aiohttp.ClientTimeout:
|
||||
"""Return an :class:`aiohttp.ClientTimeout` with the given total timeout.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -28,7 +28,7 @@ import os
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import structlog
|
||||
|
||||
@@ -54,12 +54,52 @@ from app.models.config import (
|
||||
JailValidationResult,
|
||||
RollbackResponse,
|
||||
)
|
||||
from app.services import conffile_parser, jail_service
|
||||
from app.services.jail_service import JailNotFoundError as JailNotFoundError
|
||||
from app.utils.fail2ban_client import Fail2BanClient, Fail2BanConnectionError
|
||||
from app.exceptions import FilterInvalidRegexError, JailNotFoundError
|
||||
from app.utils import conffile_parser
|
||||
from app.utils.jail_utils import reload_jails
|
||||
from app.utils.fail2ban_client import (
|
||||
Fail2BanClient,
|
||||
Fail2BanConnectionError,
|
||||
Fail2BanResponse,
|
||||
)
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
# Proxy object for jail reload operations. Tests can patch
|
||||
# app.services.config_file_service.jail_service.reload_all as needed.
|
||||
class _JailServiceProxy:
|
||||
async def reload_all(
|
||||
self,
|
||||
socket_path: str,
|
||||
include_jails: list[str] | None = None,
|
||||
exclude_jails: list[str] | None = None,
|
||||
) -> None:
|
||||
kwargs: dict[str, list[str]] = {}
|
||||
if include_jails is not None:
|
||||
kwargs["include_jails"] = include_jails
|
||||
if exclude_jails is not None:
|
||||
kwargs["exclude_jails"] = exclude_jails
|
||||
await reload_jails(socket_path, **kwargs)
|
||||
|
||||
|
||||
jail_service = _JailServiceProxy()
|
||||
|
||||
|
||||
async def _reload_all(
|
||||
socket_path: str,
|
||||
include_jails: list[str] | None = None,
|
||||
exclude_jails: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Reload fail2ban jails using the configured hook or default helper."""
|
||||
kwargs: dict[str, list[str]] = {}
|
||||
if include_jails is not None:
|
||||
kwargs["include_jails"] = include_jails
|
||||
if exclude_jails is not None:
|
||||
kwargs["exclude_jails"] = exclude_jails
|
||||
|
||||
await jail_service.reload_all(socket_path, **kwargs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -67,9 +107,7 @@ log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
_SOCKET_TIMEOUT: float = 10.0
|
||||
|
||||
# Allowlist pattern for jail names used in path construction.
|
||||
_SAFE_JAIL_NAME_RE: re.Pattern[str] = re.compile(
|
||||
r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$"
|
||||
)
|
||||
_SAFE_JAIL_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
|
||||
|
||||
# Sections that are not jail definitions.
|
||||
_META_SECTIONS: frozenset[str] = frozenset({"INCLUDES", "DEFAULT"})
|
||||
@@ -161,26 +199,10 @@ class FilterReadonlyError(Exception):
|
||||
"""
|
||||
self.name: str = name
|
||||
super().__init__(
|
||||
f"Filter {name!r} is a shipped default (.conf only); "
|
||||
"only user-created .local files can be deleted."
|
||||
f"Filter {name!r} is a shipped default (.conf only); only user-created .local files can be deleted."
|
||||
)
|
||||
|
||||
|
||||
class FilterInvalidRegexError(Exception):
|
||||
"""Raised when a regex pattern fails to compile."""
|
||||
|
||||
def __init__(self, pattern: str, error: str) -> None:
|
||||
"""Initialise with the invalid pattern and the compile error.
|
||||
|
||||
Args:
|
||||
pattern: The regex string that failed to compile.
|
||||
error: The ``re.error`` message.
|
||||
"""
|
||||
self.pattern: str = pattern
|
||||
self.error: str = error
|
||||
super().__init__(f"Invalid regex {pattern!r}: {error}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -417,9 +439,7 @@ def _parse_jails_sync(
|
||||
# items() merges DEFAULT values automatically.
|
||||
jails[section] = dict(parser.items(section))
|
||||
except configparser.Error as exc:
|
||||
log.warning(
|
||||
"jail_section_parse_error", section=section, error=str(exc)
|
||||
)
|
||||
log.warning("jail_section_parse_error", section=section, error=str(exc))
|
||||
|
||||
log.debug("jails_parsed", count=len(jails), config_dir=str(config_dir))
|
||||
return jails, source_files
|
||||
@@ -429,6 +449,7 @@ def _build_inactive_jail(
|
||||
name: str,
|
||||
settings: dict[str, str],
|
||||
source_file: str,
|
||||
config_dir: Path | None = None,
|
||||
) -> InactiveJail:
|
||||
"""Construct an :class:`~app.models.config.InactiveJail` from raw settings.
|
||||
|
||||
@@ -436,6 +457,8 @@ def _build_inactive_jail(
|
||||
name: Jail section name.
|
||||
settings: Merged key→value dict (DEFAULT values already applied).
|
||||
source_file: Path of the file that last defined this section.
|
||||
config_dir: Absolute path to the fail2ban configuration directory, used
|
||||
to check whether a ``jail.d/{name}.local`` override file exists.
|
||||
|
||||
Returns:
|
||||
Populated :class:`~app.models.config.InactiveJail`.
|
||||
@@ -513,6 +536,7 @@ def _build_inactive_jail(
|
||||
bantime_escalation=bantime_escalation,
|
||||
source_file=source_file,
|
||||
enabled=enabled,
|
||||
has_local_override=((config_dir / "jail.d" / f"{name}.local").is_file() if config_dir is not None else False),
|
||||
)
|
||||
|
||||
|
||||
@@ -530,10 +554,10 @@ async def _get_active_jail_names(socket_path: str) -> set[str]:
|
||||
try:
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
|
||||
def _to_dict_inner(pairs: Any) -> dict[str, Any]:
|
||||
def _to_dict_inner(pairs: object) -> dict[str, object]:
|
||||
if not isinstance(pairs, (list, tuple)):
|
||||
return {}
|
||||
result: dict[str, Any] = {}
|
||||
result: dict[str, object] = {}
|
||||
for item in pairs:
|
||||
try:
|
||||
k, v = item
|
||||
@@ -542,8 +566,8 @@ async def _get_active_jail_names(socket_path: str) -> set[str]:
|
||||
pass
|
||||
return result
|
||||
|
||||
def _ok(response: Any) -> Any:
|
||||
code, data = response
|
||||
def _ok(response: object) -> object:
|
||||
code, data = cast("Fail2BanResponse", response)
|
||||
if code != 0:
|
||||
raise ValueError(f"fail2ban error {code}: {data!r}")
|
||||
return data
|
||||
@@ -558,9 +582,7 @@ async def _get_active_jail_names(socket_path: str) -> set[str]:
|
||||
log.warning("fail2ban_unreachable_during_inactive_list")
|
||||
return set()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"fail2ban_status_error_during_inactive_list", error=str(exc)
|
||||
)
|
||||
log.warning("fail2ban_status_error_during_inactive_list", error=str(exc))
|
||||
return set()
|
||||
|
||||
|
||||
@@ -648,10 +670,7 @@ def _validate_jail_config_sync(
|
||||
issues.append(
|
||||
JailValidationIssue(
|
||||
field="filter",
|
||||
message=(
|
||||
f"Filter file not found: filter.d/{base_filter}.conf"
|
||||
" (or .local)"
|
||||
),
|
||||
message=(f"Filter file not found: filter.d/{base_filter}.conf (or .local)"),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -667,10 +686,7 @@ def _validate_jail_config_sync(
|
||||
issues.append(
|
||||
JailValidationIssue(
|
||||
field="action",
|
||||
message=(
|
||||
f"Action file not found: action.d/{action_name}.conf"
|
||||
" (or .local)"
|
||||
),
|
||||
message=(f"Action file not found: action.d/{action_name}.conf (or .local)"),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -740,7 +756,7 @@ async def _probe_fail2ban_running(socket_path: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
async def _wait_for_fail2ban(
|
||||
async def wait_for_fail2ban(
|
||||
socket_path: str,
|
||||
max_wait_seconds: float = 10.0,
|
||||
poll_interval: float = 2.0,
|
||||
@@ -764,7 +780,7 @@ async def _wait_for_fail2ban(
|
||||
return False
|
||||
|
||||
|
||||
async def _start_daemon(start_cmd_parts: list[str]) -> bool:
|
||||
async def start_daemon(start_cmd_parts: list[str]) -> bool:
|
||||
"""Start the fail2ban daemon using *start_cmd_parts*.
|
||||
|
||||
Uses :func:`asyncio.create_subprocess_exec` (no shell interpretation)
|
||||
@@ -804,7 +820,7 @@ def _write_local_override_sync(
|
||||
config_dir: Path,
|
||||
jail_name: str,
|
||||
enabled: bool,
|
||||
overrides: dict[str, Any],
|
||||
overrides: dict[str, object],
|
||||
) -> None:
|
||||
"""Write a ``jail.d/{name}.local`` file atomically.
|
||||
|
||||
@@ -826,9 +842,7 @@ def _write_local_override_sync(
|
||||
try:
|
||||
jail_d.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(
|
||||
f"Cannot create jail.d directory: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc
|
||||
|
||||
local_path = jail_d / f"{jail_name}.local"
|
||||
|
||||
@@ -853,7 +867,7 @@ def _write_local_override_sync(
|
||||
if overrides.get("port") is not None:
|
||||
lines.append(f"port = {overrides['port']}")
|
||||
if overrides.get("logpath"):
|
||||
paths: list[str] = overrides["logpath"]
|
||||
paths: list[str] = cast("list[str]", overrides["logpath"])
|
||||
if paths:
|
||||
lines.append(f"logpath = {paths[0]}")
|
||||
for p in paths[1:]:
|
||||
@@ -876,9 +890,7 @@ def _write_local_override_sync(
|
||||
# Clean up temp file if rename failed.
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(tmp_name) # noqa: F821 — only reachable when tmp_name is set
|
||||
raise ConfigWriteError(
|
||||
f"Failed to write {local_path}: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
||||
|
||||
log.info(
|
||||
"jail_local_written",
|
||||
@@ -907,9 +919,7 @@ def _restore_local_file_sync(local_path: Path, original_content: bytes | None) -
|
||||
try:
|
||||
local_path.unlink(missing_ok=True)
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(
|
||||
f"Failed to delete {local_path} during rollback: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Failed to delete {local_path} during rollback: {exc}") from exc
|
||||
return
|
||||
|
||||
tmp_name: str | None = None
|
||||
@@ -927,9 +937,7 @@ def _restore_local_file_sync(local_path: Path, original_content: bytes | None) -
|
||||
with contextlib.suppress(OSError):
|
||||
if tmp_name is not None:
|
||||
os.unlink(tmp_name)
|
||||
raise ConfigWriteError(
|
||||
f"Failed to restore {local_path} during rollback: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Failed to restore {local_path} during rollback: {exc}") from exc
|
||||
|
||||
|
||||
def _validate_regex_patterns(patterns: list[str]) -> None:
|
||||
@@ -965,9 +973,7 @@ def _write_filter_local_sync(filter_d: Path, name: str, content: str) -> None:
|
||||
try:
|
||||
filter_d.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(
|
||||
f"Cannot create filter.d directory: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Cannot create filter.d directory: {exc}") from exc
|
||||
|
||||
local_path = filter_d / f"{name}.local"
|
||||
try:
|
||||
@@ -984,9 +990,7 @@ def _write_filter_local_sync(filter_d: Path, name: str, content: str) -> None:
|
||||
except OSError as exc:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(tmp_name) # noqa: F821
|
||||
raise ConfigWriteError(
|
||||
f"Failed to write {local_path}: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
||||
|
||||
log.info("filter_local_written", filter=name, path=str(local_path))
|
||||
|
||||
@@ -1017,9 +1021,7 @@ def _set_jail_local_key_sync(
|
||||
try:
|
||||
jail_d.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(
|
||||
f"Cannot create jail.d directory: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc
|
||||
|
||||
local_path = jail_d / f"{jail_name}.local"
|
||||
|
||||
@@ -1058,9 +1060,7 @@ def _set_jail_local_key_sync(
|
||||
except OSError as exc:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(tmp_name) # noqa: F821
|
||||
raise ConfigWriteError(
|
||||
f"Failed to write {local_path}: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
||||
|
||||
log.info(
|
||||
"jail_local_key_set",
|
||||
@@ -1098,8 +1098,8 @@ async def list_inactive_jails(
|
||||
inactive jails.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
parsed_result: tuple[dict[str, dict[str, str]], dict[str, str]] = (
|
||||
await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
||||
parsed_result: tuple[dict[str, dict[str, str]], dict[str, str]] = await loop.run_in_executor(
|
||||
None, _parse_jails_sync, Path(config_dir)
|
||||
)
|
||||
all_jails, source_files = parsed_result
|
||||
active_names: set[str] = await _get_active_jail_names(socket_path)
|
||||
@@ -1111,7 +1111,7 @@ async def list_inactive_jails(
|
||||
continue
|
||||
|
||||
source = source_files.get(jail_name, config_dir)
|
||||
inactive.append(_build_inactive_jail(jail_name, settings, source))
|
||||
inactive.append(_build_inactive_jail(jail_name, settings, source, Path(config_dir)))
|
||||
|
||||
log.info(
|
||||
"inactive_jails_listed",
|
||||
@@ -1156,9 +1156,7 @@ async def activate_jail(
|
||||
_safe_jail_name(name)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
all_jails, _source_files = await loop.run_in_executor(
|
||||
None, _parse_jails_sync, Path(config_dir)
|
||||
)
|
||||
all_jails, _source_files = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
||||
|
||||
if name not in all_jails:
|
||||
raise JailNotFoundInConfigError(name)
|
||||
@@ -1194,13 +1192,10 @@ async def activate_jail(
|
||||
active=False,
|
||||
fail2ban_running=True,
|
||||
validation_warnings=warnings,
|
||||
message=(
|
||||
f"Jail {name!r} cannot be activated: "
|
||||
+ "; ".join(i.message for i in blocking)
|
||||
),
|
||||
message=(f"Jail {name!r} cannot be activated: " + "; ".join(i.message for i in blocking)),
|
||||
)
|
||||
|
||||
overrides: dict[str, Any] = {
|
||||
overrides: dict[str, object] = {
|
||||
"bantime": req.bantime,
|
||||
"findtime": req.findtime,
|
||||
"maxretry": req.maxretry,
|
||||
@@ -1231,7 +1226,7 @@ async def activate_jail(
|
||||
# Activation reload — if it fails, roll back immediately #
|
||||
# ---------------------------------------------------------------------- #
|
||||
try:
|
||||
await jail_service.reload_all(socket_path, include_jails=[name])
|
||||
await _reload_all(socket_path, include_jails=[name])
|
||||
except JailNotFoundError as exc:
|
||||
# Jail configuration is invalid (e.g. missing logpath that prevents
|
||||
# fail2ban from loading the jail). Roll back and provide a specific error.
|
||||
@@ -1240,9 +1235,7 @@ async def activate_jail(
|
||||
jail=name,
|
||||
error=str(exc),
|
||||
)
|
||||
recovered = await _rollback_activation_async(
|
||||
config_dir, name, socket_path, original_content
|
||||
)
|
||||
recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
|
||||
return JailActivationResponse(
|
||||
name=name,
|
||||
active=False,
|
||||
@@ -1258,9 +1251,7 @@ async def activate_jail(
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("reload_after_activate_failed", jail=name, error=str(exc))
|
||||
recovered = await _rollback_activation_async(
|
||||
config_dir, name, socket_path, original_content
|
||||
)
|
||||
recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
|
||||
return JailActivationResponse(
|
||||
name=name,
|
||||
active=False,
|
||||
@@ -1291,9 +1282,7 @@ async def activate_jail(
|
||||
jail=name,
|
||||
message="fail2ban socket unreachable after reload — initiating rollback.",
|
||||
)
|
||||
recovered = await _rollback_activation_async(
|
||||
config_dir, name, socket_path, original_content
|
||||
)
|
||||
recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
|
||||
return JailActivationResponse(
|
||||
name=name,
|
||||
active=False,
|
||||
@@ -1316,9 +1305,7 @@ async def activate_jail(
|
||||
jail=name,
|
||||
message="Jail did not appear in running jails — initiating rollback.",
|
||||
)
|
||||
recovered = await _rollback_activation_async(
|
||||
config_dir, name, socket_path, original_content
|
||||
)
|
||||
recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
|
||||
return JailActivationResponse(
|
||||
name=name,
|
||||
active=False,
|
||||
@@ -1374,24 +1361,18 @@ async def _rollback_activation_async(
|
||||
|
||||
# Step 1 — restore original file (or delete it).
|
||||
try:
|
||||
await loop.run_in_executor(
|
||||
None, _restore_local_file_sync, local_path, original_content
|
||||
)
|
||||
await loop.run_in_executor(None, _restore_local_file_sync, local_path, original_content)
|
||||
log.info("jail_activation_rollback_file_restored", jail=name)
|
||||
except ConfigWriteError as exc:
|
||||
log.error(
|
||||
"jail_activation_rollback_restore_failed", jail=name, error=str(exc)
|
||||
)
|
||||
log.error("jail_activation_rollback_restore_failed", jail=name, error=str(exc))
|
||||
return False
|
||||
|
||||
# Step 2 — reload fail2ban with the restored config.
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await _reload_all(socket_path)
|
||||
log.info("jail_activation_rollback_reload_ok", jail=name)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"jail_activation_rollback_reload_failed", jail=name, error=str(exc)
|
||||
)
|
||||
log.warning("jail_activation_rollback_reload_failed", jail=name, error=str(exc))
|
||||
return False
|
||||
|
||||
# Step 3 — wait for fail2ban to come back.
|
||||
@@ -1436,9 +1417,7 @@ async def deactivate_jail(
|
||||
_safe_jail_name(name)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
all_jails, _source_files = await loop.run_in_executor(
|
||||
None, _parse_jails_sync, Path(config_dir)
|
||||
)
|
||||
all_jails, _source_files = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
||||
|
||||
if name not in all_jails:
|
||||
raise JailNotFoundInConfigError(name)
|
||||
@@ -1457,7 +1436,7 @@ async def deactivate_jail(
|
||||
)
|
||||
|
||||
try:
|
||||
await jail_service.reload_all(socket_path, exclude_jails=[name])
|
||||
await _reload_all(socket_path, exclude_jails=[name])
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("reload_after_deactivate_failed", jail=name, error=str(exc))
|
||||
|
||||
@@ -1469,6 +1448,51 @@ async def deactivate_jail(
|
||||
)
|
||||
|
||||
|
||||
async def delete_jail_local_override(
|
||||
config_dir: str,
|
||||
socket_path: str,
|
||||
name: str,
|
||||
) -> None:
|
||||
"""Delete the ``jail.d/{name}.local`` override file for an inactive jail.
|
||||
|
||||
This is the clean-up action shown in the config UI when an inactive jail
|
||||
still has a ``.local`` override file (e.g. ``enabled = false``). The
|
||||
file is deleted outright; no fail2ban reload is required because the jail
|
||||
is already inactive.
|
||||
|
||||
Args:
|
||||
config_dir: Absolute path to the fail2ban configuration directory.
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
name: Name of the jail whose ``.local`` file should be removed.
|
||||
|
||||
Raises:
|
||||
JailNameError: If *name* contains invalid characters.
|
||||
JailNotFoundInConfigError: If *name* is not defined in any config file.
|
||||
JailAlreadyActiveError: If the jail is currently active (refusing to
|
||||
delete the live config file).
|
||||
ConfigWriteError: If the file cannot be deleted.
|
||||
"""
|
||||
_safe_jail_name(name)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
all_jails, _source_files = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
||||
|
||||
if name not in all_jails:
|
||||
raise JailNotFoundInConfigError(name)
|
||||
|
||||
active_names = await _get_active_jail_names(socket_path)
|
||||
if name in active_names:
|
||||
raise JailAlreadyActiveError(name)
|
||||
|
||||
local_path = Path(config_dir) / "jail.d" / f"{name}.local"
|
||||
try:
|
||||
await loop.run_in_executor(None, lambda: local_path.unlink(missing_ok=True))
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(f"Failed to delete {local_path}: {exc}") from exc
|
||||
|
||||
log.info("jail_local_override_deleted", jail=name, path=str(local_path))
|
||||
|
||||
|
||||
async def validate_jail_config(
|
||||
config_dir: str,
|
||||
name: str,
|
||||
@@ -1541,13 +1565,11 @@ async def rollback_jail(
|
||||
log.info("jail_rolled_back_disabled", jail=name)
|
||||
|
||||
# Attempt to start the daemon.
|
||||
started = await _start_daemon(start_cmd_parts)
|
||||
started = await start_daemon(start_cmd_parts)
|
||||
log.info("jail_rollback_start_attempted", jail=name, start_ok=started)
|
||||
|
||||
# Wait for the socket to come back.
|
||||
fail2ban_running = await _wait_for_fail2ban(
|
||||
socket_path, max_wait_seconds=10.0, poll_interval=2.0
|
||||
)
|
||||
fail2ban_running = await wait_for_fail2ban(socket_path, max_wait_seconds=10.0, poll_interval=2.0)
|
||||
|
||||
active_jails = 0
|
||||
if fail2ban_running:
|
||||
@@ -1561,10 +1583,7 @@ async def rollback_jail(
|
||||
disabled=True,
|
||||
fail2ban_running=True,
|
||||
active_jails=active_jails,
|
||||
message=(
|
||||
f"Jail {name!r} disabled and fail2ban restarted successfully "
|
||||
f"with {active_jails} active jail(s)."
|
||||
),
|
||||
message=(f"Jail {name!r} disabled and fail2ban restarted successfully with {active_jails} active jail(s)."),
|
||||
)
|
||||
|
||||
log.warning("jail_rollback_fail2ban_still_down", jail=name)
|
||||
@@ -1585,9 +1604,7 @@ async def rollback_jail(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Allowlist pattern for filter names used in path construction.
|
||||
_SAFE_FILTER_NAME_RE: re.Pattern[str] = re.compile(
|
||||
r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$"
|
||||
)
|
||||
_SAFE_FILTER_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
|
||||
|
||||
|
||||
class FilterNotFoundError(Exception):
|
||||
@@ -1699,9 +1716,7 @@ def _parse_filters_sync(
|
||||
try:
|
||||
content = conf_path.read_text(encoding="utf-8")
|
||||
except OSError as exc:
|
||||
log.warning(
|
||||
"filter_read_error", name=name, path=str(conf_path), error=str(exc)
|
||||
)
|
||||
log.warning("filter_read_error", name=name, path=str(conf_path), error=str(exc))
|
||||
continue
|
||||
|
||||
if has_local:
|
||||
@@ -1777,9 +1792,7 @@ async def list_filters(
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Run the synchronous scan in a thread-pool executor.
|
||||
raw_filters: list[tuple[str, str, str, bool, str]] = await loop.run_in_executor(
|
||||
None, _parse_filters_sync, filter_d
|
||||
)
|
||||
raw_filters: list[tuple[str, str, str, bool, str]] = await loop.run_in_executor(None, _parse_filters_sync, filter_d)
|
||||
|
||||
# Fetch active jail names and their configs concurrently.
|
||||
all_jails_result, active_names = await asyncio.gather(
|
||||
@@ -1792,9 +1805,7 @@ async def list_filters(
|
||||
|
||||
filters: list[FilterConfig] = []
|
||||
for name, filename, content, has_local, source_path in raw_filters:
|
||||
cfg = conffile_parser.parse_filter_file(
|
||||
content, name=name, filename=filename
|
||||
)
|
||||
cfg = conffile_parser.parse_filter_file(content, name=name, filename=filename)
|
||||
used_by = sorted(filter_to_jails.get(name, []))
|
||||
filters.append(
|
||||
FilterConfig(
|
||||
@@ -1882,9 +1893,7 @@ async def get_filter(
|
||||
|
||||
content, has_local, source_path = await loop.run_in_executor(None, _read)
|
||||
|
||||
cfg = conffile_parser.parse_filter_file(
|
||||
content, name=base_name, filename=f"{base_name}.conf"
|
||||
)
|
||||
cfg = conffile_parser.parse_filter_file(content, name=base_name, filename=f"{base_name}.conf")
|
||||
|
||||
all_jails_result, active_names = await asyncio.gather(
|
||||
loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)),
|
||||
@@ -1983,7 +1992,7 @@ async def update_filter(
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await _reload_all(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_filter_update_failed",
|
||||
@@ -2058,7 +2067,7 @@ async def create_filter(
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await _reload_all(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_filter_create_failed",
|
||||
@@ -2117,9 +2126,7 @@ async def delete_filter(
|
||||
try:
|
||||
local_path.unlink()
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(
|
||||
f"Failed to delete {local_path}: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Failed to delete {local_path}: {exc}") from exc
|
||||
|
||||
log.info("filter_local_deleted", filter=base_name, path=str(local_path))
|
||||
|
||||
@@ -2161,9 +2168,7 @@ async def assign_filter_to_jail(
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Verify the jail exists in config.
|
||||
all_jails, _src = await loop.run_in_executor(
|
||||
None, _parse_jails_sync, Path(config_dir)
|
||||
)
|
||||
all_jails, _src = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
||||
if jail_name not in all_jails:
|
||||
raise JailNotFoundInConfigError(jail_name)
|
||||
|
||||
@@ -2189,7 +2194,7 @@ async def assign_filter_to_jail(
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await _reload_all(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_assign_filter_failed",
|
||||
@@ -2211,9 +2216,7 @@ async def assign_filter_to_jail(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Allowlist pattern for action names used in path construction.
|
||||
_SAFE_ACTION_NAME_RE: re.Pattern[str] = re.compile(
|
||||
r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$"
|
||||
)
|
||||
_SAFE_ACTION_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
|
||||
|
||||
|
||||
class ActionNotFoundError(Exception):
|
||||
@@ -2253,8 +2256,7 @@ class ActionReadonlyError(Exception):
|
||||
"""
|
||||
self.name: str = name
|
||||
super().__init__(
|
||||
f"Action {name!r} is a shipped default (.conf only); "
|
||||
"only user-created .local files can be deleted."
|
||||
f"Action {name!r} is a shipped default (.conf only); only user-created .local files can be deleted."
|
||||
)
|
||||
|
||||
|
||||
@@ -2363,9 +2365,7 @@ def _parse_actions_sync(
|
||||
try:
|
||||
content = conf_path.read_text(encoding="utf-8")
|
||||
except OSError as exc:
|
||||
log.warning(
|
||||
"action_read_error", name=name, path=str(conf_path), error=str(exc)
|
||||
)
|
||||
log.warning("action_read_error", name=name, path=str(conf_path), error=str(exc))
|
||||
continue
|
||||
|
||||
if has_local:
|
||||
@@ -2430,9 +2430,7 @@ def _append_jail_action_sync(
|
||||
try:
|
||||
jail_d.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(
|
||||
f"Cannot create jail.d directory: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc
|
||||
|
||||
local_path = jail_d / f"{jail_name}.local"
|
||||
|
||||
@@ -2452,9 +2450,7 @@ def _append_jail_action_sync(
|
||||
|
||||
existing_raw = parser.get(jail_name, "action") if parser.has_option(jail_name, "action") else ""
|
||||
existing_lines = [
|
||||
line.strip()
|
||||
for line in existing_raw.splitlines()
|
||||
if line.strip() and not line.strip().startswith("#")
|
||||
line.strip() for line in existing_raw.splitlines() if line.strip() and not line.strip().startswith("#")
|
||||
]
|
||||
|
||||
# Extract base names from existing entries for duplicate checking.
|
||||
@@ -2468,9 +2464,7 @@ def _append_jail_action_sync(
|
||||
|
||||
if existing_lines:
|
||||
# configparser multi-line: continuation lines start with whitespace.
|
||||
new_value = existing_lines[0] + "".join(
|
||||
f"\n {line}" for line in existing_lines[1:]
|
||||
)
|
||||
new_value = existing_lines[0] + "".join(f"\n {line}" for line in existing_lines[1:])
|
||||
parser.set(jail_name, "action", new_value)
|
||||
else:
|
||||
parser.set(jail_name, "action", action_entry)
|
||||
@@ -2494,9 +2488,7 @@ def _append_jail_action_sync(
|
||||
except OSError as exc:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(tmp_name) # noqa: F821
|
||||
raise ConfigWriteError(
|
||||
f"Failed to write {local_path}: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
||||
|
||||
log.info(
|
||||
"jail_action_appended",
|
||||
@@ -2547,9 +2539,7 @@ def _remove_jail_action_sync(
|
||||
|
||||
existing_raw = parser.get(jail_name, "action")
|
||||
existing_lines = [
|
||||
line.strip()
|
||||
for line in existing_raw.splitlines()
|
||||
if line.strip() and not line.strip().startswith("#")
|
||||
line.strip() for line in existing_raw.splitlines() if line.strip() and not line.strip().startswith("#")
|
||||
]
|
||||
|
||||
def _base(entry: str) -> str:
|
||||
@@ -2563,9 +2553,7 @@ def _remove_jail_action_sync(
|
||||
return
|
||||
|
||||
if filtered:
|
||||
new_value = filtered[0] + "".join(
|
||||
f"\n {line}" for line in filtered[1:]
|
||||
)
|
||||
new_value = filtered[0] + "".join(f"\n {line}" for line in filtered[1:])
|
||||
parser.set(jail_name, "action", new_value)
|
||||
else:
|
||||
parser.remove_option(jail_name, "action")
|
||||
@@ -2589,9 +2577,7 @@ def _remove_jail_action_sync(
|
||||
except OSError as exc:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(tmp_name) # noqa: F821
|
||||
raise ConfigWriteError(
|
||||
f"Failed to write {local_path}: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
||||
|
||||
log.info(
|
||||
"jail_action_removed",
|
||||
@@ -2618,9 +2604,7 @@ def _write_action_local_sync(action_d: Path, name: str, content: str) -> None:
|
||||
try:
|
||||
action_d.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(
|
||||
f"Cannot create action.d directory: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Cannot create action.d directory: {exc}") from exc
|
||||
|
||||
local_path = action_d / f"{name}.local"
|
||||
try:
|
||||
@@ -2637,9 +2621,7 @@ def _write_action_local_sync(action_d: Path, name: str, content: str) -> None:
|
||||
except OSError as exc:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(tmp_name) # noqa: F821
|
||||
raise ConfigWriteError(
|
||||
f"Failed to write {local_path}: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
||||
|
||||
log.info("action_local_written", action=name, path=str(local_path))
|
||||
|
||||
@@ -2675,9 +2657,7 @@ async def list_actions(
|
||||
action_d = Path(config_dir) / "action.d"
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
raw_actions: list[tuple[str, str, str, bool, str]] = await loop.run_in_executor(
|
||||
None, _parse_actions_sync, action_d
|
||||
)
|
||||
raw_actions: list[tuple[str, str, str, bool, str]] = await loop.run_in_executor(None, _parse_actions_sync, action_d)
|
||||
|
||||
all_jails_result, active_names = await asyncio.gather(
|
||||
loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)),
|
||||
@@ -2689,9 +2669,7 @@ async def list_actions(
|
||||
|
||||
actions: list[ActionConfig] = []
|
||||
for name, filename, content, has_local, source_path in raw_actions:
|
||||
cfg = conffile_parser.parse_action_file(
|
||||
content, name=name, filename=filename
|
||||
)
|
||||
cfg = conffile_parser.parse_action_file(content, name=name, filename=filename)
|
||||
used_by = sorted(action_to_jails.get(name, []))
|
||||
actions.append(
|
||||
ActionConfig(
|
||||
@@ -2778,9 +2756,7 @@ async def get_action(
|
||||
|
||||
content, has_local, source_path = await loop.run_in_executor(None, _read)
|
||||
|
||||
cfg = conffile_parser.parse_action_file(
|
||||
content, name=base_name, filename=f"{base_name}.conf"
|
||||
)
|
||||
cfg = conffile_parser.parse_action_file(content, name=base_name, filename=f"{base_name}.conf")
|
||||
|
||||
all_jails_result, active_names = await asyncio.gather(
|
||||
loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)),
|
||||
@@ -2870,7 +2846,7 @@ async def update_action(
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await _reload_all(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_action_update_failed",
|
||||
@@ -2939,7 +2915,7 @@ async def create_action(
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await _reload_all(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_action_create_failed",
|
||||
@@ -2996,9 +2972,7 @@ async def delete_action(
|
||||
try:
|
||||
local_path.unlink()
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(
|
||||
f"Failed to delete {local_path}: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Failed to delete {local_path}: {exc}") from exc
|
||||
|
||||
log.info("action_local_deleted", action=base_name, path=str(local_path))
|
||||
|
||||
@@ -3040,9 +3014,7 @@ async def assign_action_to_jail(
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
all_jails, _src = await loop.run_in_executor(
|
||||
None, _parse_jails_sync, Path(config_dir)
|
||||
)
|
||||
all_jails, _src = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
||||
if jail_name not in all_jails:
|
||||
raise JailNotFoundInConfigError(jail_name)
|
||||
|
||||
@@ -3074,7 +3046,7 @@ async def assign_action_to_jail(
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await _reload_all(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_assign_action_failed",
|
||||
@@ -3122,9 +3094,7 @@ async def remove_action_from_jail(
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
all_jails, _src = await loop.run_in_executor(
|
||||
None, _parse_jails_sync, Path(config_dir)
|
||||
)
|
||||
all_jails, _src = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
||||
if jail_name not in all_jails:
|
||||
raise JailNotFoundInConfigError(jail_name)
|
||||
|
||||
@@ -3138,7 +3108,7 @@ async def remove_action_from_jail(
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await _reload_all(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_remove_action_failed",
|
||||
@@ -3153,4 +3123,3 @@ async def remove_action_from_jail(
|
||||
action=action_name,
|
||||
reload=do_reload,
|
||||
)
|
||||
|
||||
|
||||
@@ -16,13 +16,19 @@ import asyncio
|
||||
import contextlib
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, TypeVar, cast
|
||||
|
||||
import structlog
|
||||
|
||||
from app.utils.fail2ban_client import Fail2BanCommand, Fail2BanResponse, Fail2BanToken
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from app import __version__
|
||||
from app.exceptions import ConfigOperationError, ConfigValidationError, JailNotFoundError
|
||||
from app.models.config import (
|
||||
AddLogPathRequest,
|
||||
BantimeEscalation,
|
||||
@@ -33,7 +39,6 @@ from app.models.config import (
|
||||
JailConfigListResponse,
|
||||
JailConfigResponse,
|
||||
JailConfigUpdate,
|
||||
LogPreviewLine,
|
||||
LogPreviewRequest,
|
||||
LogPreviewResponse,
|
||||
MapColorThresholdsResponse,
|
||||
@@ -42,8 +47,15 @@ from app.models.config import (
|
||||
RegexTestResponse,
|
||||
ServiceStatusResponse,
|
||||
)
|
||||
from app.services import setup_service
|
||||
from app.utils.fail2ban_client import Fail2BanClient
|
||||
from app.utils.log_utils import preview_log as util_preview_log
|
||||
from app.utils.log_utils import test_regex as util_test_regex
|
||||
from app.utils.setup_utils import (
|
||||
get_map_color_thresholds as util_get_map_color_thresholds,
|
||||
)
|
||||
from app.utils.setup_utils import (
|
||||
set_map_color_thresholds as util_set_map_color_thresholds,
|
||||
)
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
@@ -53,26 +65,7 @@ _SOCKET_TIMEOUT: float = 10.0
|
||||
# Custom exceptions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class JailNotFoundError(Exception):
|
||||
"""Raised when a requested jail name does not exist in fail2ban."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""Initialise with the jail name that was not found.
|
||||
|
||||
Args:
|
||||
name: The jail name that could not be located.
|
||||
"""
|
||||
self.name: str = name
|
||||
super().__init__(f"Jail not found: {name!r}")
|
||||
|
||||
|
||||
class ConfigValidationError(Exception):
|
||||
"""Raised when a configuration value fails validation before writing."""
|
||||
|
||||
|
||||
class ConfigOperationError(Exception):
|
||||
"""Raised when a configuration write command fails."""
|
||||
# (exceptions are now defined in app.exceptions and imported above)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -80,7 +73,7 @@ class ConfigOperationError(Exception):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _ok(response: Any) -> Any:
|
||||
def _ok(response: object) -> object:
|
||||
"""Extract payload from a fail2ban ``(return_code, data)`` response.
|
||||
|
||||
Args:
|
||||
@@ -93,7 +86,7 @@ def _ok(response: Any) -> Any:
|
||||
ValueError: If the return code indicates an error.
|
||||
"""
|
||||
try:
|
||||
code, data = response
|
||||
code, data = cast("Fail2BanResponse", response)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
|
||||
if code != 0:
|
||||
@@ -101,11 +94,11 @@ def _ok(response: Any) -> Any:
|
||||
return data
|
||||
|
||||
|
||||
def _to_dict(pairs: Any) -> dict[str, Any]:
|
||||
def _to_dict(pairs: object) -> dict[str, object]:
|
||||
"""Convert a list of ``(key, value)`` pairs to a plain dict."""
|
||||
if not isinstance(pairs, (list, tuple)):
|
||||
return {}
|
||||
result: dict[str, Any] = {}
|
||||
result: dict[str, object] = {}
|
||||
for item in pairs:
|
||||
try:
|
||||
k, v = item
|
||||
@@ -115,7 +108,7 @@ def _to_dict(pairs: Any) -> dict[str, Any]:
|
||||
return result
|
||||
|
||||
|
||||
def _ensure_list(value: Any) -> list[str]:
|
||||
def _ensure_list(value: object | None) -> list[str]:
|
||||
"""Coerce a fail2ban ``get`` result to a list of strings."""
|
||||
if value is None:
|
||||
return []
|
||||
@@ -126,11 +119,14 @@ def _ensure_list(value: Any) -> list[str]:
|
||||
return [str(value)]
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
async def _safe_get(
|
||||
client: Fail2BanClient,
|
||||
command: list[Any],
|
||||
default: Any = None,
|
||||
) -> Any:
|
||||
command: Fail2BanCommand,
|
||||
default: object | None = None,
|
||||
) -> object | None:
|
||||
"""Send a command and return *default* if it fails."""
|
||||
try:
|
||||
return _ok(await client.send(command))
|
||||
@@ -138,6 +134,15 @@ async def _safe_get(
|
||||
return default
|
||||
|
||||
|
||||
async def _safe_get_typed[T](
|
||||
client: Fail2BanClient,
|
||||
command: Fail2BanCommand,
|
||||
default: T,
|
||||
) -> T:
|
||||
"""Send a command and return the result typed as ``default``'s type."""
|
||||
return cast("T", await _safe_get(client, command, default))
|
||||
|
||||
|
||||
def _is_not_found_error(exc: Exception) -> bool:
|
||||
"""Return ``True`` if *exc* signals an unknown jail."""
|
||||
msg = str(exc).lower()
|
||||
@@ -192,47 +197,25 @@ async def get_jail_config(socket_path: str, name: str) -> JailConfigResponse:
|
||||
raise JailNotFoundError(name) from exc
|
||||
raise
|
||||
|
||||
(
|
||||
bantime_raw,
|
||||
findtime_raw,
|
||||
maxretry_raw,
|
||||
failregex_raw,
|
||||
ignoreregex_raw,
|
||||
logpath_raw,
|
||||
datepattern_raw,
|
||||
logencoding_raw,
|
||||
backend_raw,
|
||||
usedns_raw,
|
||||
prefregex_raw,
|
||||
actions_raw,
|
||||
bt_increment_raw,
|
||||
bt_factor_raw,
|
||||
bt_formula_raw,
|
||||
bt_multipliers_raw,
|
||||
bt_maxtime_raw,
|
||||
bt_rndtime_raw,
|
||||
bt_overalljails_raw,
|
||||
) = await asyncio.gather(
|
||||
_safe_get(client, ["get", name, "bantime"], 600),
|
||||
_safe_get(client, ["get", name, "findtime"], 600),
|
||||
_safe_get(client, ["get", name, "maxretry"], 5),
|
||||
_safe_get(client, ["get", name, "failregex"], []),
|
||||
_safe_get(client, ["get", name, "ignoreregex"], []),
|
||||
_safe_get(client, ["get", name, "logpath"], []),
|
||||
_safe_get(client, ["get", name, "datepattern"], None),
|
||||
_safe_get(client, ["get", name, "logencoding"], "UTF-8"),
|
||||
_safe_get(client, ["get", name, "backend"], "polling"),
|
||||
_safe_get(client, ["get", name, "usedns"], "warn"),
|
||||
_safe_get(client, ["get", name, "prefregex"], ""),
|
||||
_safe_get(client, ["get", name, "actions"], []),
|
||||
_safe_get(client, ["get", name, "bantime.increment"], False),
|
||||
_safe_get(client, ["get", name, "bantime.factor"], None),
|
||||
_safe_get(client, ["get", name, "bantime.formula"], None),
|
||||
_safe_get(client, ["get", name, "bantime.multipliers"], None),
|
||||
_safe_get(client, ["get", name, "bantime.maxtime"], None),
|
||||
_safe_get(client, ["get", name, "bantime.rndtime"], None),
|
||||
_safe_get(client, ["get", name, "bantime.overalljails"], False),
|
||||
)
|
||||
bantime_raw: int = await _safe_get_typed(client, ["get", name, "bantime"], 600)
|
||||
findtime_raw: int = await _safe_get_typed(client, ["get", name, "findtime"], 600)
|
||||
maxretry_raw: int = await _safe_get_typed(client, ["get", name, "maxretry"], 5)
|
||||
failregex_raw: list[object] = await _safe_get_typed(client, ["get", name, "failregex"], [])
|
||||
ignoreregex_raw: list[object] = await _safe_get_typed(client, ["get", name, "ignoreregex"], [])
|
||||
logpath_raw: list[object] = await _safe_get_typed(client, ["get", name, "logpath"], [])
|
||||
datepattern_raw: str | None = await _safe_get_typed(client, ["get", name, "datepattern"], None)
|
||||
logencoding_raw: str = await _safe_get_typed(client, ["get", name, "logencoding"], "UTF-8")
|
||||
backend_raw: str = await _safe_get_typed(client, ["get", name, "backend"], "polling")
|
||||
usedns_raw: str = await _safe_get_typed(client, ["get", name, "usedns"], "warn")
|
||||
prefregex_raw: str = await _safe_get_typed(client, ["get", name, "prefregex"], "")
|
||||
actions_raw: list[object] = await _safe_get_typed(client, ["get", name, "actions"], [])
|
||||
bt_increment_raw: bool = await _safe_get_typed(client, ["get", name, "bantime.increment"], False)
|
||||
bt_factor_raw: str | float | None = await _safe_get_typed(client, ["get", name, "bantime.factor"], None)
|
||||
bt_formula_raw: str | None = await _safe_get_typed(client, ["get", name, "bantime.formula"], None)
|
||||
bt_multipliers_raw: str | None = await _safe_get_typed(client, ["get", name, "bantime.multipliers"], None)
|
||||
bt_maxtime_raw: str | int | None = await _safe_get_typed(client, ["get", name, "bantime.maxtime"], None)
|
||||
bt_rndtime_raw: str | int | None = await _safe_get_typed(client, ["get", name, "bantime.rndtime"], None)
|
||||
bt_overalljails_raw: bool = await _safe_get_typed(client, ["get", name, "bantime.overalljails"], False)
|
||||
|
||||
bantime_escalation = BantimeEscalation(
|
||||
increment=bool(bt_increment_raw),
|
||||
@@ -352,7 +335,7 @@ async def update_jail_config(
|
||||
raise JailNotFoundError(name) from exc
|
||||
raise
|
||||
|
||||
async def _set(key: str, value: Any) -> None:
|
||||
async def _set(key: str, value: Fail2BanToken) -> None:
|
||||
try:
|
||||
_ok(await client.send(["set", name, key, value]))
|
||||
except ValueError as exc:
|
||||
@@ -422,7 +405,7 @@ async def _replace_regex_list(
|
||||
new_patterns: Replacement list (may be empty to clear).
|
||||
"""
|
||||
# Determine current count.
|
||||
current_raw = await _safe_get(client, ["get", jail, field], [])
|
||||
current_raw: list[object] = await _safe_get_typed(client, ["get", jail, field], [])
|
||||
current: list[str] = _ensure_list(current_raw)
|
||||
|
||||
del_cmd = f"del{field}"
|
||||
@@ -469,10 +452,10 @@ async def get_global_config(socket_path: str) -> GlobalConfigResponse:
|
||||
db_purge_age_raw,
|
||||
db_max_matches_raw,
|
||||
) = await asyncio.gather(
|
||||
_safe_get(client, ["get", "loglevel"], "INFO"),
|
||||
_safe_get(client, ["get", "logtarget"], "STDOUT"),
|
||||
_safe_get(client, ["get", "dbpurgeage"], 86400),
|
||||
_safe_get(client, ["get", "dbmaxmatches"], 10),
|
||||
_safe_get_typed(client, ["get", "loglevel"], "INFO"),
|
||||
_safe_get_typed(client, ["get", "logtarget"], "STDOUT"),
|
||||
_safe_get_typed(client, ["get", "dbpurgeage"], 86400),
|
||||
_safe_get_typed(client, ["get", "dbmaxmatches"], 10),
|
||||
)
|
||||
|
||||
return GlobalConfigResponse(
|
||||
@@ -496,7 +479,7 @@ async def update_global_config(socket_path: str, update: GlobalConfigUpdate) ->
|
||||
"""
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
|
||||
async def _set_global(key: str, value: Any) -> None:
|
||||
async def _set_global(key: str, value: Fail2BanToken) -> None:
|
||||
try:
|
||||
_ok(await client.send(["set", key, value]))
|
||||
except ValueError as exc:
|
||||
@@ -520,27 +503,8 @@ async def update_global_config(socket_path: str, update: GlobalConfigUpdate) ->
|
||||
|
||||
|
||||
def test_regex(request: RegexTestRequest) -> RegexTestResponse:
|
||||
"""Test a regex pattern against a sample log line.
|
||||
|
||||
This is a pure in-process operation — no socket communication occurs.
|
||||
|
||||
Args:
|
||||
request: The :class:`~app.models.config.RegexTestRequest` payload.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.RegexTestResponse` with match result.
|
||||
"""
|
||||
try:
|
||||
compiled = re.compile(request.fail_regex)
|
||||
except re.error as exc:
|
||||
return RegexTestResponse(matched=False, groups=[], error=str(exc))
|
||||
|
||||
match = compiled.search(request.log_line)
|
||||
if match is None:
|
||||
return RegexTestResponse(matched=False)
|
||||
|
||||
groups: list[str] = list(match.groups() or [])
|
||||
return RegexTestResponse(matched=True, groups=[str(g) for g in groups if g is not None])
|
||||
"""Proxy to log utilities for regex test without service imports."""
|
||||
return util_test_regex(request)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -618,101 +582,14 @@ async def delete_log_path(
|
||||
raise ConfigOperationError(f"Failed to delete log path {log_path!r}: {exc}") from exc
|
||||
|
||||
|
||||
async def preview_log(req: LogPreviewRequest) -> LogPreviewResponse:
|
||||
"""Read the last *num_lines* of a log file and test *fail_regex* against each.
|
||||
|
||||
This operation reads from the local filesystem — no socket is used.
|
||||
|
||||
Args:
|
||||
req: :class:`~app.models.config.LogPreviewRequest`.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.LogPreviewResponse` with line-by-line results.
|
||||
"""
|
||||
# Validate the regex first.
|
||||
try:
|
||||
compiled = re.compile(req.fail_regex)
|
||||
except re.error as exc:
|
||||
return LogPreviewResponse(
|
||||
lines=[],
|
||||
total_lines=0,
|
||||
matched_count=0,
|
||||
regex_error=str(exc),
|
||||
)
|
||||
|
||||
path = Path(req.log_path)
|
||||
if not path.is_file():
|
||||
return LogPreviewResponse(
|
||||
lines=[],
|
||||
total_lines=0,
|
||||
matched_count=0,
|
||||
regex_error=f"File not found: {req.log_path!r}",
|
||||
)
|
||||
|
||||
# Read the last num_lines lines efficiently.
|
||||
try:
|
||||
raw_lines = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
_read_tail_lines,
|
||||
str(path),
|
||||
req.num_lines,
|
||||
)
|
||||
except OSError as exc:
|
||||
return LogPreviewResponse(
|
||||
lines=[],
|
||||
total_lines=0,
|
||||
matched_count=0,
|
||||
regex_error=f"Cannot read file: {exc}",
|
||||
)
|
||||
|
||||
result_lines: list[LogPreviewLine] = []
|
||||
matched_count = 0
|
||||
for line in raw_lines:
|
||||
m = compiled.search(line)
|
||||
groups = [str(g) for g in (m.groups() or []) if g is not None] if m else []
|
||||
result_lines.append(LogPreviewLine(line=line, matched=(m is not None), groups=groups))
|
||||
if m:
|
||||
matched_count += 1
|
||||
|
||||
return LogPreviewResponse(
|
||||
lines=result_lines,
|
||||
total_lines=len(result_lines),
|
||||
matched_count=matched_count,
|
||||
)
|
||||
|
||||
|
||||
def _read_tail_lines(file_path: str, num_lines: int) -> list[str]:
|
||||
"""Read the last *num_lines* from *file_path* synchronously.
|
||||
|
||||
Uses a memory-efficient approach that seeks from the end of the file.
|
||||
|
||||
Args:
|
||||
file_path: Absolute path to the log file.
|
||||
num_lines: Number of lines to return.
|
||||
|
||||
Returns:
|
||||
A list of stripped line strings.
|
||||
"""
|
||||
chunk_size = 8192
|
||||
raw_lines: list[bytes] = []
|
||||
with open(file_path, "rb") as fh:
|
||||
fh.seek(0, 2) # seek to end
|
||||
end_pos = fh.tell()
|
||||
if end_pos == 0:
|
||||
return []
|
||||
buf = b""
|
||||
pos = end_pos
|
||||
while len(raw_lines) <= num_lines and pos > 0:
|
||||
read_size = min(chunk_size, pos)
|
||||
pos -= read_size
|
||||
fh.seek(pos)
|
||||
chunk = fh.read(read_size)
|
||||
buf = chunk + buf
|
||||
raw_lines = buf.split(b"\n")
|
||||
# Strip incomplete leading line unless we've read the whole file.
|
||||
if pos > 0 and len(raw_lines) > 1:
|
||||
raw_lines = raw_lines[1:]
|
||||
return [ln.decode("utf-8", errors="replace").rstrip() for ln in raw_lines[-num_lines:] if ln.strip()]
|
||||
async def preview_log(
|
||||
req: LogPreviewRequest,
|
||||
preview_fn: Callable[[LogPreviewRequest], Awaitable[LogPreviewResponse]] | None = None,
|
||||
) -> LogPreviewResponse:
|
||||
"""Proxy to an injectable log preview function."""
|
||||
if preview_fn is None:
|
||||
preview_fn = util_preview_log
|
||||
return await preview_fn(req)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -729,7 +606,7 @@ async def get_map_color_thresholds(db: aiosqlite.Connection) -> MapColorThreshol
|
||||
Returns:
|
||||
A :class:`MapColorThresholdsResponse` containing the three threshold values.
|
||||
"""
|
||||
high, medium, low = await setup_service.get_map_color_thresholds(db)
|
||||
high, medium, low = await util_get_map_color_thresholds(db)
|
||||
return MapColorThresholdsResponse(
|
||||
threshold_high=high,
|
||||
threshold_medium=medium,
|
||||
@@ -750,7 +627,7 @@ async def update_map_color_thresholds(
|
||||
Raises:
|
||||
ValueError: If validation fails (thresholds must satisfy high > medium > low).
|
||||
"""
|
||||
await setup_service.set_map_color_thresholds(
|
||||
await util_set_map_color_thresholds(
|
||||
db,
|
||||
threshold_high=update.threshold_high,
|
||||
threshold_medium=update.threshold_medium,
|
||||
@@ -772,16 +649,7 @@ _SAFE_LOG_PREFIXES: tuple[str, ...] = ("/var/log", "/config/log")
|
||||
|
||||
|
||||
def _count_file_lines(file_path: str) -> int:
|
||||
"""Count the total number of lines in *file_path* synchronously.
|
||||
|
||||
Uses a memory-efficient buffered read to avoid loading the whole file.
|
||||
|
||||
Args:
|
||||
file_path: Absolute path to the file.
|
||||
|
||||
Returns:
|
||||
Total number of lines in the file.
|
||||
"""
|
||||
"""Count the total number of lines in *file_path* synchronously."""
|
||||
count = 0
|
||||
with open(file_path, "rb") as fh:
|
||||
for chunk in iter(lambda: fh.read(65536), b""):
|
||||
@@ -789,6 +657,32 @@ def _count_file_lines(file_path: str) -> int:
|
||||
return count
|
||||
|
||||
|
||||
def _read_tail_lines(file_path: str, num_lines: int) -> list[str]:
|
||||
"""Read the last *num_lines* from *file_path* in a memory-efficient way."""
|
||||
chunk_size = 8192
|
||||
raw_lines: list[bytes] = []
|
||||
with open(file_path, "rb") as fh:
|
||||
fh.seek(0, 2)
|
||||
end_pos = fh.tell()
|
||||
if end_pos == 0:
|
||||
return []
|
||||
|
||||
buf = b""
|
||||
pos = end_pos
|
||||
while len(raw_lines) <= num_lines and pos > 0:
|
||||
read_size = min(chunk_size, pos)
|
||||
pos -= read_size
|
||||
fh.seek(pos)
|
||||
chunk = fh.read(read_size)
|
||||
buf = chunk + buf
|
||||
raw_lines = buf.split(b"\n")
|
||||
|
||||
if pos > 0 and len(raw_lines) > 1:
|
||||
raw_lines = raw_lines[1:]
|
||||
|
||||
return [ln.decode("utf-8", errors="replace").rstrip() for ln in raw_lines[-num_lines:] if ln.strip()]
|
||||
|
||||
|
||||
async def read_fail2ban_log(
|
||||
socket_path: str,
|
||||
lines: int,
|
||||
@@ -821,8 +715,8 @@ async def read_fail2ban_log(
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
|
||||
log_level_raw, log_target_raw = await asyncio.gather(
|
||||
_safe_get(client, ["get", "loglevel"], "INFO"),
|
||||
_safe_get(client, ["get", "logtarget"], "STDOUT"),
|
||||
_safe_get_typed(client, ["get", "loglevel"], "INFO"),
|
||||
_safe_get_typed(client, ["get", "logtarget"], "STDOUT"),
|
||||
)
|
||||
|
||||
log_level = str(log_level_raw or "INFO").upper()
|
||||
@@ -883,28 +777,33 @@ async def read_fail2ban_log(
|
||||
)
|
||||
|
||||
|
||||
async def get_service_status(socket_path: str) -> ServiceStatusResponse:
|
||||
async def get_service_status(
|
||||
socket_path: str,
|
||||
probe_fn: Callable[[str], Awaitable[ServiceStatusResponse]] | None = None,
|
||||
) -> ServiceStatusResponse:
|
||||
"""Return fail2ban service health status with log configuration.
|
||||
|
||||
Delegates to :func:`~app.services.health_service.probe` for the core
|
||||
health snapshot and augments it with the current log-level and log-target
|
||||
values from the socket.
|
||||
Delegates to an injectable *probe_fn* (defaults to
|
||||
:func:`~app.services.health_service.probe`). This avoids direct service-to-
|
||||
service imports inside this module.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
probe_fn: Optional probe function.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.ServiceStatusResponse`.
|
||||
"""
|
||||
from app.services.health_service import probe # lazy import avoids circular dep
|
||||
if probe_fn is None:
|
||||
raise ValueError("probe_fn is required to avoid service-to-service coupling")
|
||||
|
||||
server_status = await probe(socket_path)
|
||||
server_status = await probe_fn(socket_path)
|
||||
|
||||
if server_status.online:
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
log_level_raw, log_target_raw = await asyncio.gather(
|
||||
_safe_get(client, ["get", "loglevel"], "INFO"),
|
||||
_safe_get(client, ["get", "logtarget"], "STDOUT"),
|
||||
_safe_get_typed(client, ["get", "loglevel"], "INFO"),
|
||||
_safe_get_typed(client, ["get", "logtarget"], "STDOUT"),
|
||||
)
|
||||
log_level = str(log_level_raw or "INFO").upper()
|
||||
log_target = str(log_target_raw or "STDOUT")
|
||||
@@ -920,7 +819,7 @@ async def get_service_status(socket_path: str) -> ServiceStatusResponse:
|
||||
|
||||
return ServiceStatusResponse(
|
||||
online=server_status.online,
|
||||
version=server_status.version,
|
||||
version=__version__,
|
||||
jail_count=server_status.active_jails,
|
||||
total_bans=server_status.total_bans,
|
||||
total_failures=server_status.total_failures,
|
||||
|
||||
926
backend/app/services/filter_config_service.py
Normal file
926
backend/app/services/filter_config_service.py
Normal file
@@ -0,0 +1,926 @@
|
||||
"""Filter configuration management for BanGUI.
|
||||
|
||||
Handles parsing, validation, and lifecycle operations (create/update/delete)
|
||||
for fail2ban filter configurations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import configparser
|
||||
import contextlib
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import structlog
|
||||
|
||||
from app.exceptions import FilterInvalidRegexError
|
||||
from app.models.config import (
|
||||
AssignFilterRequest,
|
||||
FilterConfig,
|
||||
FilterConfigUpdate,
|
||||
FilterCreateRequest,
|
||||
FilterListResponse,
|
||||
FilterUpdateRequest,
|
||||
)
|
||||
from app.services.config_file_service import _TRUE_VALUES, ConfigWriteError, JailNotFoundInConfigError
|
||||
from app.utils import conffile_parser
|
||||
from app.utils.config_file_utils import (
|
||||
_get_active_jail_names,
|
||||
_parse_jails_sync,
|
||||
)
|
||||
from app.utils.jail_utils import reload_jails
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Custom exceptions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class FilterNotFoundError(Exception):
|
||||
"""Raised when the requested filter name is not found in ``filter.d/``."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""Initialise with the filter name that was not found.
|
||||
|
||||
Args:
|
||||
name: The filter name that could not be located.
|
||||
"""
|
||||
self.name: str = name
|
||||
super().__init__(f"Filter not found: {name!r}")
|
||||
|
||||
|
||||
class FilterAlreadyExistsError(Exception):
|
||||
"""Raised when trying to create a filter whose ``.conf`` or ``.local`` already exists."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""Initialise with the filter name that already exists.
|
||||
|
||||
Args:
|
||||
name: The filter name that already exists.
|
||||
"""
|
||||
self.name: str = name
|
||||
super().__init__(f"Filter already exists: {name!r}")
|
||||
|
||||
|
||||
class FilterReadonlyError(Exception):
|
||||
"""Raised when trying to delete a shipped ``.conf`` filter with no ``.local`` override."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""Initialise with the filter name that cannot be deleted.
|
||||
|
||||
Args:
|
||||
name: The filter name that is read-only (shipped ``.conf`` only).
|
||||
"""
|
||||
self.name: str = name
|
||||
super().__init__(
|
||||
f"Filter {name!r} is a shipped default (.conf only); only user-created .local files can be deleted."
|
||||
)
|
||||
|
||||
|
||||
class FilterNameError(Exception):
|
||||
"""Raised when a filter name contains invalid characters."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Additional helper functions for this service
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class JailNameError(Exception):
|
||||
"""Raised when a jail name contains invalid characters."""
|
||||
|
||||
|
||||
_SAFE_FILTER_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
|
||||
_SAFE_JAIL_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
|
||||
|
||||
|
||||
def _safe_filter_name(name: str) -> str:
|
||||
"""Validate *name* and return it unchanged or raise :class:`FilterNameError`.
|
||||
|
||||
Args:
|
||||
name: Proposed filter name (without extension).
|
||||
|
||||
Returns:
|
||||
The name unchanged if valid.
|
||||
|
||||
Raises:
|
||||
FilterNameError: If *name* contains unsafe characters.
|
||||
"""
|
||||
if not _SAFE_FILTER_NAME_RE.match(name):
|
||||
raise FilterNameError(
|
||||
f"Filter name {name!r} contains invalid characters. "
|
||||
"Only alphanumeric characters, hyphens, underscores, and dots are "
|
||||
"allowed; must start with an alphanumeric character."
|
||||
)
|
||||
return name
|
||||
|
||||
|
||||
def _safe_jail_name(name: str) -> str:
|
||||
"""Validate *name* and return it unchanged or raise :class:`JailNameError`.
|
||||
|
||||
Args:
|
||||
name: Proposed jail name.
|
||||
|
||||
Returns:
|
||||
The name unchanged if valid.
|
||||
|
||||
Raises:
|
||||
JailNameError: If *name* contains unsafe characters.
|
||||
"""
|
||||
if not _SAFE_JAIL_NAME_RE.match(name):
|
||||
raise JailNameError(
|
||||
f"Jail name {name!r} contains invalid characters. "
|
||||
"Only alphanumeric characters, hyphens, underscores, and dots are "
|
||||
"allowed; must start with an alphanumeric character."
|
||||
)
|
||||
return name
|
||||
|
||||
|
||||
def _build_parser() -> configparser.RawConfigParser:
|
||||
"""Create a :class:`configparser.RawConfigParser` for fail2ban configs.
|
||||
|
||||
Returns:
|
||||
Parser with interpolation disabled and case-sensitive option names.
|
||||
"""
|
||||
parser = configparser.RawConfigParser(interpolation=None, strict=False)
|
||||
# fail2ban keys are lowercase but preserve case to be safe.
|
||||
parser.optionxform = str # type: ignore[assignment]
|
||||
return parser
|
||||
|
||||
|
||||
def _is_truthy(value: str) -> bool:
|
||||
"""Return ``True`` if *value* is a fail2ban boolean true string.
|
||||
|
||||
Args:
|
||||
value: Raw string from config (e.g. ``"true"``, ``"yes"``, ``"1"``).
|
||||
|
||||
Returns:
|
||||
``True`` when the value represents enabled.
|
||||
"""
|
||||
return value.strip().lower() in _TRUE_VALUES
|
||||
|
||||
|
||||
def _parse_multiline(raw: str) -> list[str]:
|
||||
"""Split a multi-line INI value into individual non-blank lines.
|
||||
|
||||
Args:
|
||||
raw: Raw multi-line string from configparser.
|
||||
|
||||
Returns:
|
||||
List of stripped, non-empty, non-comment strings.
|
||||
"""
|
||||
result: list[str] = []
|
||||
for line in raw.splitlines():
|
||||
stripped = line.strip()
|
||||
if stripped and not stripped.startswith("#"):
|
||||
result.append(stripped)
|
||||
return result
|
||||
|
||||
|
||||
def _resolve_filter(raw_filter: str, jail_name: str, mode: str) -> str:
|
||||
"""Resolve fail2ban variable placeholders in a filter string.
|
||||
|
||||
Handles the common default ``%(__name__)s[mode=%(mode)s]`` pattern that
|
||||
fail2ban uses so the filter name displayed to the user is readable.
|
||||
|
||||
Args:
|
||||
raw_filter: Raw ``filter`` value from config (may contain ``%()s``).
|
||||
jail_name: The jail's section name, used to substitute ``%(__name__)s``.
|
||||
mode: The jail's ``mode`` value, used to substitute ``%(mode)s``.
|
||||
|
||||
Returns:
|
||||
Human-readable filter string.
|
||||
"""
|
||||
result = raw_filter.replace("%(__name__)s", jail_name)
|
||||
result = result.replace("%(mode)s", mode)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers - from config_file_service for local use
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _set_jail_local_key_sync(
|
||||
config_dir: Path,
|
||||
jail_name: str,
|
||||
key: str,
|
||||
value: str,
|
||||
) -> None:
|
||||
"""Update ``jail.d/{jail_name}.local`` to set a single key in the jail section.
|
||||
|
||||
If the ``.local`` file already exists it is read, the key is updated (or
|
||||
added), and the file is written back atomically without disturbing other
|
||||
settings. If the file does not exist a new one is created containing
|
||||
only the BanGUI header comment, the jail section, and the requested key.
|
||||
|
||||
Args:
|
||||
config_dir: The fail2ban configuration root directory.
|
||||
jail_name: Validated jail name (used as section name and filename stem).
|
||||
key: Config key to set inside the jail section.
|
||||
value: Config value to assign.
|
||||
|
||||
Raises:
|
||||
ConfigWriteError: If writing fails.
|
||||
"""
|
||||
jail_d = config_dir / "jail.d"
|
||||
try:
|
||||
jail_d.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc
|
||||
|
||||
local_path = jail_d / f"{jail_name}.local"
|
||||
|
||||
parser = _build_parser()
|
||||
if local_path.is_file():
|
||||
try:
|
||||
parser.read(str(local_path), encoding="utf-8")
|
||||
except (configparser.Error, OSError) as exc:
|
||||
log.warning(
|
||||
"jail_local_read_for_update_error",
|
||||
jail=jail_name,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
if not parser.has_section(jail_name):
|
||||
parser.add_section(jail_name)
|
||||
parser.set(jail_name, key, value)
|
||||
|
||||
# Serialize: write a BanGUI header then the parser output.
|
||||
buf = io.StringIO()
|
||||
buf.write("# Managed by BanGUI — do not edit manually\n\n")
|
||||
parser.write(buf)
|
||||
content = buf.getvalue()
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
encoding="utf-8",
|
||||
dir=jail_d,
|
||||
delete=False,
|
||||
suffix=".tmp",
|
||||
) as tmp:
|
||||
tmp.write(content)
|
||||
tmp_name = tmp.name
|
||||
os.replace(tmp_name, local_path)
|
||||
except OSError as exc:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(tmp_name) # noqa: F821
|
||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
||||
|
||||
log.info(
|
||||
"jail_local_key_set",
|
||||
jail=jail_name,
|
||||
key=key,
|
||||
path=str(local_path),
|
||||
)
|
||||
|
||||
|
||||
def _extract_filter_base_name(filter_raw: str) -> str:
|
||||
"""Extract the base filter name from a raw fail2ban filter string.
|
||||
|
||||
fail2ban jail configs may specify a filter with an optional mode suffix,
|
||||
e.g. ``sshd``, ``sshd[mode=aggressive]``, or
|
||||
``%(__name__)s[mode=%(mode)s]``. This function strips the ``[…]`` mode
|
||||
block and any leading/trailing whitespace to return just the file-system
|
||||
base name used to look up ``filter.d/{name}.conf``.
|
||||
|
||||
Args:
|
||||
filter_raw: Raw ``filter`` value from a jail config (already
|
||||
with ``%(__name__)s`` substituted by the caller).
|
||||
|
||||
Returns:
|
||||
Base filter name, e.g. ``"sshd"``.
|
||||
"""
|
||||
bracket = filter_raw.find("[")
|
||||
if bracket != -1:
|
||||
return filter_raw[:bracket].strip()
|
||||
return filter_raw.strip()
|
||||
|
||||
|
||||
def _build_filter_to_jails_map(
|
||||
all_jails: dict[str, dict[str, str]],
|
||||
active_names: set[str],
|
||||
) -> dict[str, list[str]]:
|
||||
"""Return a mapping of filter base name → list of active jail names.
|
||||
|
||||
Iterates over every jail whose name is in *active_names*, resolves its
|
||||
``filter`` config key, and records the jail against the base filter name.
|
||||
|
||||
Args:
|
||||
all_jails: Merged jail config dict — ``{jail_name: {key: value}}``.
|
||||
active_names: Set of jail names currently running in fail2ban.
|
||||
|
||||
Returns:
|
||||
``{filter_base_name: [jail_name, …]}``.
|
||||
"""
|
||||
mapping: dict[str, list[str]] = {}
|
||||
for jail_name, settings in all_jails.items():
|
||||
if jail_name not in active_names:
|
||||
continue
|
||||
raw_filter = settings.get("filter", "")
|
||||
mode = settings.get("mode", "normal")
|
||||
resolved = _resolve_filter(raw_filter, jail_name, mode) if raw_filter else jail_name
|
||||
base = _extract_filter_base_name(resolved)
|
||||
if base:
|
||||
mapping.setdefault(base, []).append(jail_name)
|
||||
return mapping
|
||||
|
||||
|
||||
def _parse_filters_sync(
|
||||
filter_d: Path,
|
||||
) -> list[tuple[str, str, str, bool, str]]:
|
||||
"""Synchronously scan ``filter.d/`` and return per-filter tuples.
|
||||
|
||||
Each tuple contains:
|
||||
|
||||
- ``name`` — filter base name (``"sshd"``).
|
||||
- ``filename`` — actual filename (``"sshd.conf"`` or ``"sshd.local"``).
|
||||
- ``content`` — merged file content (``conf`` overridden by ``local``).
|
||||
- ``has_local`` — whether a ``.local`` override exists alongside a ``.conf``.
|
||||
- ``source_path`` — absolute path to the primary (``conf``) source file, or
|
||||
to the ``.local`` file for user-created (local-only) filters.
|
||||
|
||||
Also discovers ``.local``-only files (user-created filters with no
|
||||
corresponding ``.conf``). These are returned with ``has_local = False``
|
||||
and ``source_path`` pointing to the ``.local`` file itself.
|
||||
|
||||
Args:
|
||||
filter_d: Path to the ``filter.d`` directory.
|
||||
|
||||
Returns:
|
||||
List of ``(name, filename, content, has_local, source_path)`` tuples,
|
||||
sorted by name.
|
||||
"""
|
||||
if not filter_d.is_dir():
|
||||
log.warning("filter_d_not_found", path=str(filter_d))
|
||||
return []
|
||||
|
||||
conf_names: set[str] = set()
|
||||
results: list[tuple[str, str, str, bool, str]] = []
|
||||
|
||||
# ---- .conf-based filters (with optional .local override) ----------------
|
||||
for conf_path in sorted(filter_d.glob("*.conf")):
|
||||
if not conf_path.is_file():
|
||||
continue
|
||||
name = conf_path.stem
|
||||
filename = conf_path.name
|
||||
conf_names.add(name)
|
||||
local_path = conf_path.with_suffix(".local")
|
||||
has_local = local_path.is_file()
|
||||
|
||||
try:
|
||||
content = conf_path.read_text(encoding="utf-8")
|
||||
except OSError as exc:
|
||||
log.warning("filter_read_error", name=name, path=str(conf_path), error=str(exc))
|
||||
continue
|
||||
|
||||
if has_local:
|
||||
try:
|
||||
local_content = local_path.read_text(encoding="utf-8")
|
||||
# Append local content after conf so configparser reads local
|
||||
# values last (higher priority).
|
||||
content = content + "\n" + local_content
|
||||
except OSError as exc:
|
||||
log.warning(
|
||||
"filter_local_read_error",
|
||||
name=name,
|
||||
path=str(local_path),
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
results.append((name, filename, content, has_local, str(conf_path)))
|
||||
|
||||
# ---- .local-only filters (user-created, no corresponding .conf) ----------
|
||||
for local_path in sorted(filter_d.glob("*.local")):
|
||||
if not local_path.is_file():
|
||||
continue
|
||||
name = local_path.stem
|
||||
if name in conf_names:
|
||||
# Already covered above as a .conf filter with a .local override.
|
||||
continue
|
||||
try:
|
||||
content = local_path.read_text(encoding="utf-8")
|
||||
except OSError as exc:
|
||||
log.warning(
|
||||
"filter_local_read_error",
|
||||
name=name,
|
||||
path=str(local_path),
|
||||
error=str(exc),
|
||||
)
|
||||
continue
|
||||
results.append((name, local_path.name, content, False, str(local_path)))
|
||||
|
||||
results.sort(key=lambda t: t[0])
|
||||
log.debug("filters_scanned", count=len(results), filter_d=str(filter_d))
|
||||
return results
|
||||
|
||||
|
||||
def _validate_regex_patterns(patterns: list[str]) -> None:
|
||||
"""Validate each pattern in *patterns* using Python's ``re`` module.
|
||||
|
||||
Args:
|
||||
patterns: List of regex strings to validate.
|
||||
|
||||
Raises:
|
||||
FilterInvalidRegexError: If any pattern fails to compile.
|
||||
"""
|
||||
for pattern in patterns:
|
||||
try:
|
||||
re.compile(pattern)
|
||||
except re.error as exc:
|
||||
raise FilterInvalidRegexError(pattern, str(exc)) from exc
|
||||
|
||||
|
||||
def _write_filter_local_sync(filter_d: Path, name: str, content: str) -> None:
|
||||
"""Write *content* to ``filter.d/{name}.local`` atomically.
|
||||
|
||||
The write is atomic: content is written to a temp file first, then
|
||||
renamed into place. The ``filter.d/`` directory is created if absent.
|
||||
|
||||
Args:
|
||||
filter_d: Path to the ``filter.d`` directory.
|
||||
name: Validated filter base name (used as filename stem).
|
||||
content: Full serialized filter content to write.
|
||||
|
||||
Raises:
|
||||
ConfigWriteError: If writing fails.
|
||||
"""
|
||||
try:
|
||||
filter_d.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(f"Cannot create filter.d directory: {exc}") from exc
|
||||
|
||||
local_path = filter_d / f"{name}.local"
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
encoding="utf-8",
|
||||
dir=filter_d,
|
||||
delete=False,
|
||||
suffix=".tmp",
|
||||
) as tmp:
|
||||
tmp.write(content)
|
||||
tmp_name = tmp.name
|
||||
os.replace(tmp_name, local_path)
|
||||
except OSError as exc:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(tmp_name) # noqa: F821
|
||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
||||
|
||||
log.info("filter_local_written", filter=name, path=str(local_path))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API — filter discovery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def list_filters(
|
||||
config_dir: str,
|
||||
socket_path: str,
|
||||
) -> FilterListResponse:
|
||||
"""Return all available filters from ``filter.d/`` with active/inactive status.
|
||||
|
||||
Scans ``{config_dir}/filter.d/`` for ``.conf`` files, merges any
|
||||
corresponding ``.local`` overrides, parses each file into a
|
||||
:class:`~app.models.config.FilterConfig`, and cross-references with the
|
||||
currently running jails to determine which filters are active.
|
||||
|
||||
A filter is considered *active* when its base name matches the ``filter``
|
||||
field of at least one currently running jail.
|
||||
|
||||
Args:
|
||||
config_dir: Absolute path to the fail2ban configuration directory.
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.FilterListResponse` with all filters
|
||||
sorted alphabetically, active ones carrying non-empty
|
||||
``used_by_jails`` lists.
|
||||
"""
|
||||
filter_d = Path(config_dir) / "filter.d"
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Run the synchronous scan in a thread-pool executor.
|
||||
raw_filters: list[tuple[str, str, str, bool, str]] = await loop.run_in_executor(None, _parse_filters_sync, filter_d)
|
||||
|
||||
# Fetch active jail names and their configs concurrently.
|
||||
all_jails_result, active_names = await asyncio.gather(
|
||||
loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)),
|
||||
_get_active_jail_names(socket_path),
|
||||
)
|
||||
all_jails, _source_files = all_jails_result
|
||||
|
||||
filter_to_jails = _build_filter_to_jails_map(all_jails, active_names)
|
||||
|
||||
filters: list[FilterConfig] = []
|
||||
for name, filename, content, has_local, source_path in raw_filters:
|
||||
cfg = conffile_parser.parse_filter_file(content, name=name, filename=filename)
|
||||
used_by = sorted(filter_to_jails.get(name, []))
|
||||
filters.append(
|
||||
FilterConfig(
|
||||
name=cfg.name,
|
||||
filename=cfg.filename,
|
||||
before=cfg.before,
|
||||
after=cfg.after,
|
||||
variables=cfg.variables,
|
||||
prefregex=cfg.prefregex,
|
||||
failregex=cfg.failregex,
|
||||
ignoreregex=cfg.ignoreregex,
|
||||
maxlines=cfg.maxlines,
|
||||
datepattern=cfg.datepattern,
|
||||
journalmatch=cfg.journalmatch,
|
||||
active=len(used_by) > 0,
|
||||
used_by_jails=used_by,
|
||||
source_file=source_path,
|
||||
has_local_override=has_local,
|
||||
)
|
||||
)
|
||||
|
||||
log.info("filters_listed", total=len(filters), active=sum(1 for f in filters if f.active))
|
||||
return FilterListResponse(filters=filters, total=len(filters))
|
||||
|
||||
|
||||
async def get_filter(
|
||||
config_dir: str,
|
||||
socket_path: str,
|
||||
name: str,
|
||||
) -> FilterConfig:
|
||||
"""Return a single filter from ``filter.d/`` with active/inactive status.
|
||||
|
||||
Reads ``{config_dir}/filter.d/{name}.conf``, merges any ``.local``
|
||||
override, and enriches the parsed :class:`~app.models.config.FilterConfig`
|
||||
with ``active``, ``used_by_jails``, ``source_file``, and
|
||||
``has_local_override``.
|
||||
|
||||
Args:
|
||||
config_dir: Absolute path to the fail2ban configuration directory.
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
name: Filter base name (e.g. ``"sshd"`` or ``"sshd.conf"``).
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.FilterConfig` with status fields populated.
|
||||
|
||||
Raises:
|
||||
FilterNotFoundError: If no ``{name}.conf`` or ``{name}.local`` file
|
||||
exists in ``filter.d/``.
|
||||
"""
|
||||
# Normalise — strip extension if provided (.conf=5 chars, .local=6 chars).
|
||||
if name.endswith(".conf"):
|
||||
base_name = name[:-5]
|
||||
elif name.endswith(".local"):
|
||||
base_name = name[:-6]
|
||||
else:
|
||||
base_name = name
|
||||
|
||||
filter_d = Path(config_dir) / "filter.d"
|
||||
conf_path = filter_d / f"{base_name}.conf"
|
||||
local_path = filter_d / f"{base_name}.local"
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def _read() -> tuple[str, bool, str]:
|
||||
"""Read filter content and return (content, has_local_override, source_path)."""
|
||||
has_local = local_path.is_file()
|
||||
if conf_path.is_file():
|
||||
content = conf_path.read_text(encoding="utf-8")
|
||||
if has_local:
|
||||
try:
|
||||
content += "\n" + local_path.read_text(encoding="utf-8")
|
||||
except OSError as exc:
|
||||
log.warning(
|
||||
"filter_local_read_error",
|
||||
name=base_name,
|
||||
path=str(local_path),
|
||||
error=str(exc),
|
||||
)
|
||||
return content, has_local, str(conf_path)
|
||||
elif has_local:
|
||||
# Local-only filter: created by the user, no shipped .conf base.
|
||||
content = local_path.read_text(encoding="utf-8")
|
||||
return content, False, str(local_path)
|
||||
else:
|
||||
raise FilterNotFoundError(base_name)
|
||||
|
||||
content, has_local, source_path = await loop.run_in_executor(None, _read)
|
||||
|
||||
cfg = conffile_parser.parse_filter_file(content, name=base_name, filename=f"{base_name}.conf")
|
||||
|
||||
all_jails_result, active_names = await asyncio.gather(
|
||||
loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)),
|
||||
_get_active_jail_names(socket_path),
|
||||
)
|
||||
all_jails, _source_files = all_jails_result
|
||||
filter_to_jails = _build_filter_to_jails_map(all_jails, active_names)
|
||||
|
||||
used_by = sorted(filter_to_jails.get(base_name, []))
|
||||
log.info("filter_fetched", name=base_name, active=len(used_by) > 0)
|
||||
return FilterConfig(
|
||||
name=cfg.name,
|
||||
filename=cfg.filename,
|
||||
before=cfg.before,
|
||||
after=cfg.after,
|
||||
variables=cfg.variables,
|
||||
prefregex=cfg.prefregex,
|
||||
failregex=cfg.failregex,
|
||||
ignoreregex=cfg.ignoreregex,
|
||||
maxlines=cfg.maxlines,
|
||||
datepattern=cfg.datepattern,
|
||||
journalmatch=cfg.journalmatch,
|
||||
active=len(used_by) > 0,
|
||||
used_by_jails=used_by,
|
||||
source_file=source_path,
|
||||
has_local_override=has_local,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API — filter write operations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def update_filter(
|
||||
config_dir: str,
|
||||
socket_path: str,
|
||||
name: str,
|
||||
req: FilterUpdateRequest,
|
||||
do_reload: bool = False,
|
||||
) -> FilterConfig:
|
||||
"""Update a filter's ``.local`` override with new regex/pattern values.
|
||||
|
||||
Reads the current merged configuration for *name* (``conf`` + any existing
|
||||
``local``), applies the non-``None`` fields in *req* on top of it, and
|
||||
writes the resulting definition to ``filter.d/{name}.local``. The
|
||||
original ``.conf`` file is never modified.
|
||||
|
||||
All regex patterns in *req* are validated with Python's ``re`` module
|
||||
before any write occurs.
|
||||
|
||||
Args:
|
||||
config_dir: Absolute path to the fail2ban configuration directory.
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
name: Filter base name (e.g. ``"sshd"`` or ``"sshd.conf"``).
|
||||
req: Partial update — only non-``None`` fields are applied.
|
||||
do_reload: When ``True``, trigger a full fail2ban reload after writing.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.FilterConfig` reflecting the updated state.
|
||||
|
||||
Raises:
|
||||
FilterNameError: If *name* contains invalid characters.
|
||||
FilterNotFoundError: If no ``{name}.conf`` or ``{name}.local`` exists.
|
||||
FilterInvalidRegexError: If any supplied regex pattern is invalid.
|
||||
ConfigWriteError: If writing the ``.local`` file fails.
|
||||
"""
|
||||
base_name = name[:-5] if name.endswith(".conf") or name.endswith(".local") else name
|
||||
_safe_filter_name(base_name)
|
||||
|
||||
# Validate regex patterns before touching the filesystem.
|
||||
patterns: list[str] = []
|
||||
if req.failregex is not None:
|
||||
patterns.extend(req.failregex)
|
||||
if req.ignoreregex is not None:
|
||||
patterns.extend(req.ignoreregex)
|
||||
_validate_regex_patterns(patterns)
|
||||
|
||||
# Fetch the current merged config (raises FilterNotFoundError if absent).
|
||||
current = await get_filter(config_dir, socket_path, base_name)
|
||||
|
||||
# Build a FilterConfigUpdate from the request fields.
|
||||
update = FilterConfigUpdate(
|
||||
failregex=req.failregex,
|
||||
ignoreregex=req.ignoreregex,
|
||||
datepattern=req.datepattern,
|
||||
journalmatch=req.journalmatch,
|
||||
)
|
||||
|
||||
merged = conffile_parser.merge_filter_update(current, update)
|
||||
content = conffile_parser.serialize_filter_config(merged)
|
||||
|
||||
filter_d = Path(config_dir) / "filter.d"
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, _write_filter_local_sync, filter_d, base_name, content)
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await reload_jails(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_filter_update_failed",
|
||||
filter=base_name,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
log.info("filter_updated", filter=base_name, reload=do_reload)
|
||||
return await get_filter(config_dir, socket_path, base_name)
|
||||
|
||||
|
||||
async def create_filter(
|
||||
config_dir: str,
|
||||
socket_path: str,
|
||||
req: FilterCreateRequest,
|
||||
do_reload: bool = False,
|
||||
) -> FilterConfig:
|
||||
"""Create a brand-new user-defined filter in ``filter.d/{name}.local``.
|
||||
|
||||
No ``.conf`` is written; fail2ban loads ``.local`` files directly. If a
|
||||
``.conf`` or ``.local`` file already exists for the requested name, a
|
||||
:class:`FilterAlreadyExistsError` is raised.
|
||||
|
||||
All regex patterns are validated with Python's ``re`` module before
|
||||
writing.
|
||||
|
||||
Args:
|
||||
config_dir: Absolute path to the fail2ban configuration directory.
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
req: Filter name and definition fields.
|
||||
do_reload: When ``True``, trigger a full fail2ban reload after writing.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.FilterConfig` for the newly created filter.
|
||||
|
||||
Raises:
|
||||
FilterNameError: If ``req.name`` contains invalid characters.
|
||||
FilterAlreadyExistsError: If a ``.conf`` or ``.local`` already exists.
|
||||
FilterInvalidRegexError: If any regex pattern is invalid.
|
||||
ConfigWriteError: If writing fails.
|
||||
"""
|
||||
_safe_filter_name(req.name)
|
||||
|
||||
filter_d = Path(config_dir) / "filter.d"
|
||||
conf_path = filter_d / f"{req.name}.conf"
|
||||
local_path = filter_d / f"{req.name}.local"
|
||||
|
||||
def _check_not_exists() -> None:
|
||||
if conf_path.is_file() or local_path.is_file():
|
||||
raise FilterAlreadyExistsError(req.name)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, _check_not_exists)
|
||||
|
||||
# Validate regex patterns.
|
||||
patterns: list[str] = list(req.failregex) + list(req.ignoreregex)
|
||||
_validate_regex_patterns(patterns)
|
||||
|
||||
# Build a FilterConfig and serialise it.
|
||||
cfg = FilterConfig(
|
||||
name=req.name,
|
||||
filename=f"{req.name}.local",
|
||||
failregex=req.failregex,
|
||||
ignoreregex=req.ignoreregex,
|
||||
prefregex=req.prefregex,
|
||||
datepattern=req.datepattern,
|
||||
journalmatch=req.journalmatch,
|
||||
)
|
||||
content = conffile_parser.serialize_filter_config(cfg)
|
||||
|
||||
await loop.run_in_executor(None, _write_filter_local_sync, filter_d, req.name, content)
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await reload_jails(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_filter_create_failed",
|
||||
filter=req.name,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
log.info("filter_created", filter=req.name, reload=do_reload)
|
||||
# Re-fetch to get the canonical FilterConfig (source_file, active, etc.).
|
||||
return await get_filter(config_dir, socket_path, req.name)
|
||||
|
||||
|
||||
async def delete_filter(
|
||||
config_dir: str,
|
||||
name: str,
|
||||
) -> None:
|
||||
"""Delete a user-created filter's ``.local`` file.
|
||||
|
||||
Deletion rules:
|
||||
- If only a ``.conf`` file exists (shipped default, no user override) →
|
||||
:class:`FilterReadonlyError`.
|
||||
- If a ``.local`` file exists (whether or not a ``.conf`` also exists) →
|
||||
the ``.local`` file is deleted. The shipped ``.conf`` is never touched.
|
||||
- If neither file exists → :class:`FilterNotFoundError`.
|
||||
|
||||
Args:
|
||||
config_dir: Absolute path to the fail2ban configuration directory.
|
||||
name: Filter base name (e.g. ``"sshd"``).
|
||||
|
||||
Raises:
|
||||
FilterNameError: If *name* contains invalid characters.
|
||||
FilterNotFoundError: If no filter file is found for *name*.
|
||||
FilterReadonlyError: If only a shipped ``.conf`` exists (no ``.local``).
|
||||
ConfigWriteError: If deletion of the ``.local`` file fails.
|
||||
"""
|
||||
base_name = name[:-5] if name.endswith(".conf") or name.endswith(".local") else name
|
||||
_safe_filter_name(base_name)
|
||||
|
||||
filter_d = Path(config_dir) / "filter.d"
|
||||
conf_path = filter_d / f"{base_name}.conf"
|
||||
local_path = filter_d / f"{base_name}.local"
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def _delete() -> None:
|
||||
has_conf = conf_path.is_file()
|
||||
has_local = local_path.is_file()
|
||||
|
||||
if not has_conf and not has_local:
|
||||
raise FilterNotFoundError(base_name)
|
||||
|
||||
if has_conf and not has_local:
|
||||
# Shipped default — nothing user-writable to remove.
|
||||
raise FilterReadonlyError(base_name)
|
||||
|
||||
try:
|
||||
local_path.unlink()
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(f"Failed to delete {local_path}: {exc}") from exc
|
||||
|
||||
log.info("filter_local_deleted", filter=base_name, path=str(local_path))
|
||||
|
||||
await loop.run_in_executor(None, _delete)
|
||||
|
||||
|
||||
async def assign_filter_to_jail(
|
||||
config_dir: str,
|
||||
socket_path: str,
|
||||
jail_name: str,
|
||||
req: AssignFilterRequest,
|
||||
do_reload: bool = False,
|
||||
) -> None:
|
||||
"""Assign a filter to a jail by updating the jail's ``.local`` file.
|
||||
|
||||
Writes ``filter = {req.filter_name}`` into the ``[{jail_name}]`` section
|
||||
of ``jail.d/{jail_name}.local``. If the ``.local`` file already contains
|
||||
other settings for this jail they are preserved.
|
||||
|
||||
Args:
|
||||
config_dir: Absolute path to the fail2ban configuration directory.
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
jail_name: Name of the jail to update.
|
||||
req: Request containing the filter name to assign.
|
||||
do_reload: When ``True``, trigger a full fail2ban reload after writing.
|
||||
|
||||
Raises:
|
||||
JailNameError: If *jail_name* contains invalid characters.
|
||||
FilterNameError: If ``req.filter_name`` contains invalid characters.
|
||||
JailNotFoundError: If *jail_name* is not defined in any config file.
|
||||
FilterNotFoundError: If ``req.filter_name`` does not exist in
|
||||
``filter.d/``.
|
||||
ConfigWriteError: If writing fails.
|
||||
"""
|
||||
_safe_jail_name(jail_name)
|
||||
_safe_filter_name(req.filter_name)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Verify the jail exists in config.
|
||||
all_jails, _src = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
||||
if jail_name not in all_jails:
|
||||
raise JailNotFoundInConfigError(jail_name)
|
||||
|
||||
# Verify the filter exists (conf or local).
|
||||
filter_d = Path(config_dir) / "filter.d"
|
||||
|
||||
def _check_filter() -> None:
|
||||
conf_exists = (filter_d / f"{req.filter_name}.conf").is_file()
|
||||
local_exists = (filter_d / f"{req.filter_name}.local").is_file()
|
||||
if not conf_exists and not local_exists:
|
||||
raise FilterNotFoundError(req.filter_name)
|
||||
|
||||
await loop.run_in_executor(None, _check_filter)
|
||||
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
_set_jail_local_key_sync,
|
||||
Path(config_dir),
|
||||
jail_name,
|
||||
"filter",
|
||||
req.filter_name,
|
||||
)
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await reload_jails(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_assign_filter_failed",
|
||||
jail=jail_name,
|
||||
filter=req.filter_name,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
log.info(
|
||||
"filter_assigned_to_jail",
|
||||
jail=jail_name,
|
||||
filter=req.filter_name,
|
||||
reload=do_reload,
|
||||
)
|
||||
@@ -20,9 +20,7 @@ Usage::
|
||||
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
from app.services import geo_service
|
||||
|
||||
# warm the cache from the persistent store at startup
|
||||
# Use the geo_service directly in application startup
|
||||
async with aiosqlite.connect("bangui.db") as db:
|
||||
await geo_service.load_cache_from_db(db)
|
||||
|
||||
@@ -30,7 +28,8 @@ Usage::
|
||||
# single lookup
|
||||
info = await geo_service.lookup("1.2.3.4", session)
|
||||
if info:
|
||||
print(info.country_code) # "DE"
|
||||
# info.country_code == "DE"
|
||||
... # use the GeoInfo object in your application
|
||||
|
||||
# bulk lookup (more efficient for large sets)
|
||||
geo_map = await geo_service.lookup_batch(["1.2.3.4", "5.6.7.8"], session)
|
||||
@@ -40,12 +39,14 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
import structlog
|
||||
|
||||
from app.models.geo import GeoInfo
|
||||
from app.repositories import geo_cache_repo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiosqlite
|
||||
import geoip2.database
|
||||
@@ -90,32 +91,6 @@ _BATCH_DELAY: float = 1.5
|
||||
#: transient error (e.g. connection reset due to rate limiting).
|
||||
_BATCH_MAX_RETRIES: int = 2
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Domain model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeoInfo:
|
||||
"""Geographical and network metadata for a single IP address.
|
||||
|
||||
All fields default to ``None`` when the information is unavailable or
|
||||
the lookup fails gracefully.
|
||||
"""
|
||||
|
||||
country_code: str | None
|
||||
"""ISO 3166-1 alpha-2 country code, e.g. ``"DE"``."""
|
||||
|
||||
country_name: str | None
|
||||
"""Human-readable country name, e.g. ``"Germany"``."""
|
||||
|
||||
asn: str | None
|
||||
"""Autonomous System Number string, e.g. ``"AS3320"``."""
|
||||
|
||||
org: str | None
|
||||
"""Organisation name associated with the IP, e.g. ``"Deutsche Telekom"``."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal cache
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -184,11 +159,7 @@ async def cache_stats(db: aiosqlite.Connection) -> dict[str, int]:
|
||||
Dict with keys ``cache_size``, ``unresolved``, ``neg_cache_size``,
|
||||
and ``dirty_size``.
|
||||
"""
|
||||
async with db.execute(
|
||||
"SELECT COUNT(*) FROM geo_cache WHERE country_code IS NULL"
|
||||
) as cur:
|
||||
row = await cur.fetchone()
|
||||
unresolved: int = int(row[0]) if row else 0
|
||||
unresolved = await geo_cache_repo.count_unresolved(db)
|
||||
|
||||
return {
|
||||
"cache_size": len(_cache),
|
||||
@@ -198,6 +169,24 @@ async def cache_stats(db: aiosqlite.Connection) -> dict[str, int]:
|
||||
}
|
||||
|
||||
|
||||
async def count_unresolved(db: aiosqlite.Connection) -> int:
|
||||
"""Return the number of unresolved entries in the persistent geo cache."""
|
||||
|
||||
return await geo_cache_repo.count_unresolved(db)
|
||||
|
||||
|
||||
async def get_unresolved_ips(db: aiosqlite.Connection) -> list[str]:
|
||||
"""Return geo cache IPs where the country code has not yet been resolved.
|
||||
|
||||
Args:
|
||||
db: Open BanGUI application database connection.
|
||||
|
||||
Returns:
|
||||
List of IP addresses that are candidates for re-resolution.
|
||||
"""
|
||||
return await geo_cache_repo.get_unresolved_ips(db)
|
||||
|
||||
|
||||
def init_geoip(mmdb_path: str | None) -> None:
|
||||
"""Initialise the MaxMind GeoLite2-Country database reader.
|
||||
|
||||
@@ -268,21 +257,18 @@ async def load_cache_from_db(db: aiosqlite.Connection) -> None:
|
||||
database (not the fail2ban database).
|
||||
"""
|
||||
count = 0
|
||||
async with db.execute(
|
||||
"SELECT ip, country_code, country_name, asn, org FROM geo_cache"
|
||||
) as cur:
|
||||
async for row in cur:
|
||||
ip: str = str(row[0])
|
||||
country_code: str | None = row[1]
|
||||
if country_code is None:
|
||||
continue
|
||||
_cache[ip] = GeoInfo(
|
||||
country_code=country_code,
|
||||
country_name=row[2],
|
||||
asn=row[3],
|
||||
org=row[4],
|
||||
)
|
||||
count += 1
|
||||
for row in await geo_cache_repo.load_all(db):
|
||||
country_code: str | None = row["country_code"]
|
||||
if country_code is None:
|
||||
continue
|
||||
ip: str = row["ip"]
|
||||
_cache[ip] = GeoInfo(
|
||||
country_code=country_code,
|
||||
country_name=row["country_name"],
|
||||
asn=row["asn"],
|
||||
org=row["org"],
|
||||
)
|
||||
count += 1
|
||||
log.info("geo_cache_loaded_from_db", entries=count)
|
||||
|
||||
|
||||
@@ -301,18 +287,13 @@ async def _persist_entry(
|
||||
ip: IP address string.
|
||||
info: Resolved geo data to persist.
|
||||
"""
|
||||
await db.execute(
|
||||
"""
|
||||
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(ip) DO UPDATE SET
|
||||
country_code = excluded.country_code,
|
||||
country_name = excluded.country_name,
|
||||
asn = excluded.asn,
|
||||
org = excluded.org,
|
||||
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
|
||||
""",
|
||||
(ip, info.country_code, info.country_name, info.asn, info.org),
|
||||
await geo_cache_repo.upsert_entry(
|
||||
db=db,
|
||||
ip=ip,
|
||||
country_code=info.country_code,
|
||||
country_name=info.country_name,
|
||||
asn=info.asn,
|
||||
org=info.org,
|
||||
)
|
||||
|
||||
|
||||
@@ -326,10 +307,7 @@ async def _persist_neg_entry(db: aiosqlite.Connection, ip: str) -> None:
|
||||
db: BanGUI application database connection.
|
||||
ip: IP address string whose resolution failed.
|
||||
"""
|
||||
await db.execute(
|
||||
"INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)",
|
||||
(ip,),
|
||||
)
|
||||
await geo_cache_repo.upsert_neg_entry(db=db, ip=ip)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -585,19 +563,7 @@ async def lookup_batch(
|
||||
if db is not None:
|
||||
if pos_rows:
|
||||
try:
|
||||
await db.executemany(
|
||||
"""
|
||||
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(ip) DO UPDATE SET
|
||||
country_code = excluded.country_code,
|
||||
country_name = excluded.country_name,
|
||||
asn = excluded.asn,
|
||||
org = excluded.org,
|
||||
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
|
||||
""",
|
||||
pos_rows,
|
||||
)
|
||||
await geo_cache_repo.bulk_upsert_entries(db, pos_rows)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"geo_batch_persist_failed",
|
||||
@@ -606,10 +572,7 @@ async def lookup_batch(
|
||||
)
|
||||
if neg_ips:
|
||||
try:
|
||||
await db.executemany(
|
||||
"INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)",
|
||||
[(ip,) for ip in neg_ips],
|
||||
)
|
||||
await geo_cache_repo.bulk_upsert_neg_entries(db, neg_ips)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"geo_batch_persist_neg_failed",
|
||||
@@ -792,19 +755,7 @@ async def flush_dirty(db: aiosqlite.Connection) -> int:
|
||||
return 0
|
||||
|
||||
try:
|
||||
await db.executemany(
|
||||
"""
|
||||
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(ip) DO UPDATE SET
|
||||
country_code = excluded.country_code,
|
||||
country_name = excluded.country_name,
|
||||
asn = excluded.asn,
|
||||
org = excluded.org,
|
||||
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
|
||||
""",
|
||||
rows,
|
||||
)
|
||||
await geo_cache_repo.bulk_upsert_entries(db, rows)
|
||||
await db.commit()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("geo_flush_dirty_failed", error=str(exc))
|
||||
|
||||
@@ -9,12 +9,17 @@ seconds by the background health-check task, not on every HTTP request.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import structlog
|
||||
|
||||
from app.models.server import ServerStatus
|
||||
from app.utils.fail2ban_client import Fail2BanClient, Fail2BanConnectionError, Fail2BanProtocolError
|
||||
from app.utils.fail2ban_client import (
|
||||
Fail2BanClient,
|
||||
Fail2BanConnectionError,
|
||||
Fail2BanProtocolError,
|
||||
Fail2BanResponse,
|
||||
)
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
@@ -25,7 +30,7 @@ log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
_SOCKET_TIMEOUT: float = 5.0
|
||||
|
||||
|
||||
def _ok(response: Any) -> Any:
|
||||
def _ok(response: object) -> object:
|
||||
"""Extract the payload from a fail2ban ``(return_code, data)`` response.
|
||||
|
||||
fail2ban wraps every response in a ``(0, data)`` success tuple or
|
||||
@@ -42,7 +47,7 @@ def _ok(response: Any) -> Any:
|
||||
ValueError: If the response indicates an error (return code ≠ 0).
|
||||
"""
|
||||
try:
|
||||
code, data = response
|
||||
code, data = cast("Fail2BanResponse", response)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
|
||||
|
||||
@@ -52,7 +57,7 @@ def _ok(response: Any) -> Any:
|
||||
return data
|
||||
|
||||
|
||||
def _to_dict(pairs: Any) -> dict[str, Any]:
|
||||
def _to_dict(pairs: object) -> dict[str, object]:
|
||||
"""Convert a list of ``(key, value)`` pairs to a plain dict.
|
||||
|
||||
fail2ban returns structured data as lists of 2-tuples rather than dicts.
|
||||
@@ -66,7 +71,7 @@ def _to_dict(pairs: Any) -> dict[str, Any]:
|
||||
"""
|
||||
if not isinstance(pairs, (list, tuple)):
|
||||
return {}
|
||||
result: dict[str, Any] = {}
|
||||
result: dict[str, object] = {}
|
||||
for item in pairs:
|
||||
try:
|
||||
k, v = item
|
||||
@@ -119,7 +124,7 @@ async def probe(socket_path: str, timeout: float = _SOCKET_TIMEOUT) -> ServerSta
|
||||
# 3. Global status — jail count and names #
|
||||
# ------------------------------------------------------------------ #
|
||||
status_data = _to_dict(_ok(await client.send(["status"])))
|
||||
active_jails: int = int(status_data.get("Number of jail", 0) or 0)
|
||||
active_jails: int = int(str(status_data.get("Number of jail", 0) or 0))
|
||||
jail_list_raw: str = str(status_data.get("Jail list", "") or "").strip()
|
||||
jail_names: list[str] = (
|
||||
[j.strip() for j in jail_list_raw.split(",") if j.strip()]
|
||||
@@ -138,8 +143,8 @@ async def probe(socket_path: str, timeout: float = _SOCKET_TIMEOUT) -> ServerSta
|
||||
jail_resp = _to_dict(_ok(await client.send(["status", jail_name])))
|
||||
filter_stats = _to_dict(jail_resp.get("Filter") or [])
|
||||
action_stats = _to_dict(jail_resp.get("Actions") or [])
|
||||
total_failures += int(filter_stats.get("Currently failed", 0) or 0)
|
||||
total_bans += int(action_stats.get("Currently banned", 0) or 0)
|
||||
total_failures += int(str(filter_stats.get("Currently failed", 0) or 0))
|
||||
total_bans += int(str(action_stats.get("Currently banned", 0) or 0))
|
||||
except (ValueError, TypeError, KeyError) as exc:
|
||||
log.warning(
|
||||
"fail2ban_jail_status_parse_error",
|
||||
|
||||
@@ -11,19 +11,22 @@ modifies or locks the fail2ban database.
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiosqlite
|
||||
import structlog
|
||||
|
||||
from app.models.ban import TIME_RANGE_SECONDS, TimeRange
|
||||
if TYPE_CHECKING:
|
||||
from app.models.geo import GeoEnricher
|
||||
|
||||
from app.models.ban import TIME_RANGE_SECONDS, BanOrigin, TimeRange
|
||||
from app.models.history import (
|
||||
HistoryBanItem,
|
||||
HistoryListResponse,
|
||||
IpDetailResponse,
|
||||
IpTimelineEvent,
|
||||
)
|
||||
from app.services.ban_service import _get_fail2ban_db_path, _parse_data_json, _ts_to_iso
|
||||
from app.repositories import fail2ban_db_repo
|
||||
from app.utils.fail2ban_db_utils import get_fail2ban_db_path, parse_data_json, ts_to_iso
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
@@ -59,9 +62,10 @@ async def list_history(
|
||||
range_: TimeRange | None = None,
|
||||
jail: str | None = None,
|
||||
ip_filter: str | None = None,
|
||||
origin: BanOrigin | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = _DEFAULT_PAGE_SIZE,
|
||||
geo_enricher: Any | None = None,
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
) -> HistoryListResponse:
|
||||
"""Return a paginated list of historical ban records with optional filters.
|
||||
|
||||
@@ -84,28 +88,13 @@ async def list_history(
|
||||
and the total matching count.
|
||||
"""
|
||||
effective_page_size: int = min(page_size, _MAX_PAGE_SIZE)
|
||||
offset: int = (page - 1) * effective_page_size
|
||||
|
||||
# Build WHERE clauses dynamically.
|
||||
wheres: list[str] = []
|
||||
params: list[Any] = []
|
||||
|
||||
since: int | None = None
|
||||
if range_ is not None:
|
||||
since: int = _since_unix(range_)
|
||||
wheres.append("timeofban >= ?")
|
||||
params.append(since)
|
||||
since = _since_unix(range_)
|
||||
|
||||
if jail is not None:
|
||||
wheres.append("jail = ?")
|
||||
params.append(jail)
|
||||
|
||||
if ip_filter is not None:
|
||||
wheres.append("ip LIKE ?")
|
||||
params.append(f"{ip_filter}%")
|
||||
|
||||
where_sql: str = ("WHERE " + " AND ".join(wheres)) if wheres else ""
|
||||
|
||||
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||
db_path: str = await get_fail2ban_db_path(socket_path)
|
||||
log.info(
|
||||
"history_service_list",
|
||||
db_path=db_path,
|
||||
@@ -115,32 +104,23 @@ async def list_history(
|
||||
page=page,
|
||||
)
|
||||
|
||||
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
|
||||
f2b_db.row_factory = aiosqlite.Row
|
||||
|
||||
async with f2b_db.execute(
|
||||
f"SELECT COUNT(*) FROM bans {where_sql}", # noqa: S608
|
||||
params,
|
||||
) as cur:
|
||||
count_row = await cur.fetchone()
|
||||
total: int = int(count_row[0]) if count_row else 0
|
||||
|
||||
async with f2b_db.execute(
|
||||
f"SELECT jail, ip, timeofban, bancount, data " # noqa: S608
|
||||
f"FROM bans {where_sql} "
|
||||
"ORDER BY timeofban DESC "
|
||||
"LIMIT ? OFFSET ?",
|
||||
[*params, effective_page_size, offset],
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
rows, total = await fail2ban_db_repo.get_history_page(
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
jail=jail,
|
||||
ip_filter=ip_filter,
|
||||
origin=origin,
|
||||
page=page,
|
||||
page_size=effective_page_size,
|
||||
)
|
||||
|
||||
items: list[HistoryBanItem] = []
|
||||
for row in rows:
|
||||
jail_name: str = str(row["jail"])
|
||||
ip: str = str(row["ip"])
|
||||
banned_at: str = _ts_to_iso(int(row["timeofban"]))
|
||||
ban_count: int = int(row["bancount"])
|
||||
matches, failures = _parse_data_json(row["data"])
|
||||
jail_name: str = row.jail
|
||||
ip: str = row.ip
|
||||
banned_at: str = ts_to_iso(row.timeofban)
|
||||
ban_count: int = row.bancount
|
||||
matches, failures = parse_data_json(row.data)
|
||||
|
||||
country_code: str | None = None
|
||||
country_name: str | None = None
|
||||
@@ -185,7 +165,7 @@ async def get_ip_detail(
|
||||
socket_path: str,
|
||||
ip: str,
|
||||
*,
|
||||
geo_enricher: Any | None = None,
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
) -> IpDetailResponse | None:
|
||||
"""Return the full historical record for a single IP address.
|
||||
|
||||
@@ -202,19 +182,10 @@ async def get_ip_detail(
|
||||
:class:`~app.models.history.IpDetailResponse` if any records exist
|
||||
for *ip*, or ``None`` if the IP has no history in the database.
|
||||
"""
|
||||
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||
db_path: str = await get_fail2ban_db_path(socket_path)
|
||||
log.info("history_service_ip_detail", db_path=db_path, ip=ip)
|
||||
|
||||
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
|
||||
f2b_db.row_factory = aiosqlite.Row
|
||||
async with f2b_db.execute(
|
||||
"SELECT jail, ip, timeofban, bancount, data "
|
||||
"FROM bans "
|
||||
"WHERE ip = ? "
|
||||
"ORDER BY timeofban DESC",
|
||||
(ip,),
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
rows = await fail2ban_db_repo.get_history_for_ip(db_path=db_path, ip=ip)
|
||||
|
||||
if not rows:
|
||||
return None
|
||||
@@ -223,10 +194,10 @@ async def get_ip_detail(
|
||||
total_failures: int = 0
|
||||
|
||||
for row in rows:
|
||||
jail_name: str = str(row["jail"])
|
||||
banned_at: str = _ts_to_iso(int(row["timeofban"]))
|
||||
ban_count: int = int(row["bancount"])
|
||||
matches, failures = _parse_data_json(row["data"])
|
||||
jail_name: str = row.jail
|
||||
banned_at: str = ts_to_iso(row.timeofban)
|
||||
ban_count: int = row.bancount
|
||||
matches, failures = parse_data_json(row.data)
|
||||
total_failures += failures
|
||||
timeline.append(
|
||||
IpTimelineEvent(
|
||||
|
||||
993
backend/app/services/jail_config_service.py
Normal file
993
backend/app/services/jail_config_service.py
Normal file
@@ -0,0 +1,993 @@
|
||||
"""Jail configuration management for BanGUI.
|
||||
|
||||
Handles parsing, validation, and lifecycle operations (activate/deactivate)
|
||||
for fail2ban jail configurations. Provides functions to discover inactive
|
||||
jails, validate their configurations before activation, and manage jail
|
||||
overrides in jail.d/*.local files.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import configparser
|
||||
import contextlib
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import structlog
|
||||
|
||||
from app.exceptions import JailNotFoundError
|
||||
from app.models.config import (
|
||||
ActivateJailRequest,
|
||||
InactiveJail,
|
||||
InactiveJailListResponse,
|
||||
JailActivationResponse,
|
||||
JailValidationResult,
|
||||
RollbackResponse,
|
||||
)
|
||||
from app.utils.config_file_utils import (
|
||||
_build_inactive_jail,
|
||||
_get_active_jail_names,
|
||||
_parse_jails_sync,
|
||||
_validate_jail_config_sync,
|
||||
)
|
||||
from app.utils.fail2ban_client import Fail2BanClient
|
||||
from app.utils.jail_utils import reload_jails
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SOCKET_TIMEOUT: float = 10.0
|
||||
|
||||
# Allowlist pattern for jail names used in path construction.
|
||||
_SAFE_JAIL_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
|
||||
|
||||
# Sections that are not jail definitions.
|
||||
_META_SECTIONS: frozenset[str] = frozenset({"INCLUDES", "DEFAULT"})
|
||||
|
||||
# True-ish values for the ``enabled`` key.
|
||||
_TRUE_VALUES: frozenset[str] = frozenset({"true", "yes", "1"})
|
||||
|
||||
# False-ish values for the ``enabled`` key.
|
||||
_FALSE_VALUES: frozenset[str] = frozenset({"false", "no", "0"})
|
||||
|
||||
# Seconds to wait between fail2ban liveness probes after a reload.
|
||||
_POST_RELOAD_PROBE_INTERVAL: float = 2.0
|
||||
|
||||
# Maximum number of post-reload probe attempts (initial attempt + retries).
|
||||
_POST_RELOAD_MAX_ATTEMPTS: int = 4
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Custom exceptions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class JailNotFoundInConfigError(Exception):
|
||||
"""Raised when the requested jail name is not defined in any config file."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""Initialise with the jail name that was not found.
|
||||
|
||||
Args:
|
||||
name: The jail name that could not be located.
|
||||
"""
|
||||
self.name: str = name
|
||||
super().__init__(f"Jail not found in config files: {name!r}")
|
||||
|
||||
|
||||
class JailAlreadyActiveError(Exception):
|
||||
"""Raised when trying to activate a jail that is already active."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""Initialise with the jail name.
|
||||
|
||||
Args:
|
||||
name: The jail that is already active.
|
||||
"""
|
||||
self.name: str = name
|
||||
super().__init__(f"Jail is already active: {name!r}")
|
||||
|
||||
|
||||
class JailAlreadyInactiveError(Exception):
|
||||
"""Raised when trying to deactivate a jail that is already inactive."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""Initialise with the jail name.
|
||||
|
||||
Args:
|
||||
name: The jail that is already inactive.
|
||||
"""
|
||||
self.name: str = name
|
||||
super().__init__(f"Jail is already inactive: {name!r}")
|
||||
|
||||
|
||||
class JailNameError(Exception):
|
||||
"""Raised when a jail name contains invalid characters."""
|
||||
|
||||
|
||||
class ConfigWriteError(Exception):
|
||||
"""Raised when writing a ``.local`` override file fails."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _safe_jail_name(name: str) -> str:
|
||||
"""Validate *name* and return it unchanged or raise :class:`JailNameError`.
|
||||
|
||||
Args:
|
||||
name: Proposed jail name.
|
||||
|
||||
Returns:
|
||||
The name unchanged if valid.
|
||||
|
||||
Raises:
|
||||
JailNameError: If *name* contains unsafe characters.
|
||||
"""
|
||||
if not _SAFE_JAIL_NAME_RE.match(name):
|
||||
raise JailNameError(
|
||||
f"Jail name {name!r} contains invalid characters. "
|
||||
"Only alphanumeric characters, hyphens, underscores, and dots are "
|
||||
"allowed; must start with an alphanumeric character."
|
||||
)
|
||||
return name
|
||||
|
||||
|
||||
def _build_parser() -> configparser.RawConfigParser:
|
||||
"""Create a :class:`configparser.RawConfigParser` for fail2ban configs.
|
||||
|
||||
Returns:
|
||||
Parser with interpolation disabled and case-sensitive option names.
|
||||
"""
|
||||
parser = configparser.RawConfigParser(interpolation=None, strict=False)
|
||||
# fail2ban keys are lowercase but preserve case to be safe.
|
||||
parser.optionxform = str # type: ignore[assignment]
|
||||
return parser
|
||||
|
||||
|
||||
def _is_truthy(value: str) -> bool:
|
||||
"""Return ``True`` if *value* is a fail2ban boolean true string.
|
||||
|
||||
Args:
|
||||
value: Raw string from config (e.g. ``"true"``, ``"yes"``, ``"1"``).
|
||||
|
||||
Returns:
|
||||
``True`` when the value represents enabled.
|
||||
"""
|
||||
return value.strip().lower() in _TRUE_VALUES
|
||||
|
||||
|
||||
def _write_local_override_sync(
|
||||
config_dir: Path,
|
||||
jail_name: str,
|
||||
enabled: bool,
|
||||
overrides: dict[str, object],
|
||||
) -> None:
|
||||
"""Write a ``jail.d/{name}.local`` file atomically.
|
||||
|
||||
Always writes to ``jail.d/{jail_name}.local``. If the file already
|
||||
exists it is replaced entirely. The write is atomic: content is
|
||||
written to a temp file first, then renamed into place.
|
||||
|
||||
Args:
|
||||
config_dir: The fail2ban configuration root directory.
|
||||
jail_name: Validated jail name (used as filename stem).
|
||||
enabled: Value to write for ``enabled =``.
|
||||
overrides: Optional setting overrides (bantime, findtime, maxretry,
|
||||
port, logpath).
|
||||
|
||||
Raises:
|
||||
ConfigWriteError: If writing fails.
|
||||
"""
|
||||
jail_d = config_dir / "jail.d"
|
||||
try:
|
||||
jail_d.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc
|
||||
|
||||
local_path = jail_d / f"{jail_name}.local"
|
||||
|
||||
lines: list[str] = [
|
||||
"# Managed by BanGUI — do not edit manually",
|
||||
"",
|
||||
f"[{jail_name}]",
|
||||
"",
|
||||
f"enabled = {'true' if enabled else 'false'}",
|
||||
# Provide explicit banaction defaults so fail2ban can resolve the
|
||||
# %(banaction)s interpolation used in the built-in action_ chain.
|
||||
"banaction = iptables-multiport",
|
||||
"banaction_allports = iptables-allports",
|
||||
]
|
||||
|
||||
if overrides.get("bantime") is not None:
|
||||
lines.append(f"bantime = {overrides['bantime']}")
|
||||
if overrides.get("findtime") is not None:
|
||||
lines.append(f"findtime = {overrides['findtime']}")
|
||||
if overrides.get("maxretry") is not None:
|
||||
lines.append(f"maxretry = {overrides['maxretry']}")
|
||||
if overrides.get("port") is not None:
|
||||
lines.append(f"port = {overrides['port']}")
|
||||
if overrides.get("logpath"):
|
||||
paths: list[str] = cast("list[str]", overrides["logpath"])
|
||||
if paths:
|
||||
lines.append(f"logpath = {paths[0]}")
|
||||
for p in paths[1:]:
|
||||
lines.append(f" {p}")
|
||||
|
||||
content = "\n".join(lines) + "\n"
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
encoding="utf-8",
|
||||
dir=jail_d,
|
||||
delete=False,
|
||||
suffix=".tmp",
|
||||
) as tmp:
|
||||
tmp.write(content)
|
||||
tmp_name = tmp.name
|
||||
os.replace(tmp_name, local_path)
|
||||
except OSError as exc:
|
||||
# Clean up temp file if rename failed.
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(tmp_name) # noqa: F821 — only reachable when tmp_name is set
|
||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
||||
|
||||
log.info(
|
||||
"jail_local_written",
|
||||
jail=jail_name,
|
||||
path=str(local_path),
|
||||
enabled=enabled,
|
||||
)
|
||||
|
||||
|
||||
def _restore_local_file_sync(local_path: Path, original_content: bytes | None) -> None:
|
||||
"""Restore a ``.local`` file to its pre-activation state.
|
||||
|
||||
If *original_content* is ``None``, the file is deleted (it did not exist
|
||||
before the activation). Otherwise the original bytes are written back
|
||||
atomically via a temp-file rename.
|
||||
|
||||
Args:
|
||||
local_path: Absolute path to the ``.local`` file to restore.
|
||||
original_content: Original raw bytes to write back, or ``None`` to
|
||||
delete the file.
|
||||
|
||||
Raises:
|
||||
ConfigWriteError: If the write or delete operation fails.
|
||||
"""
|
||||
if original_content is None:
|
||||
try:
|
||||
local_path.unlink(missing_ok=True)
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(f"Failed to delete {local_path} during rollback: {exc}") from exc
|
||||
return
|
||||
|
||||
tmp_name: str | None = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="wb",
|
||||
dir=local_path.parent,
|
||||
delete=False,
|
||||
suffix=".tmp",
|
||||
) as tmp:
|
||||
tmp.write(original_content)
|
||||
tmp_name = tmp.name
|
||||
os.replace(tmp_name, local_path)
|
||||
except OSError as exc:
|
||||
with contextlib.suppress(OSError):
|
||||
if tmp_name is not None:
|
||||
os.unlink(tmp_name)
|
||||
raise ConfigWriteError(f"Failed to restore {local_path} during rollback: {exc}") from exc
|
||||
|
||||
|
||||
def _validate_regex_patterns(patterns: list[str]) -> None:
|
||||
"""Validate each pattern in *patterns* using Python's ``re`` module.
|
||||
|
||||
Args:
|
||||
patterns: List of regex strings to validate.
|
||||
|
||||
Raises:
|
||||
FilterInvalidRegexError: If any pattern fails to compile.
|
||||
"""
|
||||
for pattern in patterns:
|
||||
try:
|
||||
re.compile(pattern)
|
||||
except re.error as exc:
|
||||
# Import here to avoid circular dependency
|
||||
from app.exceptions import FilterInvalidRegexError
|
||||
raise FilterInvalidRegexError(pattern, str(exc)) from exc
|
||||
|
||||
|
||||
def _set_jail_local_key_sync(
|
||||
config_dir: Path,
|
||||
jail_name: str,
|
||||
key: str,
|
||||
value: str,
|
||||
) -> None:
|
||||
"""Update ``jail.d/{jail_name}.local`` to set a single key in the jail section.
|
||||
|
||||
If the ``.local`` file already exists it is read, the key is updated (or
|
||||
added), and the file is written back atomically without disturbing other
|
||||
settings. If the file does not exist a new one is created containing
|
||||
only the BanGUI header comment, the jail section, and the requested key.
|
||||
|
||||
Args:
|
||||
config_dir: The fail2ban configuration root directory.
|
||||
jail_name: Validated jail name (used as section name and filename stem).
|
||||
key: Config key to set inside the jail section.
|
||||
value: Config value to assign.
|
||||
|
||||
Raises:
|
||||
ConfigWriteError: If writing fails.
|
||||
"""
|
||||
jail_d = config_dir / "jail.d"
|
||||
try:
|
||||
jail_d.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc
|
||||
|
||||
local_path = jail_d / f"{jail_name}.local"
|
||||
|
||||
parser = _build_parser()
|
||||
if local_path.is_file():
|
||||
try:
|
||||
parser.read(str(local_path), encoding="utf-8")
|
||||
except (configparser.Error, OSError) as exc:
|
||||
log.warning(
|
||||
"jail_local_read_for_update_error",
|
||||
jail=jail_name,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
if not parser.has_section(jail_name):
|
||||
parser.add_section(jail_name)
|
||||
parser.set(jail_name, key, value)
|
||||
|
||||
# Serialize: write a BanGUI header then the parser output.
|
||||
buf = io.StringIO()
|
||||
buf.write("# Managed by BanGUI — do not edit manually\n\n")
|
||||
parser.write(buf)
|
||||
content = buf.getvalue()
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
encoding="utf-8",
|
||||
dir=jail_d,
|
||||
delete=False,
|
||||
suffix=".tmp",
|
||||
) as tmp:
|
||||
tmp.write(content)
|
||||
tmp_name = tmp.name
|
||||
os.replace(tmp_name, local_path)
|
||||
except OSError as exc:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(tmp_name) # noqa: F821
|
||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
||||
|
||||
log.info(
|
||||
"jail_local_key_set",
|
||||
jail=jail_name,
|
||||
key=key,
|
||||
path=str(local_path),
|
||||
)
|
||||
|
||||
|
||||
async def _probe_fail2ban_running(socket_path: str) -> bool:
|
||||
"""Return ``True`` if the fail2ban socket responds to a ping.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
|
||||
Returns:
|
||||
``True`` when fail2ban is reachable, ``False`` otherwise.
|
||||
"""
|
||||
try:
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=5.0)
|
||||
resp = await client.send(["ping"])
|
||||
return isinstance(resp, (list, tuple)) and resp[0] == 0
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
|
||||
async def wait_for_fail2ban(
|
||||
socket_path: str,
|
||||
max_wait_seconds: float = 10.0,
|
||||
poll_interval: float = 2.0,
|
||||
) -> bool:
|
||||
"""Poll the fail2ban socket until it responds or the timeout expires.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
max_wait_seconds: Total time budget in seconds.
|
||||
poll_interval: Delay between probe attempts in seconds.
|
||||
|
||||
Returns:
|
||||
``True`` if fail2ban came online within the budget.
|
||||
"""
|
||||
elapsed = 0.0
|
||||
while elapsed < max_wait_seconds:
|
||||
if await _probe_fail2ban_running(socket_path):
|
||||
return True
|
||||
await asyncio.sleep(poll_interval)
|
||||
elapsed += poll_interval
|
||||
return False
|
||||
|
||||
|
||||
async def start_daemon(start_cmd_parts: list[str]) -> bool:
|
||||
"""Start the fail2ban daemon using *start_cmd_parts*.
|
||||
|
||||
Uses :func:`asyncio.create_subprocess_exec` (no shell interpretation)
|
||||
to avoid command injection.
|
||||
|
||||
Args:
|
||||
start_cmd_parts: Command and arguments, e.g.
|
||||
``["fail2ban-client", "start"]``.
|
||||
|
||||
Returns:
|
||||
``True`` when the process exited with code 0.
|
||||
"""
|
||||
if not start_cmd_parts:
|
||||
log.warning("fail2ban_start_cmd_empty")
|
||||
return False
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*start_cmd_parts,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
await asyncio.wait_for(proc.wait(), timeout=30.0)
|
||||
success = proc.returncode == 0
|
||||
if not success:
|
||||
log.warning(
|
||||
"fail2ban_start_cmd_nonzero",
|
||||
cmd=start_cmd_parts,
|
||||
returncode=proc.returncode,
|
||||
)
|
||||
return success
|
||||
except (TimeoutError, OSError) as exc:
|
||||
log.warning("fail2ban_start_cmd_error", cmd=start_cmd_parts, error=str(exc))
|
||||
return False
|
||||
|
||||
|
||||
# Shared functions from config_file_service are imported from app.utils.config_file_utils
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def list_inactive_jails(
|
||||
config_dir: str,
|
||||
socket_path: str,
|
||||
) -> InactiveJailListResponse:
|
||||
"""Return all jails defined in config files that are not currently active.
|
||||
|
||||
Parses ``jail.conf``, ``jail.local``, and ``jail.d/`` following the
|
||||
fail2ban merge order. A jail is considered inactive when:
|
||||
|
||||
- Its merged ``enabled`` value is ``false`` (or absent, which defaults to
|
||||
``false`` in fail2ban), **or**
|
||||
- Its ``enabled`` value is ``true`` in config but fail2ban does not report
|
||||
it as running.
|
||||
|
||||
Args:
|
||||
config_dir: Absolute path to the fail2ban configuration directory.
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.InactiveJailListResponse` with all
|
||||
inactive jails.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
parsed_result: tuple[dict[str, dict[str, str]], dict[str, str]] = await loop.run_in_executor(
|
||||
None, _parse_jails_sync, Path(config_dir)
|
||||
)
|
||||
all_jails, source_files = parsed_result
|
||||
active_names: set[str] = await _get_active_jail_names(socket_path)
|
||||
|
||||
inactive: list[InactiveJail] = []
|
||||
for jail_name, settings in sorted(all_jails.items()):
|
||||
if jail_name in active_names:
|
||||
# fail2ban reports this jail as running — skip it.
|
||||
continue
|
||||
|
||||
source = source_files.get(jail_name, config_dir)
|
||||
inactive.append(_build_inactive_jail(jail_name, settings, source, Path(config_dir)))
|
||||
|
||||
log.info(
|
||||
"inactive_jails_listed",
|
||||
total_defined=len(all_jails),
|
||||
active=len(active_names),
|
||||
inactive=len(inactive),
|
||||
)
|
||||
return InactiveJailListResponse(jails=inactive, total=len(inactive))
|
||||
|
||||
|
||||
async def activate_jail(
|
||||
config_dir: str,
|
||||
socket_path: str,
|
||||
name: str,
|
||||
req: ActivateJailRequest,
|
||||
) -> JailActivationResponse:
|
||||
"""Enable an inactive jail and reload fail2ban.
|
||||
|
||||
Performs pre-activation validation, writes ``enabled = true`` (plus any
|
||||
override values from *req*) to ``jail.d/{name}.local``, and triggers a
|
||||
full fail2ban reload. After the reload a multi-attempt health probe
|
||||
determines whether fail2ban (and the specific jail) are still running.
|
||||
|
||||
Args:
|
||||
config_dir: Absolute path to the fail2ban configuration directory.
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
name: Name of the jail to activate. Must exist in the parsed config.
|
||||
req: Optional override values to write alongside ``enabled = true``.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.JailActivationResponse` including
|
||||
``fail2ban_running`` and ``validation_warnings`` fields.
|
||||
|
||||
Raises:
|
||||
JailNameError: If *name* contains invalid characters.
|
||||
JailNotFoundInConfigError: If *name* is not defined in any config file.
|
||||
JailAlreadyActiveError: If fail2ban already reports *name* as running.
|
||||
ConfigWriteError: If writing the ``.local`` file fails.
|
||||
~app.utils.fail2ban_client.Fail2BanConnectionError: If the fail2ban
|
||||
socket is unreachable during reload.
|
||||
"""
|
||||
_safe_jail_name(name)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
all_jails, _source_files = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
||||
|
||||
if name not in all_jails:
|
||||
raise JailNotFoundInConfigError(name)
|
||||
|
||||
active_names = await _get_active_jail_names(socket_path)
|
||||
if name in active_names:
|
||||
raise JailAlreadyActiveError(name)
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# Pre-activation validation — collect warnings but do not block #
|
||||
# ---------------------------------------------------------------------- #
|
||||
validation_result: JailValidationResult = await loop.run_in_executor(
|
||||
None, _validate_jail_config_sync, Path(config_dir), name
|
||||
)
|
||||
warnings: list[str] = [f"{i.field}: {i.message}" for i in validation_result.issues]
|
||||
if warnings:
|
||||
log.warning(
|
||||
"jail_activation_validation_warnings",
|
||||
jail=name,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
# Block activation on critical validation failures (missing filter or logpath).
|
||||
blocking = [i for i in validation_result.issues if i.field in ("filter", "logpath")]
|
||||
if blocking:
|
||||
log.warning(
|
||||
"jail_activation_blocked",
|
||||
jail=name,
|
||||
issues=[f"{i.field}: {i.message}" for i in blocking],
|
||||
)
|
||||
return JailActivationResponse(
|
||||
name=name,
|
||||
active=False,
|
||||
fail2ban_running=True,
|
||||
validation_warnings=warnings,
|
||||
message=(f"Jail {name!r} cannot be activated: " + "; ".join(i.message for i in blocking)),
|
||||
)
|
||||
|
||||
overrides: dict[str, object] = {
|
||||
"bantime": req.bantime,
|
||||
"findtime": req.findtime,
|
||||
"maxretry": req.maxretry,
|
||||
"port": req.port,
|
||||
"logpath": req.logpath,
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# Backup the existing .local file (if any) before overwriting it so that #
|
||||
# we can restore it if activation fails. #
|
||||
# ---------------------------------------------------------------------- #
|
||||
local_path = Path(config_dir) / "jail.d" / f"{name}.local"
|
||||
original_content: bytes | None = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: local_path.read_bytes() if local_path.exists() else None,
|
||||
)
|
||||
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
_write_local_override_sync,
|
||||
Path(config_dir),
|
||||
name,
|
||||
True,
|
||||
overrides,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# Activation reload — if it fails, roll back immediately #
|
||||
# ---------------------------------------------------------------------- #
|
||||
try:
|
||||
await reload_jails(socket_path, include_jails=[name])
|
||||
except JailNotFoundError as exc:
|
||||
# Jail configuration is invalid (e.g. missing logpath that prevents
|
||||
# fail2ban from loading the jail). Roll back and provide a specific error.
|
||||
log.warning(
|
||||
"reload_after_activate_failed_jail_not_found",
|
||||
jail=name,
|
||||
error=str(exc),
|
||||
)
|
||||
recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
|
||||
return JailActivationResponse(
|
||||
name=name,
|
||||
active=False,
|
||||
fail2ban_running=False,
|
||||
recovered=recovered,
|
||||
validation_warnings=warnings,
|
||||
message=(
|
||||
f"Jail {name!r} activation failed: {str(exc)}. "
|
||||
"Check that all logpath files exist and are readable. "
|
||||
"The configuration was "
|
||||
+ ("automatically recovered." if recovered else "not recovered — manual intervention is required.")
|
||||
),
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("reload_after_activate_failed", jail=name, error=str(exc))
|
||||
recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
|
||||
return JailActivationResponse(
|
||||
name=name,
|
||||
active=False,
|
||||
fail2ban_running=False,
|
||||
recovered=recovered,
|
||||
validation_warnings=warnings,
|
||||
message=(
|
||||
f"Jail {name!r} activation failed during reload and the "
|
||||
"configuration was "
|
||||
+ ("automatically recovered." if recovered else "not recovered — manual intervention is required.")
|
||||
),
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# Post-reload health probe with retries #
|
||||
# ---------------------------------------------------------------------- #
|
||||
fail2ban_running = False
|
||||
for attempt in range(_POST_RELOAD_MAX_ATTEMPTS):
|
||||
if attempt > 0:
|
||||
await asyncio.sleep(_POST_RELOAD_PROBE_INTERVAL)
|
||||
if await _probe_fail2ban_running(socket_path):
|
||||
fail2ban_running = True
|
||||
break
|
||||
|
||||
if not fail2ban_running:
|
||||
log.warning(
|
||||
"fail2ban_down_after_activate",
|
||||
jail=name,
|
||||
message="fail2ban socket unreachable after reload — initiating rollback.",
|
||||
)
|
||||
recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
|
||||
return JailActivationResponse(
|
||||
name=name,
|
||||
active=False,
|
||||
fail2ban_running=False,
|
||||
recovered=recovered,
|
||||
validation_warnings=warnings,
|
||||
message=(
|
||||
f"Jail {name!r} activation failed: fail2ban stopped responding "
|
||||
"after reload. The configuration was "
|
||||
+ ("automatically recovered." if recovered else "not recovered — manual intervention is required.")
|
||||
),
|
||||
)
|
||||
|
||||
# Verify the jail actually started (config error may prevent it silently).
|
||||
post_reload_names = await _get_active_jail_names(socket_path)
|
||||
actually_running = name in post_reload_names
|
||||
if not actually_running:
|
||||
log.warning(
|
||||
"jail_activation_unverified",
|
||||
jail=name,
|
||||
message="Jail did not appear in running jails — initiating rollback.",
|
||||
)
|
||||
recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
|
||||
return JailActivationResponse(
|
||||
name=name,
|
||||
active=False,
|
||||
fail2ban_running=True,
|
||||
recovered=recovered,
|
||||
validation_warnings=warnings,
|
||||
message=(
|
||||
f"Jail {name!r} was written to config but did not start after "
|
||||
"reload. The configuration was "
|
||||
+ ("automatically recovered." if recovered else "not recovered — manual intervention is required.")
|
||||
),
|
||||
)
|
||||
|
||||
log.info("jail_activated", jail=name)
|
||||
return JailActivationResponse(
|
||||
name=name,
|
||||
active=True,
|
||||
fail2ban_running=True,
|
||||
validation_warnings=warnings,
|
||||
message=f"Jail {name!r} activated successfully.",
|
||||
)
|
||||
|
||||
|
||||
async def _rollback_activation_async(
|
||||
config_dir: str,
|
||||
name: str,
|
||||
socket_path: str,
|
||||
original_content: bytes | None,
|
||||
) -> bool:
|
||||
"""Restore the pre-activation ``.local`` file and reload fail2ban.
|
||||
|
||||
Called internally by :func:`activate_jail` when the activation fails after
|
||||
the config file was already written. Tries to:
|
||||
|
||||
1. Restore the original file content (or delete the file if it was newly
|
||||
created by the activation attempt).
|
||||
2. Reload fail2ban so the daemon runs with the restored configuration.
|
||||
3. Probe fail2ban to confirm it came back up.
|
||||
|
||||
Args:
|
||||
config_dir: Absolute path to the fail2ban configuration directory.
|
||||
name: Name of the jail whose ``.local`` file should be restored.
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
original_content: Raw bytes of the original ``.local`` file, or
|
||||
``None`` if the file did not exist before the activation.
|
||||
|
||||
Returns:
|
||||
``True`` if fail2ban is responsive again after the rollback, ``False``
|
||||
if recovery also failed.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
local_path = Path(config_dir) / "jail.d" / f"{name}.local"
|
||||
|
||||
# Step 1 — restore original file (or delete it).
|
||||
try:
|
||||
await loop.run_in_executor(None, _restore_local_file_sync, local_path, original_content)
|
||||
log.info("jail_activation_rollback_file_restored", jail=name)
|
||||
except ConfigWriteError as exc:
|
||||
log.error("jail_activation_rollback_restore_failed", jail=name, error=str(exc))
|
||||
return False
|
||||
|
||||
# Step 2 — reload fail2ban with the restored config.
|
||||
try:
|
||||
await reload_jails(socket_path)
|
||||
log.info("jail_activation_rollback_reload_ok", jail=name)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("jail_activation_rollback_reload_failed", jail=name, error=str(exc))
|
||||
return False
|
||||
|
||||
# Step 3 — wait for fail2ban to come back.
|
||||
for attempt in range(_POST_RELOAD_MAX_ATTEMPTS):
|
||||
if attempt > 0:
|
||||
await asyncio.sleep(_POST_RELOAD_PROBE_INTERVAL)
|
||||
if await _probe_fail2ban_running(socket_path):
|
||||
log.info("jail_activation_rollback_recovered", jail=name)
|
||||
return True
|
||||
|
||||
log.warning("jail_activation_rollback_still_down", jail=name)
|
||||
return False
|
||||
|
||||
|
||||
async def deactivate_jail(
|
||||
config_dir: str,
|
||||
socket_path: str,
|
||||
name: str,
|
||||
) -> JailActivationResponse:
|
||||
"""Disable an active jail and reload fail2ban.
|
||||
|
||||
Writes ``enabled = false`` to ``jail.d/{name}.local`` and triggers a
|
||||
full fail2ban reload so the jail stops immediately.
|
||||
|
||||
Args:
|
||||
config_dir: Absolute path to the fail2ban configuration directory.
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
name: Name of the jail to deactivate. Must exist in the parsed config.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.JailActivationResponse`.
|
||||
|
||||
Raises:
|
||||
JailNameError: If *name* contains invalid characters.
|
||||
JailNotFoundInConfigError: If *name* is not defined in any config file.
|
||||
JailAlreadyInactiveError: If fail2ban already reports *name* as not
|
||||
running.
|
||||
ConfigWriteError: If writing the ``.local`` file fails.
|
||||
~app.utils.fail2ban_client.Fail2BanConnectionError: If the fail2ban
|
||||
socket is unreachable during reload.
|
||||
"""
|
||||
_safe_jail_name(name)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
all_jails, _source_files = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
||||
|
||||
if name not in all_jails:
|
||||
raise JailNotFoundInConfigError(name)
|
||||
|
||||
active_names = await _get_active_jail_names(socket_path)
|
||||
if name not in active_names:
|
||||
raise JailAlreadyInactiveError(name)
|
||||
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
_write_local_override_sync,
|
||||
Path(config_dir),
|
||||
name,
|
||||
False,
|
||||
{},
|
||||
)
|
||||
|
||||
try:
|
||||
await reload_jails(socket_path, exclude_jails=[name])
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("reload_after_deactivate_failed", jail=name, error=str(exc))
|
||||
|
||||
log.info("jail_deactivated", jail=name)
|
||||
return JailActivationResponse(
|
||||
name=name,
|
||||
active=False,
|
||||
message=f"Jail {name!r} deactivated successfully.",
|
||||
)
|
||||
|
||||
|
||||
async def delete_jail_local_override(
|
||||
config_dir: str,
|
||||
socket_path: str,
|
||||
name: str,
|
||||
) -> None:
|
||||
"""Delete the ``jail.d/{name}.local`` override file for an inactive jail.
|
||||
|
||||
This is the clean-up action shown in the config UI when an inactive jail
|
||||
still has a ``.local`` override file (e.g. ``enabled = false``). The
|
||||
file is deleted outright; no fail2ban reload is required because the jail
|
||||
is already inactive.
|
||||
|
||||
Args:
|
||||
config_dir: Absolute path to the fail2ban configuration directory.
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
name: Name of the jail whose ``.local`` file should be removed.
|
||||
|
||||
Raises:
|
||||
JailNameError: If *name* contains invalid characters.
|
||||
JailNotFoundInConfigError: If *name* is not defined in any config file.
|
||||
JailAlreadyActiveError: If the jail is currently active (refusing to
|
||||
delete the live config file).
|
||||
ConfigWriteError: If the file cannot be deleted.
|
||||
"""
|
||||
_safe_jail_name(name)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
all_jails, _source_files = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
||||
|
||||
if name not in all_jails:
|
||||
raise JailNotFoundInConfigError(name)
|
||||
|
||||
active_names = await _get_active_jail_names(socket_path)
|
||||
if name in active_names:
|
||||
raise JailAlreadyActiveError(name)
|
||||
|
||||
local_path = Path(config_dir) / "jail.d" / f"{name}.local"
|
||||
try:
|
||||
await loop.run_in_executor(None, lambda: local_path.unlink(missing_ok=True))
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(f"Failed to delete {local_path}: {exc}") from exc
|
||||
|
||||
log.info("jail_local_override_deleted", jail=name, path=str(local_path))
|
||||
|
||||
|
||||
async def validate_jail_config(
|
||||
config_dir: str,
|
||||
name: str,
|
||||
) -> JailValidationResult:
|
||||
"""Run pre-activation validation checks on a jail configuration.
|
||||
|
||||
Validates that referenced filter and action files exist in ``filter.d/``
|
||||
and ``action.d/``, that all regex patterns compile, and that declared log
|
||||
paths exist on disk.
|
||||
|
||||
Args:
|
||||
config_dir: Absolute path to the fail2ban configuration directory.
|
||||
name: Name of the jail to validate.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.JailValidationResult` with any issues found.
|
||||
|
||||
Raises:
|
||||
JailNameError: If *name* contains invalid characters.
|
||||
"""
|
||||
_safe_jail_name(name)
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
_validate_jail_config_sync,
|
||||
Path(config_dir),
|
||||
name,
|
||||
)
|
||||
|
||||
|
||||
async def rollback_jail(
|
||||
config_dir: str,
|
||||
socket_path: str,
|
||||
name: str,
|
||||
start_cmd_parts: list[str],
|
||||
) -> RollbackResponse:
|
||||
"""Disable a bad jail config and restart the fail2ban daemon.
|
||||
|
||||
Writes ``enabled = false`` to ``jail.d/{name}.local`` (works even when
|
||||
fail2ban is down — only a file write), then attempts to start the daemon
|
||||
with *start_cmd_parts*. Waits up to 10 seconds for the socket to respond.
|
||||
|
||||
Args:
|
||||
config_dir: Absolute path to the fail2ban configuration directory.
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
name: Name of the jail to disable.
|
||||
start_cmd_parts: Argument list for the daemon start command, e.g.
|
||||
``["fail2ban-client", "start"]``.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.RollbackResponse`.
|
||||
|
||||
Raises:
|
||||
JailNameError: If *name* contains invalid characters.
|
||||
ConfigWriteError: If writing the ``.local`` file fails.
|
||||
"""
|
||||
_safe_jail_name(name)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Write enabled=false — this must succeed even when fail2ban is down.
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
_write_local_override_sync,
|
||||
Path(config_dir),
|
||||
name,
|
||||
False,
|
||||
{},
|
||||
)
|
||||
log.info("jail_rolled_back_disabled", jail=name)
|
||||
|
||||
# Attempt to start the daemon.
|
||||
started = await start_daemon(start_cmd_parts)
|
||||
log.info("jail_rollback_start_attempted", jail=name, start_ok=started)
|
||||
|
||||
# Wait for the socket to come back.
|
||||
fail2ban_running = await wait_for_fail2ban(socket_path, max_wait_seconds=10.0, poll_interval=2.0)
|
||||
|
||||
active_jails = 0
|
||||
if fail2ban_running:
|
||||
names = await _get_active_jail_names(socket_path)
|
||||
active_jails = len(names)
|
||||
|
||||
if fail2ban_running:
|
||||
log.info("jail_rollback_success", jail=name, active_jails=active_jails)
|
||||
return RollbackResponse(
|
||||
jail_name=name,
|
||||
disabled=True,
|
||||
fail2ban_running=True,
|
||||
active_jails=active_jails,
|
||||
message=(f"Jail {name!r} disabled and fail2ban restarted successfully with {active_jails} active jail(s)."),
|
||||
)
|
||||
|
||||
log.warning("jail_rollback_fail2ban_still_down", jail=name)
|
||||
return RollbackResponse(
|
||||
jail_name=name,
|
||||
disabled=True,
|
||||
fail2ban_running=False,
|
||||
active_jails=0,
|
||||
message=(
|
||||
f"Jail {name!r} was disabled but fail2ban did not come back online. "
|
||||
"Check the fail2ban log for additional errors."
|
||||
),
|
||||
)
|
||||
@@ -14,10 +14,11 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import contextlib
|
||||
import ipaddress
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, TypedDict, cast
|
||||
|
||||
import structlog
|
||||
|
||||
from app.exceptions import JailNotFoundError, JailOperationError
|
||||
from app.models.ban import ActiveBan, ActiveBanListResponse, JailBannedIpsResponse
|
||||
from app.models.config import BantimeEscalation
|
||||
from app.models.jail import (
|
||||
@@ -27,10 +28,36 @@ from app.models.jail import (
|
||||
JailStatus,
|
||||
JailSummary,
|
||||
)
|
||||
from app.utils.fail2ban_client import Fail2BanClient, Fail2BanConnectionError
|
||||
from app.utils.fail2ban_client import (
|
||||
Fail2BanClient,
|
||||
Fail2BanCommand,
|
||||
Fail2BanConnectionError,
|
||||
Fail2BanResponse,
|
||||
Fail2BanToken,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable
|
||||
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
|
||||
from app.models.geo import GeoBatchLookup, GeoEnricher, GeoInfo
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
class IpLookupResult(TypedDict):
|
||||
"""Result returned by :func:`lookup_ip`.
|
||||
|
||||
This is intentionally a :class:`TypedDict` to provide precise typing for
|
||||
callers (e.g. routers) while keeping the implementation flexible.
|
||||
"""
|
||||
|
||||
ip: str
|
||||
currently_banned_in: list[str]
|
||||
geo: GeoInfo | None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -55,29 +82,12 @@ _backend_cmd_lock: asyncio.Lock = asyncio.Lock()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class JailNotFoundError(Exception):
|
||||
"""Raised when a requested jail name does not exist in fail2ban."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""Initialise with the jail name that was not found.
|
||||
|
||||
Args:
|
||||
name: The jail name that could not be located.
|
||||
"""
|
||||
self.name: str = name
|
||||
super().__init__(f"Jail not found: {name!r}")
|
||||
|
||||
|
||||
class JailOperationError(Exception):
|
||||
"""Raised when a jail control command fails for a non-auth reason."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _ok(response: Any) -> Any:
|
||||
def _ok(response: object) -> object:
|
||||
"""Extract the payload from a fail2ban ``(return_code, data)`` response.
|
||||
|
||||
Args:
|
||||
@@ -90,7 +100,7 @@ def _ok(response: Any) -> Any:
|
||||
ValueError: If the response indicates an error (return code ≠ 0).
|
||||
"""
|
||||
try:
|
||||
code, data = response
|
||||
code, data = cast("Fail2BanResponse", response)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
|
||||
|
||||
@@ -100,7 +110,7 @@ def _ok(response: Any) -> Any:
|
||||
return data
|
||||
|
||||
|
||||
def _to_dict(pairs: Any) -> dict[str, Any]:
|
||||
def _to_dict(pairs: object) -> dict[str, object]:
|
||||
"""Convert a list of ``(key, value)`` pairs to a plain dict.
|
||||
|
||||
Args:
|
||||
@@ -111,7 +121,7 @@ def _to_dict(pairs: Any) -> dict[str, Any]:
|
||||
"""
|
||||
if not isinstance(pairs, (list, tuple)):
|
||||
return {}
|
||||
result: dict[str, Any] = {}
|
||||
result: dict[str, object] = {}
|
||||
for item in pairs:
|
||||
try:
|
||||
k, v = item
|
||||
@@ -121,7 +131,7 @@ def _to_dict(pairs: Any) -> dict[str, Any]:
|
||||
return result
|
||||
|
||||
|
||||
def _ensure_list(value: Any) -> list[str]:
|
||||
def _ensure_list(value: object | None) -> list[str]:
|
||||
"""Coerce a fail2ban response value to a list of strings.
|
||||
|
||||
Some fail2ban ``get`` responses return ``None`` or a single string
|
||||
@@ -170,9 +180,9 @@ def _is_not_found_error(exc: Exception) -> bool:
|
||||
|
||||
async def _safe_get(
|
||||
client: Fail2BanClient,
|
||||
command: list[Any],
|
||||
default: Any = None,
|
||||
) -> Any:
|
||||
command: Fail2BanCommand,
|
||||
default: object | None = None,
|
||||
) -> object | None:
|
||||
"""Send a ``get`` command and return ``default`` on error.
|
||||
|
||||
Errors during optional detail queries (logpath, regex, etc.) should
|
||||
@@ -187,7 +197,8 @@ async def _safe_get(
|
||||
The response payload, or *default* on any error.
|
||||
"""
|
||||
try:
|
||||
return _ok(await client.send(command))
|
||||
response = await client.send(command)
|
||||
return _ok(cast("Fail2BanResponse", response))
|
||||
except (ValueError, TypeError, Exception):
|
||||
return default
|
||||
|
||||
@@ -309,7 +320,7 @@ async def _fetch_jail_summary(
|
||||
backend_cmd_is_supported = await _check_backend_cmd_supported(client, name)
|
||||
|
||||
# Build the gather list based on command support.
|
||||
gather_list: list[Any] = [
|
||||
gather_list: list[Awaitable[object]] = [
|
||||
client.send(["status", name, "short"]),
|
||||
client.send(["get", name, "bantime"]),
|
||||
client.send(["get", name, "findtime"]),
|
||||
@@ -322,25 +333,23 @@ async def _fetch_jail_summary(
|
||||
client.send(["get", name, "backend"]),
|
||||
client.send(["get", name, "idle"]),
|
||||
])
|
||||
uses_backend_backend_commands = True
|
||||
else:
|
||||
# Commands not supported; return default values without sending.
|
||||
async def _return_default(value: Any) -> tuple[int, Any]:
|
||||
async def _return_default(value: object | None) -> Fail2BanResponse:
|
||||
return (0, value)
|
||||
|
||||
gather_list.extend([
|
||||
_return_default("polling"), # backend default
|
||||
_return_default(False), # idle default
|
||||
])
|
||||
uses_backend_backend_commands = False
|
||||
|
||||
_r = await asyncio.gather(*gather_list, return_exceptions=True)
|
||||
status_raw: Any = _r[0]
|
||||
bantime_raw: Any = _r[1]
|
||||
findtime_raw: Any = _r[2]
|
||||
maxretry_raw: Any = _r[3]
|
||||
backend_raw: Any = _r[4]
|
||||
idle_raw: Any = _r[5]
|
||||
status_raw: object | Exception = _r[0]
|
||||
bantime_raw: object | Exception = _r[1]
|
||||
findtime_raw: object | Exception = _r[2]
|
||||
maxretry_raw: object | Exception = _r[3]
|
||||
backend_raw: object | Exception = _r[4]
|
||||
idle_raw: object | Exception = _r[5]
|
||||
|
||||
# Parse jail status (filter + actions).
|
||||
jail_status: JailStatus | None = None
|
||||
@@ -350,35 +359,35 @@ async def _fetch_jail_summary(
|
||||
filter_stats = _to_dict(raw.get("Filter") or [])
|
||||
action_stats = _to_dict(raw.get("Actions") or [])
|
||||
jail_status = JailStatus(
|
||||
currently_banned=int(action_stats.get("Currently banned", 0) or 0),
|
||||
total_banned=int(action_stats.get("Total banned", 0) or 0),
|
||||
currently_failed=int(filter_stats.get("Currently failed", 0) or 0),
|
||||
total_failed=int(filter_stats.get("Total failed", 0) or 0),
|
||||
currently_banned=int(str(action_stats.get("Currently banned", 0) or 0)),
|
||||
total_banned=int(str(action_stats.get("Total banned", 0) or 0)),
|
||||
currently_failed=int(str(filter_stats.get("Currently failed", 0) or 0)),
|
||||
total_failed=int(str(filter_stats.get("Total failed", 0) or 0)),
|
||||
)
|
||||
except (ValueError, TypeError) as exc:
|
||||
log.warning("jail_status_parse_error", jail=name, error=str(exc))
|
||||
|
||||
def _safe_int(raw: Any, fallback: int) -> int:
|
||||
def _safe_int(raw: object | Exception, fallback: int) -> int:
|
||||
if isinstance(raw, Exception):
|
||||
return fallback
|
||||
try:
|
||||
return int(_ok(raw))
|
||||
return int(str(_ok(cast("Fail2BanResponse", raw))))
|
||||
except (ValueError, TypeError):
|
||||
return fallback
|
||||
|
||||
def _safe_str(raw: Any, fallback: str) -> str:
|
||||
def _safe_str(raw: object | Exception, fallback: str) -> str:
|
||||
if isinstance(raw, Exception):
|
||||
return fallback
|
||||
try:
|
||||
return str(_ok(raw))
|
||||
return str(_ok(cast("Fail2BanResponse", raw)))
|
||||
except (ValueError, TypeError):
|
||||
return fallback
|
||||
|
||||
def _safe_bool(raw: Any, fallback: bool = False) -> bool:
|
||||
def _safe_bool(raw: object | Exception, fallback: bool = False) -> bool:
|
||||
if isinstance(raw, Exception):
|
||||
return fallback
|
||||
try:
|
||||
return bool(_ok(raw))
|
||||
return bool(_ok(cast("Fail2BanResponse", raw)))
|
||||
except (ValueError, TypeError):
|
||||
return fallback
|
||||
|
||||
@@ -428,10 +437,10 @@ async def get_jail(socket_path: str, name: str) -> JailDetailResponse:
|
||||
action_stats = _to_dict(raw.get("Actions") or [])
|
||||
|
||||
jail_status = JailStatus(
|
||||
currently_banned=int(action_stats.get("Currently banned", 0) or 0),
|
||||
total_banned=int(action_stats.get("Total banned", 0) or 0),
|
||||
currently_failed=int(filter_stats.get("Currently failed", 0) or 0),
|
||||
total_failed=int(filter_stats.get("Total failed", 0) or 0),
|
||||
currently_banned=int(str(action_stats.get("Currently banned", 0) or 0)),
|
||||
total_banned=int(str(action_stats.get("Total banned", 0) or 0)),
|
||||
currently_failed=int(str(filter_stats.get("Currently failed", 0) or 0)),
|
||||
total_failed=int(str(filter_stats.get("Total failed", 0) or 0)),
|
||||
)
|
||||
|
||||
# Fetch all detail fields in parallel.
|
||||
@@ -480,11 +489,11 @@ async def get_jail(socket_path: str, name: str) -> JailDetailResponse:
|
||||
bt_increment: bool = bool(bt_increment_raw)
|
||||
bantime_escalation = BantimeEscalation(
|
||||
increment=bt_increment,
|
||||
factor=float(bt_factor_raw) if bt_factor_raw is not None else None,
|
||||
factor=float(str(bt_factor_raw)) if bt_factor_raw is not None else None,
|
||||
formula=str(bt_formula_raw) if bt_formula_raw else None,
|
||||
multipliers=str(bt_multipliers_raw) if bt_multipliers_raw else None,
|
||||
max_time=int(bt_maxtime_raw) if bt_maxtime_raw is not None else None,
|
||||
rnd_time=int(bt_rndtime_raw) if bt_rndtime_raw is not None else None,
|
||||
max_time=int(str(bt_maxtime_raw)) if bt_maxtime_raw is not None else None,
|
||||
rnd_time=int(str(bt_rndtime_raw)) if bt_rndtime_raw is not None else None,
|
||||
overall_jails=bool(bt_overalljails_raw),
|
||||
)
|
||||
|
||||
@@ -500,9 +509,9 @@ async def get_jail(socket_path: str, name: str) -> JailDetailResponse:
|
||||
ignore_ips=_ensure_list(ignoreip_raw),
|
||||
date_pattern=str(datepattern_raw) if datepattern_raw else None,
|
||||
log_encoding=str(logencoding_raw or "UTF-8"),
|
||||
find_time=int(findtime_raw or 600),
|
||||
ban_time=int(bantime_raw or 600),
|
||||
max_retry=int(maxretry_raw or 5),
|
||||
find_time=int(str(findtime_raw or 600)),
|
||||
ban_time=int(str(bantime_raw or 600)),
|
||||
max_retry=int(str(maxretry_raw or 5)),
|
||||
bantime_escalation=bantime_escalation,
|
||||
status=jail_status,
|
||||
actions=_ensure_list(actions_raw),
|
||||
@@ -671,8 +680,8 @@ async def reload_all(
|
||||
if exclude_jails:
|
||||
names_set -= set(exclude_jails)
|
||||
|
||||
stream: list[list[str]] = [["start", n] for n in sorted(names_set)]
|
||||
_ok(await client.send(["reload", "--all", [], stream]))
|
||||
stream: list[list[object]] = [["start", n] for n in sorted(names_set)]
|
||||
_ok(await client.send(["reload", "--all", [], cast("Fail2BanToken", stream)]))
|
||||
log.info("all_jails_reloaded")
|
||||
except ValueError as exc:
|
||||
# Detect UnknownJailException (missing or invalid jail configuration)
|
||||
@@ -685,24 +694,29 @@ async def reload_all(
|
||||
|
||||
|
||||
async def restart(socket_path: str) -> None:
|
||||
"""Restart the fail2ban service (daemon).
|
||||
"""Stop the fail2ban daemon via the Unix socket.
|
||||
|
||||
Sends the 'restart' command to the fail2ban daemon via the Unix socket.
|
||||
All jails are stopped and the daemon is restarted, re-reading all
|
||||
configuration from scratch.
|
||||
Sends ``["stop"]`` to the fail2ban daemon, which calls ``server.quit()``
|
||||
on the daemon side and tears down all jails. The caller is responsible
|
||||
for starting the daemon again (e.g. via ``fail2ban-client start``).
|
||||
|
||||
Note:
|
||||
``["restart"]`` is a *client-side* orchestration command that is not
|
||||
handled by the fail2ban server transmitter — sending it to the socket
|
||||
raises ``"Invalid command"`` in the daemon.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
|
||||
Raises:
|
||||
JailOperationError: If fail2ban reports the operation failed.
|
||||
JailOperationError: If fail2ban reports the stop command failed.
|
||||
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
|
||||
cannot be reached.
|
||||
"""
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
try:
|
||||
_ok(await client.send(["restart"]))
|
||||
log.info("fail2ban_restarted")
|
||||
_ok(await client.send(["stop"]))
|
||||
log.info("fail2ban_stopped_for_restart")
|
||||
except ValueError as exc:
|
||||
raise JailOperationError(str(exc)) from exc
|
||||
|
||||
@@ -790,9 +804,10 @@ async def unban_ip(
|
||||
|
||||
async def get_active_bans(
|
||||
socket_path: str,
|
||||
geo_enricher: Any | None = None,
|
||||
http_session: Any | None = None,
|
||||
app_db: Any | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
http_session: aiohttp.ClientSession | None = None,
|
||||
app_db: aiosqlite.Connection | None = None,
|
||||
) -> ActiveBanListResponse:
|
||||
"""Return all currently banned IPs across every jail.
|
||||
|
||||
@@ -827,7 +842,6 @@ async def get_active_bans(
|
||||
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
|
||||
cannot be reached.
|
||||
"""
|
||||
from app.services import geo_service # noqa: PLC0415
|
||||
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
|
||||
@@ -844,7 +858,7 @@ async def get_active_bans(
|
||||
return ActiveBanListResponse(bans=[], total=0)
|
||||
|
||||
# For each jail, fetch the ban list with time info in parallel.
|
||||
results: list[Any] = await asyncio.gather(
|
||||
results: list[object | Exception] = await asyncio.gather(
|
||||
*[client.send(["get", jn, "banip", "--with-time"]) for jn in jail_names],
|
||||
return_exceptions=True,
|
||||
)
|
||||
@@ -860,7 +874,7 @@ async def get_active_bans(
|
||||
continue
|
||||
|
||||
try:
|
||||
ban_list: list[str] = _ok(raw_result) or []
|
||||
ban_list: list[str] = cast("list[str]", _ok(raw_result)) or []
|
||||
except (TypeError, ValueError) as exc:
|
||||
log.warning(
|
||||
"active_bans_parse_error",
|
||||
@@ -875,10 +889,10 @@ async def get_active_bans(
|
||||
bans.append(ban)
|
||||
|
||||
# Enrich with geo data — prefer batch lookup over per-IP enricher.
|
||||
if http_session is not None and bans:
|
||||
if http_session is not None and bans and geo_batch_lookup is not None:
|
||||
all_ips: list[str] = [ban.ip for ban in bans]
|
||||
try:
|
||||
geo_map = await geo_service.lookup_batch(all_ips, http_session, db=app_db)
|
||||
geo_map = await geo_batch_lookup(all_ips, http_session, db=app_db)
|
||||
except Exception: # noqa: BLE001
|
||||
log.warning("active_bans_batch_geo_failed")
|
||||
geo_map = {}
|
||||
@@ -987,8 +1001,9 @@ async def get_jail_banned_ips(
|
||||
page: int = 1,
|
||||
page_size: int = 25,
|
||||
search: str | None = None,
|
||||
http_session: Any | None = None,
|
||||
app_db: Any | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
http_session: aiohttp.ClientSession | None = None,
|
||||
app_db: aiosqlite.Connection | None = None,
|
||||
) -> JailBannedIpsResponse:
|
||||
"""Return a paginated list of currently banned IPs for a single jail.
|
||||
|
||||
@@ -1014,8 +1029,6 @@ async def get_jail_banned_ips(
|
||||
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket is
|
||||
unreachable.
|
||||
"""
|
||||
from app.services import geo_service # noqa: PLC0415
|
||||
|
||||
# Clamp page_size to the allowed maximum.
|
||||
page_size = min(page_size, _MAX_PAGE_SIZE)
|
||||
|
||||
@@ -1035,7 +1048,7 @@ async def get_jail_banned_ips(
|
||||
except (ValueError, TypeError):
|
||||
raw_result = []
|
||||
|
||||
ban_list: list[str] = raw_result or []
|
||||
ban_list: list[str] = cast("list[str]", raw_result) or []
|
||||
|
||||
# Parse all entries.
|
||||
all_bans: list[ActiveBan] = []
|
||||
@@ -1056,10 +1069,10 @@ async def get_jail_banned_ips(
|
||||
page_bans = all_bans[start : start + page_size]
|
||||
|
||||
# Geo-enrich only the page slice.
|
||||
if http_session is not None and page_bans:
|
||||
if http_session is not None and page_bans and geo_batch_lookup is not None:
|
||||
page_ips = [b.ip for b in page_bans]
|
||||
try:
|
||||
geo_map = await geo_service.lookup_batch(page_ips, http_session, db=app_db)
|
||||
geo_map = await geo_batch_lookup(page_ips, http_session, db=app_db)
|
||||
except Exception: # noqa: BLE001
|
||||
log.warning("jail_banned_ips_geo_failed", jail=jail_name)
|
||||
geo_map = {}
|
||||
@@ -1089,7 +1102,7 @@ async def get_jail_banned_ips(
|
||||
|
||||
async def _enrich_bans(
|
||||
bans: list[ActiveBan],
|
||||
geo_enricher: Any,
|
||||
geo_enricher: GeoEnricher,
|
||||
) -> list[ActiveBan]:
|
||||
"""Enrich ban records with geo data asynchronously.
|
||||
|
||||
@@ -1100,14 +1113,15 @@ async def _enrich_bans(
|
||||
Returns:
|
||||
The same list with ``country`` fields populated where lookup succeeded.
|
||||
"""
|
||||
geo_results: list[Any] = await asyncio.gather(
|
||||
*[geo_enricher(ban.ip) for ban in bans],
|
||||
geo_results: list[object | Exception] = await asyncio.gather(
|
||||
*[cast("Awaitable[object]", geo_enricher(ban.ip)) for ban in bans],
|
||||
return_exceptions=True,
|
||||
)
|
||||
enriched: list[ActiveBan] = []
|
||||
for ban, geo in zip(bans, geo_results, strict=False):
|
||||
if geo is not None and not isinstance(geo, Exception):
|
||||
enriched.append(ban.model_copy(update={"country": geo.country_code}))
|
||||
geo_info = cast("GeoInfo", geo)
|
||||
enriched.append(ban.model_copy(update={"country": geo_info.country_code}))
|
||||
else:
|
||||
enriched.append(ban)
|
||||
return enriched
|
||||
@@ -1255,8 +1269,8 @@ async def set_ignore_self(socket_path: str, name: str, *, on: bool) -> None:
|
||||
async def lookup_ip(
|
||||
socket_path: str,
|
||||
ip: str,
|
||||
geo_enricher: Any | None = None,
|
||||
) -> dict[str, Any]:
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
) -> IpLookupResult:
|
||||
"""Return ban status and history for a single IP address.
|
||||
|
||||
Checks every running jail for whether the IP is currently banned.
|
||||
@@ -1299,7 +1313,7 @@ async def lookup_ip(
|
||||
)
|
||||
|
||||
# Check ban status per jail in parallel.
|
||||
ban_results: list[Any] = await asyncio.gather(
|
||||
ban_results: list[object | Exception] = await asyncio.gather(
|
||||
*[client.send(["get", jn, "banip"]) for jn in jail_names],
|
||||
return_exceptions=True,
|
||||
)
|
||||
@@ -1309,7 +1323,7 @@ async def lookup_ip(
|
||||
if isinstance(result, Exception):
|
||||
continue
|
||||
try:
|
||||
ban_list: list[str] = _ok(result) or []
|
||||
ban_list: list[str] = cast("list[str]", _ok(result)) or []
|
||||
if ip in ban_list:
|
||||
currently_banned_in.append(jail_name)
|
||||
except (ValueError, TypeError):
|
||||
@@ -1346,6 +1360,6 @@ async def unban_all_ips(socket_path: str) -> int:
|
||||
cannot be reached.
|
||||
"""
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
count: int = int(_ok(await client.send(["unban", "--all"])))
|
||||
count: int = int(str(_ok(await client.send(["unban", "--all"])) or 0))
|
||||
log.info("all_ips_unbanned", count=count)
|
||||
return count
|
||||
|
||||
128
backend/app/services/log_service.py
Normal file
128
backend/app/services/log_service.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""Log helper service.
|
||||
|
||||
Contains regex test and log preview helpers that are independent of
|
||||
fail2ban socket operations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from app.models.config import (
|
||||
LogPreviewLine,
|
||||
LogPreviewRequest,
|
||||
LogPreviewResponse,
|
||||
RegexTestRequest,
|
||||
RegexTestResponse,
|
||||
)
|
||||
|
||||
|
||||
def test_regex(request: RegexTestRequest) -> RegexTestResponse:
|
||||
"""Test a regex pattern against a sample log line.
|
||||
|
||||
Args:
|
||||
request: The regex test payload.
|
||||
|
||||
Returns:
|
||||
RegexTestResponse with match result, groups and optional error.
|
||||
"""
|
||||
try:
|
||||
compiled = re.compile(request.fail_regex)
|
||||
except re.error as exc:
|
||||
return RegexTestResponse(matched=False, groups=[], error=str(exc))
|
||||
|
||||
match = compiled.search(request.log_line)
|
||||
if match is None:
|
||||
return RegexTestResponse(matched=False)
|
||||
|
||||
groups: list[str] = list(match.groups() or [])
|
||||
return RegexTestResponse(matched=True, groups=[str(g) for g in groups if g is not None])
|
||||
|
||||
|
||||
async def preview_log(req: LogPreviewRequest) -> LogPreviewResponse:
|
||||
"""Inspect the last lines of a log file and evaluate regex matches.
|
||||
|
||||
Args:
|
||||
req: Log preview request.
|
||||
|
||||
Returns:
|
||||
LogPreviewResponse with lines, total_lines and matched_count, or error.
|
||||
"""
|
||||
try:
|
||||
compiled = re.compile(req.fail_regex)
|
||||
except re.error as exc:
|
||||
return LogPreviewResponse(
|
||||
lines=[],
|
||||
total_lines=0,
|
||||
matched_count=0,
|
||||
regex_error=str(exc),
|
||||
)
|
||||
|
||||
path = Path(req.log_path)
|
||||
if not path.is_file():
|
||||
return LogPreviewResponse(
|
||||
lines=[],
|
||||
total_lines=0,
|
||||
matched_count=0,
|
||||
regex_error=f"File not found: {req.log_path!r}",
|
||||
)
|
||||
|
||||
try:
|
||||
raw_lines = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
_read_tail_lines,
|
||||
str(path),
|
||||
req.num_lines,
|
||||
)
|
||||
except OSError as exc:
|
||||
return LogPreviewResponse(
|
||||
lines=[],
|
||||
total_lines=0,
|
||||
matched_count=0,
|
||||
regex_error=f"Cannot read file: {exc}",
|
||||
)
|
||||
|
||||
result_lines: list[LogPreviewLine] = []
|
||||
matched_count = 0
|
||||
for line in raw_lines:
|
||||
m = compiled.search(line)
|
||||
groups = [str(g) for g in (m.groups() or []) if g is not None] if m else []
|
||||
result_lines.append(
|
||||
LogPreviewLine(line=line, matched=(m is not None), groups=groups),
|
||||
)
|
||||
if m:
|
||||
matched_count += 1
|
||||
|
||||
return LogPreviewResponse(
|
||||
lines=result_lines,
|
||||
total_lines=len(result_lines),
|
||||
matched_count=matched_count,
|
||||
)
|
||||
|
||||
|
||||
def _read_tail_lines(file_path: str, num_lines: int) -> list[str]:
|
||||
"""Read the last *num_lines* from *file_path* in a memory-efficient way."""
|
||||
chunk_size = 8192
|
||||
raw_lines: list[bytes] = []
|
||||
with open(file_path, "rb") as fh:
|
||||
fh.seek(0, 2)
|
||||
end_pos = fh.tell()
|
||||
if end_pos == 0:
|
||||
return []
|
||||
|
||||
buf = b""
|
||||
pos = end_pos
|
||||
while len(raw_lines) <= num_lines and pos > 0:
|
||||
read_size = min(chunk_size, pos)
|
||||
pos -= read_size
|
||||
fh.seek(pos)
|
||||
chunk = fh.read(read_size)
|
||||
buf = chunk + buf
|
||||
raw_lines = buf.split(b"\n")
|
||||
|
||||
if pos > 0 and len(raw_lines) > 1:
|
||||
raw_lines = raw_lines[1:]
|
||||
|
||||
return [ln.decode("utf-8", errors="replace").rstrip() for ln in raw_lines[-num_lines:] if ln.strip()]
|
||||
@@ -817,7 +817,7 @@ async def get_parsed_filter_file(config_dir: str, name: str) -> FilterConfig:
|
||||
"""Parse a filter definition file and return its structured representation.
|
||||
|
||||
Reads the raw ``.conf``/``.local`` file from ``filter.d/``, parses it with
|
||||
:func:`~app.services.conffile_parser.parse_filter_file`, and returns the
|
||||
:func:`~app.utils.conffile_parser.parse_filter_file`, and returns the
|
||||
result.
|
||||
|
||||
Args:
|
||||
@@ -831,7 +831,7 @@ async def get_parsed_filter_file(config_dir: str, name: str) -> FilterConfig:
|
||||
ConfigFileNotFoundError: If no matching file is found.
|
||||
ConfigDirError: If *config_dir* does not exist.
|
||||
"""
|
||||
from app.services.conffile_parser import parse_filter_file # avoid circular imports
|
||||
from app.utils.conffile_parser import parse_filter_file # avoid circular imports
|
||||
|
||||
def _do() -> FilterConfig:
|
||||
filter_d = _resolve_subdir(config_dir, "filter.d")
|
||||
@@ -863,7 +863,7 @@ async def update_parsed_filter_file(
|
||||
ConfigFileWriteError: If the file cannot be written.
|
||||
ConfigDirError: If *config_dir* does not exist.
|
||||
"""
|
||||
from app.services.conffile_parser import ( # avoid circular imports
|
||||
from app.utils.conffile_parser import ( # avoid circular imports
|
||||
merge_filter_update,
|
||||
parse_filter_file,
|
||||
serialize_filter_config,
|
||||
@@ -901,7 +901,7 @@ async def get_parsed_action_file(config_dir: str, name: str) -> ActionConfig:
|
||||
ConfigFileNotFoundError: If no matching file is found.
|
||||
ConfigDirError: If *config_dir* does not exist.
|
||||
"""
|
||||
from app.services.conffile_parser import parse_action_file # avoid circular imports
|
||||
from app.utils.conffile_parser import parse_action_file # avoid circular imports
|
||||
|
||||
def _do() -> ActionConfig:
|
||||
action_d = _resolve_subdir(config_dir, "action.d")
|
||||
@@ -930,7 +930,7 @@ async def update_parsed_action_file(
|
||||
ConfigFileWriteError: If the file cannot be written.
|
||||
ConfigDirError: If *config_dir* does not exist.
|
||||
"""
|
||||
from app.services.conffile_parser import ( # avoid circular imports
|
||||
from app.utils.conffile_parser import ( # avoid circular imports
|
||||
merge_action_update,
|
||||
parse_action_file,
|
||||
serialize_action_config,
|
||||
@@ -963,7 +963,7 @@ async def get_parsed_jail_file(config_dir: str, filename: str) -> JailFileConfig
|
||||
ConfigFileNotFoundError: If no matching file is found.
|
||||
ConfigDirError: If *config_dir* does not exist.
|
||||
"""
|
||||
from app.services.conffile_parser import parse_jail_file # avoid circular imports
|
||||
from app.utils.conffile_parser import parse_jail_file # avoid circular imports
|
||||
|
||||
def _do() -> JailFileConfig:
|
||||
jail_d = _resolve_subdir(config_dir, "jail.d")
|
||||
@@ -992,7 +992,7 @@ async def update_parsed_jail_file(
|
||||
ConfigFileWriteError: If the file cannot be written.
|
||||
ConfigDirError: If *config_dir* does not exist.
|
||||
"""
|
||||
from app.services.conffile_parser import ( # avoid circular imports
|
||||
from app.utils.conffile_parser import ( # avoid circular imports
|
||||
merge_jail_file_update,
|
||||
parse_jail_file,
|
||||
serialize_jail_file_config,
|
||||
@@ -10,25 +10,50 @@ HTTP/FastAPI concerns.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import structlog
|
||||
|
||||
from app.exceptions import ServerOperationError
|
||||
from app.exceptions import ServerOperationError
|
||||
from app.models.server import ServerSettings, ServerSettingsResponse, ServerSettingsUpdate
|
||||
from app.utils.fail2ban_client import Fail2BanClient
|
||||
from app.utils.fail2ban_client import Fail2BanClient, Fail2BanCommand, Fail2BanResponse
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
type Fail2BanSettingValue = str | int | bool
|
||||
"""Allowed values for server settings commands."""
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
_SOCKET_TIMEOUT: float = 10.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Custom exceptions
|
||||
# ---------------------------------------------------------------------------
|
||||
def _to_int(value: object | None, default: int) -> int:
|
||||
"""Convert a raw value to an int, falling back to a default.
|
||||
|
||||
The fail2ban control socket can return either int or str values for some
|
||||
settings, so we normalise them here in a type-safe way.
|
||||
"""
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
if isinstance(value, float):
|
||||
return int(value)
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return default
|
||||
return default
|
||||
|
||||
|
||||
class ServerOperationError(Exception):
|
||||
"""Raised when a server-level set command fails."""
|
||||
def _to_str(value: object | None, default: str) -> str:
|
||||
"""Convert a raw value to a string, falling back to a default."""
|
||||
if value is None:
|
||||
return default
|
||||
return str(value)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -36,7 +61,7 @@ class ServerOperationError(Exception):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _ok(response: Any) -> Any:
|
||||
def _ok(response: Fail2BanResponse) -> object:
|
||||
"""Extract payload from a fail2ban ``(code, data)`` response.
|
||||
|
||||
Args:
|
||||
@@ -59,9 +84,9 @@ def _ok(response: Any) -> Any:
|
||||
|
||||
async def _safe_get(
|
||||
client: Fail2BanClient,
|
||||
command: list[Any],
|
||||
default: Any = None,
|
||||
) -> Any:
|
||||
command: Fail2BanCommand,
|
||||
default: object | None = None,
|
||||
) -> object | None:
|
||||
"""Send a command and silently return *default* on any error.
|
||||
|
||||
Args:
|
||||
@@ -73,7 +98,8 @@ async def _safe_get(
|
||||
The successful response, or *default*.
|
||||
"""
|
||||
try:
|
||||
return _ok(await client.send(command))
|
||||
response = await client.send(command)
|
||||
return _ok(cast("Fail2BanResponse", response))
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
@@ -118,13 +144,20 @@ async def get_settings(socket_path: str) -> ServerSettingsResponse:
|
||||
_safe_get(client, ["get", "dbmaxmatches"], 10),
|
||||
)
|
||||
|
||||
log_level = _to_str(log_level_raw, "INFO").upper()
|
||||
log_target = _to_str(log_target_raw, "STDOUT")
|
||||
syslog_socket = _to_str(syslog_socket_raw, "") or None
|
||||
db_path = _to_str(db_path_raw, "/var/lib/fail2ban/fail2ban.sqlite3")
|
||||
db_purge_age = _to_int(db_purge_age_raw, 86400)
|
||||
db_max_matches = _to_int(db_max_matches_raw, 10)
|
||||
|
||||
settings = ServerSettings(
|
||||
log_level=str(log_level_raw or "INFO").upper(),
|
||||
log_target=str(log_target_raw or "STDOUT"),
|
||||
syslog_socket=str(syslog_socket_raw) if syslog_socket_raw else None,
|
||||
db_path=str(db_path_raw or "/var/lib/fail2ban/fail2ban.sqlite3"),
|
||||
db_purge_age=int(db_purge_age_raw or 86400),
|
||||
db_max_matches=int(db_max_matches_raw or 10),
|
||||
log_level=log_level,
|
||||
log_target=log_target,
|
||||
syslog_socket=syslog_socket,
|
||||
db_path=db_path,
|
||||
db_purge_age=db_purge_age,
|
||||
db_max_matches=db_max_matches,
|
||||
)
|
||||
|
||||
log.info("server_settings_fetched")
|
||||
@@ -146,9 +179,10 @@ async def update_settings(socket_path: str, update: ServerSettingsUpdate) -> Non
|
||||
"""
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
|
||||
async def _set(key: str, value: Any) -> None:
|
||||
async def _set(key: str, value: Fail2BanSettingValue) -> None:
|
||||
try:
|
||||
_ok(await client.send(["set", key, value]))
|
||||
response = await client.send(["set", key, value])
|
||||
_ok(cast("Fail2BanResponse", response))
|
||||
except ValueError as exc:
|
||||
raise ServerOperationError(f"Failed to set {key!r} = {value!r}: {exc}") from exc
|
||||
|
||||
@@ -182,7 +216,8 @@ async def flush_logs(socket_path: str) -> str:
|
||||
"""
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
try:
|
||||
result = _ok(await client.send(["flushlogs"]))
|
||||
response = await client.send(["flushlogs"])
|
||||
result = _ok(cast("Fail2BanResponse", response))
|
||||
log.info("logs_flushed", result=result)
|
||||
return str(result)
|
||||
except ValueError as exc:
|
||||
|
||||
@@ -102,30 +102,20 @@ async def run_setup(
|
||||
log.info("bangui_setup_completed")
|
||||
|
||||
|
||||
from app.utils.setup_utils import (
|
||||
get_map_color_thresholds as util_get_map_color_thresholds,
|
||||
get_password_hash as util_get_password_hash,
|
||||
set_map_color_thresholds as util_set_map_color_thresholds,
|
||||
)
|
||||
|
||||
|
||||
async def get_password_hash(db: aiosqlite.Connection) -> str | None:
|
||||
"""Return the stored bcrypt password hash, or ``None`` if not set.
|
||||
|
||||
Args:
|
||||
db: Active aiosqlite connection.
|
||||
|
||||
Returns:
|
||||
The bcrypt hash string, or ``None``.
|
||||
"""
|
||||
return await settings_repo.get_setting(db, _KEY_PASSWORD_HASH)
|
||||
"""Return the stored bcrypt password hash, or ``None`` if not set."""
|
||||
return await util_get_password_hash(db)
|
||||
|
||||
|
||||
async def get_timezone(db: aiosqlite.Connection) -> str:
|
||||
"""Return the configured IANA timezone string.
|
||||
|
||||
Falls back to ``"UTC"`` when no timezone has been stored (e.g. before
|
||||
setup completes or for legacy databases).
|
||||
|
||||
Args:
|
||||
db: Active aiosqlite connection.
|
||||
|
||||
Returns:
|
||||
An IANA timezone identifier such as ``"Europe/Berlin"`` or ``"UTC"``.
|
||||
"""
|
||||
"""Return the configured IANA timezone string."""
|
||||
tz = await settings_repo.get_setting(db, _KEY_TIMEZONE)
|
||||
return tz if tz else "UTC"
|
||||
|
||||
@@ -133,31 +123,8 @@ async def get_timezone(db: aiosqlite.Connection) -> str:
|
||||
async def get_map_color_thresholds(
|
||||
db: aiosqlite.Connection,
|
||||
) -> tuple[int, int, int]:
|
||||
"""Return the configured map color thresholds (high, medium, low).
|
||||
|
||||
Falls back to default values (100, 50, 20) if not set.
|
||||
|
||||
Args:
|
||||
db: Active aiosqlite connection.
|
||||
|
||||
Returns:
|
||||
A tuple of (threshold_high, threshold_medium, threshold_low).
|
||||
"""
|
||||
high = await settings_repo.get_setting(
|
||||
db, _KEY_MAP_COLOR_THRESHOLD_HIGH
|
||||
)
|
||||
medium = await settings_repo.get_setting(
|
||||
db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM
|
||||
)
|
||||
low = await settings_repo.get_setting(
|
||||
db, _KEY_MAP_COLOR_THRESHOLD_LOW
|
||||
)
|
||||
|
||||
return (
|
||||
int(high) if high else 100,
|
||||
int(medium) if medium else 50,
|
||||
int(low) if low else 20,
|
||||
)
|
||||
"""Return the configured map color thresholds (high, medium, low)."""
|
||||
return await util_get_map_color_thresholds(db)
|
||||
|
||||
|
||||
async def set_map_color_thresholds(
|
||||
@@ -167,31 +134,12 @@ async def set_map_color_thresholds(
|
||||
threshold_medium: int,
|
||||
threshold_low: int,
|
||||
) -> None:
|
||||
"""Update the map color threshold configuration.
|
||||
|
||||
Args:
|
||||
db: Active aiosqlite connection.
|
||||
threshold_high: Ban count for red coloring.
|
||||
threshold_medium: Ban count for yellow coloring.
|
||||
threshold_low: Ban count for green coloring.
|
||||
|
||||
Raises:
|
||||
ValueError: If thresholds are not positive integers or if
|
||||
high <= medium <= low.
|
||||
"""
|
||||
if threshold_high <= 0 or threshold_medium <= 0 or threshold_low <= 0:
|
||||
raise ValueError("All thresholds must be positive integers.")
|
||||
if not (threshold_high > threshold_medium > threshold_low):
|
||||
raise ValueError("Thresholds must satisfy: high > medium > low.")
|
||||
|
||||
await settings_repo.set_setting(
|
||||
db, _KEY_MAP_COLOR_THRESHOLD_HIGH, str(threshold_high)
|
||||
)
|
||||
await settings_repo.set_setting(
|
||||
db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM, str(threshold_medium)
|
||||
)
|
||||
await settings_repo.set_setting(
|
||||
db, _KEY_MAP_COLOR_THRESHOLD_LOW, str(threshold_low)
|
||||
"""Update the map color threshold configuration."""
|
||||
await util_set_map_color_thresholds(
|
||||
db,
|
||||
threshold_high=threshold_high,
|
||||
threshold_medium=threshold_medium,
|
||||
threshold_low=threshold_low,
|
||||
)
|
||||
log.info(
|
||||
"map_color_thresholds_updated",
|
||||
|
||||
@@ -43,9 +43,15 @@ async def _run_import(app: Any) -> None:
|
||||
http_session = app.state.http_session
|
||||
socket_path: str = app.state.settings.fail2ban_socket
|
||||
|
||||
from app.services import jail_service
|
||||
|
||||
log.info("blocklist_import_starting")
|
||||
try:
|
||||
result = await blocklist_service.import_all(db, http_session, socket_path)
|
||||
result = await blocklist_service.import_all(
|
||||
db,
|
||||
http_session,
|
||||
socket_path,
|
||||
)
|
||||
log.info(
|
||||
"blocklist_import_finished",
|
||||
total_imported=result.total_imported,
|
||||
|
||||
@@ -17,7 +17,7 @@ The task runs every 10 minutes. On each invocation it:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
|
||||
@@ -35,7 +35,7 @@ GEO_RE_RESOLVE_INTERVAL: int = 600
|
||||
JOB_ID: str = "geo_re_resolve"
|
||||
|
||||
|
||||
async def _run_re_resolve(app: Any) -> None:
|
||||
async def _run_re_resolve(app: FastAPI) -> None:
|
||||
"""Query NULL-country IPs from the database and re-resolve them.
|
||||
|
||||
Reads shared resources from ``app.state`` and delegates to
|
||||
@@ -49,12 +49,7 @@ async def _run_re_resolve(app: Any) -> None:
|
||||
http_session = app.state.http_session
|
||||
|
||||
# Fetch all IPs with NULL country_code from the persistent cache.
|
||||
unresolved_ips: list[str] = []
|
||||
async with db.execute(
|
||||
"SELECT ip FROM geo_cache WHERE country_code IS NULL"
|
||||
) as cursor:
|
||||
async for row in cursor:
|
||||
unresolved_ips.append(str(row[0]))
|
||||
unresolved_ips = await geo_service.get_unresolved_ips(db)
|
||||
|
||||
if not unresolved_ips:
|
||||
log.debug("geo_re_resolve_skip", reason="no_unresolved_ips")
|
||||
|
||||
@@ -18,7 +18,7 @@ within 60 seconds of that activation, a
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
|
||||
import structlog
|
||||
|
||||
@@ -31,6 +31,14 @@ if TYPE_CHECKING: # pragma: no cover
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
|
||||
class ActivationRecord(TypedDict):
|
||||
"""Stored timestamp data for a jail activation event."""
|
||||
|
||||
jail_name: str
|
||||
at: datetime.datetime
|
||||
|
||||
|
||||
#: How often the probe fires (seconds).
|
||||
HEALTH_CHECK_INTERVAL: int = 30
|
||||
|
||||
@@ -39,7 +47,7 @@ HEALTH_CHECK_INTERVAL: int = 30
|
||||
_ACTIVATION_CRASH_WINDOW: int = 60
|
||||
|
||||
|
||||
async def _run_probe(app: Any) -> None:
|
||||
async def _run_probe(app: FastAPI) -> None:
|
||||
"""Probe fail2ban and cache the result on *app.state*.
|
||||
|
||||
Detects online/offline state transitions. When fail2ban goes offline
|
||||
@@ -86,7 +94,7 @@ async def _run_probe(app: Any) -> None:
|
||||
elif not status.online and prev_status.online:
|
||||
log.warning("fail2ban_went_offline")
|
||||
# Check whether this crash happened shortly after a jail activation.
|
||||
last_activation: dict[str, Any] | None = getattr(
|
||||
last_activation: ActivationRecord | None = getattr(
|
||||
app.state, "last_activation", None
|
||||
)
|
||||
if last_activation is not None:
|
||||
|
||||
21
backend/app/utils/config_file_utils.py
Normal file
21
backend/app/utils/config_file_utils.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Utilities re-exported from config_file_service for cross-module usage."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from app.services.config_file_service import (
|
||||
_build_inactive_jail,
|
||||
_get_active_jail_names,
|
||||
_ordered_config_files,
|
||||
_parse_jails_sync,
|
||||
_validate_jail_config_sync,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"_ordered_config_files",
|
||||
"_parse_jails_sync",
|
||||
"_build_inactive_jail",
|
||||
"_get_active_jail_names",
|
||||
"_validate_jail_config_sync",
|
||||
]
|
||||
@@ -21,14 +21,52 @@ import contextlib
|
||||
import errno
|
||||
import socket
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence, Set
|
||||
from pickle import HIGHEST_PROTOCOL, dumps, loads
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Use covariant container types so callers can pass ``list[int]`` / ``dict[str, str]``
|
||||
# without needing to cast. At runtime we only accept the basic built-in
|
||||
# containers supported by fail2ban's protocol (list/dict/set) and stringify
|
||||
# anything else.
|
||||
#
|
||||
# NOTE: ``Sequence`` will also accept tuples, but tuples are stringified at
|
||||
# runtime because fail2ban only understands lists.
|
||||
|
||||
type Fail2BanToken = (
|
||||
str
|
||||
| int
|
||||
| float
|
||||
| bool
|
||||
| None
|
||||
| Mapping[str, object]
|
||||
| Sequence[object]
|
||||
| Set[object]
|
||||
)
|
||||
"""A single token in a fail2ban command.
|
||||
|
||||
Fail2ban accepts simple types (str/int/float/bool) plus compound types
|
||||
(list/dict/set). Complex objects are stringified before being sent.
|
||||
"""
|
||||
|
||||
type Fail2BanCommand = Sequence[Fail2BanToken]
|
||||
"""A command sent to fail2ban over the socket.
|
||||
|
||||
Commands are pickle serialised sequences of tokens.
|
||||
"""
|
||||
|
||||
type Fail2BanResponse = tuple[int, object]
|
||||
"""A typical fail2ban response containing a status code and payload."""
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
|
||||
import structlog
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
# fail2ban protocol constants — inline to avoid a hard import dependency
|
||||
@@ -81,9 +119,9 @@ class Fail2BanProtocolError(Exception):
|
||||
|
||||
def _send_command_sync(
|
||||
socket_path: str,
|
||||
command: list[Any],
|
||||
command: Fail2BanCommand,
|
||||
timeout: float,
|
||||
) -> Any:
|
||||
) -> object:
|
||||
"""Send a command to fail2ban and return the parsed response.
|
||||
|
||||
This is a **synchronous** function intended to be called from within
|
||||
@@ -180,7 +218,7 @@ def _send_command_sync(
|
||||
) from last_oserror
|
||||
|
||||
|
||||
def _coerce_command_token(token: Any) -> Any:
|
||||
def _coerce_command_token(token: object) -> Fail2BanToken:
|
||||
"""Coerce a command token to a type that fail2ban understands.
|
||||
|
||||
fail2ban's ``CSocket.convert`` accepts ``str``, ``bool``, ``int``,
|
||||
@@ -229,7 +267,7 @@ class Fail2BanClient:
|
||||
self.socket_path: str = socket_path
|
||||
self.timeout: float = timeout
|
||||
|
||||
async def send(self, command: list[Any]) -> Any:
|
||||
async def send(self, command: Fail2BanCommand) -> object:
|
||||
"""Send a command to fail2ban and return the response.
|
||||
|
||||
Acquires the module-level concurrency semaphore before dispatching
|
||||
@@ -267,13 +305,13 @@ class Fail2BanClient:
|
||||
log.debug("fail2ban_sending_command", command=command)
|
||||
loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
|
||||
try:
|
||||
response: Any = await loop.run_in_executor(
|
||||
None,
|
||||
_send_command_sync,
|
||||
self.socket_path,
|
||||
command,
|
||||
self.timeout,
|
||||
)
|
||||
response: object = await loop.run_in_executor(
|
||||
None,
|
||||
_send_command_sync,
|
||||
self.socket_path,
|
||||
command,
|
||||
self.timeout,
|
||||
)
|
||||
except Fail2BanConnectionError:
|
||||
log.warning(
|
||||
"fail2ban_connection_error",
|
||||
@@ -300,7 +338,7 @@ class Fail2BanClient:
|
||||
``True`` when the daemon responds correctly, ``False`` otherwise.
|
||||
"""
|
||||
try:
|
||||
response: Any = await self.send(["ping"])
|
||||
response: object = await self.send(["ping"])
|
||||
return bool(response == 1) # fail2ban returns 1 on successful ping
|
||||
except (Fail2BanConnectionError, Fail2BanProtocolError):
|
||||
return False
|
||||
|
||||
63
backend/app/utils/fail2ban_db_utils.py
Normal file
63
backend/app/utils/fail2ban_db_utils.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Utilities shared by fail2ban-related services."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
|
||||
|
||||
def ts_to_iso(unix_ts: int) -> str:
|
||||
"""Convert a Unix timestamp to an ISO 8601 UTC string."""
|
||||
return datetime.fromtimestamp(unix_ts, tz=UTC).isoformat()
|
||||
|
||||
|
||||
async def get_fail2ban_db_path(socket_path: str) -> str:
|
||||
"""Query fail2ban for the path to its SQLite database file."""
|
||||
from app.utils.fail2ban_client import Fail2BanClient # pragma: no cover
|
||||
|
||||
socket_timeout: float = 5.0
|
||||
|
||||
async with Fail2BanClient(socket_path, timeout=socket_timeout) as client:
|
||||
response = await client.send(["get", "dbfile"])
|
||||
|
||||
if not isinstance(response, tuple) or len(response) != 2:
|
||||
raise RuntimeError(f"Unexpected response from fail2ban: {response!r}")
|
||||
|
||||
code, data = response
|
||||
if code != 0:
|
||||
raise RuntimeError(f"fail2ban error code {code}: {data!r}")
|
||||
|
||||
if data is None:
|
||||
raise RuntimeError("fail2ban has no database configured (dbfile is None)")
|
||||
|
||||
return str(data)
|
||||
|
||||
|
||||
def parse_data_json(raw: object) -> tuple[list[str], int]:
|
||||
"""Extract matches and failure count from the fail2ban bans.data value."""
|
||||
if raw is None:
|
||||
return [], 0
|
||||
|
||||
obj: dict[str, object] = {}
|
||||
if isinstance(raw, str):
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
if isinstance(parsed, dict):
|
||||
obj = parsed
|
||||
except json.JSONDecodeError:
|
||||
return [], 0
|
||||
elif isinstance(raw, dict):
|
||||
obj = raw
|
||||
|
||||
raw_matches = obj.get("matches")
|
||||
matches = [str(m) for m in raw_matches] if isinstance(raw_matches, list) else []
|
||||
|
||||
raw_failures = obj.get("failures")
|
||||
failures = 0
|
||||
if isinstance(raw_failures, (int, float, str)):
|
||||
try:
|
||||
failures = int(raw_failures)
|
||||
except (ValueError, TypeError):
|
||||
failures = 0
|
||||
|
||||
return matches, failures
|
||||
93
backend/app/utils/jail_config.py
Normal file
93
backend/app/utils/jail_config.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Utilities for ensuring required fail2ban jail configuration files exist.
|
||||
|
||||
BanGUI requires two custom jails — ``manual-Jail`` and ``blocklist-import``
|
||||
— to be present in the fail2ban ``jail.d`` directory. This module provides
|
||||
:func:`ensure_jail_configs` which checks each of the four files
|
||||
(``*.conf`` template + ``*.local`` override) and creates any that are missing
|
||||
with the correct default content.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Default file contents
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_MANUAL_JAIL_CONF = """\
|
||||
[manual-Jail]
|
||||
|
||||
enabled = false
|
||||
filter = manual-Jail
|
||||
logpath = /remotelogs/bangui/auth.log
|
||||
backend = polling
|
||||
maxretry = 3
|
||||
findtime = 120
|
||||
bantime = 60
|
||||
ignoreip = 127.0.0.0/8 ::1 172.16.0.0/12
|
||||
"""
|
||||
|
||||
_MANUAL_JAIL_LOCAL = """\
|
||||
[manual-Jail]
|
||||
enabled = true
|
||||
"""
|
||||
|
||||
_BLOCKLIST_IMPORT_CONF = """\
|
||||
[blocklist-import]
|
||||
|
||||
enabled = false
|
||||
filter =
|
||||
logpath = /dev/null
|
||||
backend = auto
|
||||
maxretry = 1
|
||||
findtime = 1d
|
||||
bantime = 86400
|
||||
ignoreip = 127.0.0.0/8 ::1 172.16.0.0/12
|
||||
"""
|
||||
|
||||
_BLOCKLIST_IMPORT_LOCAL = """\
|
||||
[blocklist-import]
|
||||
enabled = true
|
||||
"""
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# File registry: (filename, default_content)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_JAIL_FILES: list[tuple[str, str]] = [
|
||||
("manual-Jail.conf", _MANUAL_JAIL_CONF),
|
||||
("manual-Jail.local", _MANUAL_JAIL_LOCAL),
|
||||
("blocklist-import.conf", _BLOCKLIST_IMPORT_CONF),
|
||||
("blocklist-import.local", _BLOCKLIST_IMPORT_LOCAL),
|
||||
]
|
||||
|
||||
|
||||
def ensure_jail_configs(jail_d_path: Path) -> None:
|
||||
"""Ensure the required fail2ban jail configuration files exist.
|
||||
|
||||
Checks for ``manual-Jail.conf``, ``manual-Jail.local``,
|
||||
``blocklist-import.conf``, and ``blocklist-import.local`` inside
|
||||
*jail_d_path*. Any file that is missing is created with its default
|
||||
content. Existing files are **never** overwritten.
|
||||
|
||||
Args:
|
||||
jail_d_path: Path to the fail2ban ``jail.d`` directory. Will be
|
||||
created (including all parents) if it does not already exist.
|
||||
"""
|
||||
jail_d_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for filename, default_content in _JAIL_FILES:
|
||||
file_path = jail_d_path / filename
|
||||
if file_path.exists():
|
||||
log.debug("jail_config_already_exists", path=str(file_path))
|
||||
else:
|
||||
file_path.write_text(default_content, encoding="utf-8")
|
||||
log.info("jail_config_created", path=str(file_path))
|
||||
20
backend/app/utils/jail_utils.py
Normal file
20
backend/app/utils/jail_utils.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Jail helpers to decouple service layer dependencies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from app.services.jail_service import reload_all
|
||||
|
||||
|
||||
async def reload_jails(
|
||||
socket_path: str,
|
||||
include_jails: Sequence[str] | None = None,
|
||||
exclude_jails: Sequence[str] | None = None,
|
||||
) -> None:
|
||||
"""Reload fail2ban jails using shared jail service helper."""
|
||||
await reload_all(
|
||||
socket_path,
|
||||
include_jails=list(include_jails) if include_jails is not None else None,
|
||||
exclude_jails=list(exclude_jails) if exclude_jails is not None else None,
|
||||
)
|
||||
14
backend/app/utils/log_utils.py
Normal file
14
backend/app/utils/log_utils.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Log-related helpers to avoid direct service-to-service imports."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.models.config import LogPreviewRequest, LogPreviewResponse, RegexTestRequest, RegexTestResponse
|
||||
from app.services.log_service import preview_log as _preview_log, test_regex as _test_regex
|
||||
|
||||
|
||||
async def preview_log(req: LogPreviewRequest) -> LogPreviewResponse:
|
||||
return await _preview_log(req)
|
||||
|
||||
|
||||
def test_regex(req: RegexTestRequest) -> RegexTestResponse:
|
||||
return _test_regex(req)
|
||||
47
backend/app/utils/setup_utils.py
Normal file
47
backend/app/utils/setup_utils.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Setup-related utilities shared by multiple services."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.repositories import settings_repo
|
||||
|
||||
_KEY_PASSWORD_HASH = "master_password_hash"
|
||||
_KEY_SETUP_DONE = "setup_completed"
|
||||
_KEY_MAP_COLOR_THRESHOLD_HIGH = "map_color_threshold_high"
|
||||
_KEY_MAP_COLOR_THRESHOLD_MEDIUM = "map_color_threshold_medium"
|
||||
_KEY_MAP_COLOR_THRESHOLD_LOW = "map_color_threshold_low"
|
||||
|
||||
|
||||
async def get_password_hash(db):
|
||||
"""Return the stored master password hash or None."""
|
||||
return await settings_repo.get_setting(db, _KEY_PASSWORD_HASH)
|
||||
|
||||
|
||||
async def get_map_color_thresholds(db):
|
||||
"""Return map color thresholds as tuple (high, medium, low)."""
|
||||
high = await settings_repo.get_setting(db, _KEY_MAP_COLOR_THRESHOLD_HIGH)
|
||||
medium = await settings_repo.get_setting(db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM)
|
||||
low = await settings_repo.get_setting(db, _KEY_MAP_COLOR_THRESHOLD_LOW)
|
||||
|
||||
return (
|
||||
int(high) if high else 100,
|
||||
int(medium) if medium else 50,
|
||||
int(low) if low else 20,
|
||||
)
|
||||
|
||||
|
||||
async def set_map_color_thresholds(
|
||||
db,
|
||||
*,
|
||||
threshold_high: int,
|
||||
threshold_medium: int,
|
||||
threshold_low: int,
|
||||
) -> None:
|
||||
"""Persist map color thresholds after validating values."""
|
||||
if threshold_high <= 0 or threshold_medium <= 0 or threshold_low <= 0:
|
||||
raise ValueError("All thresholds must be positive integers.")
|
||||
if not (threshold_high > threshold_medium > threshold_low):
|
||||
raise ValueError("Thresholds must satisfy: high > medium > low.")
|
||||
|
||||
await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_HIGH, str(threshold_high))
|
||||
await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM, str(threshold_medium))
|
||||
await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_LOW, str(threshold_low))
|
||||
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "bangui-backend"
|
||||
version = "0.1.0"
|
||||
version = "0.9.8"
|
||||
description = "BanGUI backend — fail2ban web management interface"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
@@ -60,4 +60,5 @@ plugins = ["pydantic.mypy"]
|
||||
asyncio_mode = "auto"
|
||||
pythonpath = [".", "../fail2ban-master"]
|
||||
testpaths = ["tests"]
|
||||
addopts = "--cov=app --cov-report=term-missing"
|
||||
addopts = "--asyncio-mode=auto --cov=app --cov-report=term-missing"
|
||||
filterwarnings = ["ignore::pytest.PytestRemovedIn9Warning"]
|
||||
|
||||
@@ -37,9 +37,15 @@ def test_settings(tmp_path: Path) -> Settings:
|
||||
Returns:
|
||||
A :class:`~app.config.Settings` instance with overridden paths.
|
||||
"""
|
||||
config_dir = tmp_path / "fail2ban"
|
||||
(config_dir / "jail.d").mkdir(parents=True)
|
||||
(config_dir / "filter.d").mkdir(parents=True)
|
||||
(config_dir / "action.d").mkdir(parents=True)
|
||||
|
||||
return Settings(
|
||||
database_path=str(tmp_path / "test_bangui.db"),
|
||||
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||
fail2ban_config_dir=str(config_dir),
|
||||
session_secret="test-secret-key-do-not-use-in-production",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
|
||||
276
backend/tests/test_regression_500s.py
Normal file
276
backend/tests/test_regression_500s.py
Normal file
@@ -0,0 +1,276 @@
|
||||
"""Regression tests for the four 500-error bugs discovered on 2026-03-22.
|
||||
|
||||
Each test targets the exact code path that caused a 500 Internal Server Error.
|
||||
These tests call the **real** service/repository functions (not the router)
|
||||
so they fail even if the route layer is mocked in router-level tests.
|
||||
|
||||
Bugs covered:
|
||||
1. ``list_history`` rejected the ``origin`` keyword argument (TypeError).
|
||||
2. ``jail_config_service`` used ``_get_active_jail_names`` without importing it.
|
||||
3. ``filter_config_service`` used ``_parse_jails_sync`` / ``_get_active_jail_names``
|
||||
without importing them.
|
||||
4. ``config_service.get_service_status`` omitted the required ``bangui_version``
|
||||
field from the ``ServiceStatusResponse`` constructor (Pydantic ValidationError).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
|
||||
# ── Bug 1 ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestHistoryOriginParameter:
|
||||
"""Bug 1: ``origin`` parameter must be threaded through service → repo."""
|
||||
|
||||
# -- Service layer --
|
||||
|
||||
async def test_list_history_accepts_origin_kwarg(self) -> None:
|
||||
"""``history_service.list_history()`` must accept an ``origin`` keyword."""
|
||||
from app.services import history_service
|
||||
|
||||
sig = inspect.signature(history_service.list_history)
|
||||
assert "origin" in sig.parameters, (
|
||||
"list_history() is missing the 'origin' parameter — "
|
||||
"the router passes origin=… which would cause a TypeError"
|
||||
)
|
||||
|
||||
async def test_list_history_forwards_origin_to_repo(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""``list_history(origin='blocklist')`` must forward origin to the DB repo."""
|
||||
from app.services import history_service
|
||||
|
||||
db_path = str(tmp_path / "f2b.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await db.execute(
|
||||
"CREATE TABLE jails (name TEXT, enabled INTEGER DEFAULT 1)"
|
||||
)
|
||||
await db.execute(
|
||||
"CREATE TABLE bans "
|
||||
"(jail TEXT, ip TEXT, timeofban INTEGER, bantime INTEGER, "
|
||||
"bancount INTEGER DEFAULT 1, data JSON)"
|
||||
)
|
||||
await db.execute(
|
||||
"INSERT INTO bans VALUES (?, ?, ?, ?, ?, ?)",
|
||||
("blocklist-import", "10.0.0.1", int(time.time()), 3600, 1, "{}"),
|
||||
)
|
||||
await db.execute(
|
||||
"INSERT INTO bans VALUES (?, ?, ?, ?, ?, ?)",
|
||||
("sshd", "10.0.0.2", int(time.time()), 3600, 1, "{}"),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
with patch(
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=db_path),
|
||||
):
|
||||
result = await history_service.list_history(
|
||||
"fake_socket", origin="blocklist"
|
||||
)
|
||||
|
||||
assert all(
|
||||
item.jail == "blocklist-import" for item in result.items
|
||||
), "origin='blocklist' must filter to blocklist-import jail only"
|
||||
|
||||
# -- Repository layer --
|
||||
|
||||
async def test_get_history_page_accepts_origin_kwarg(self) -> None:
|
||||
"""``fail2ban_db_repo.get_history_page()`` must accept ``origin``."""
|
||||
from app.repositories import fail2ban_db_repo
|
||||
|
||||
sig = inspect.signature(fail2ban_db_repo.get_history_page)
|
||||
assert "origin" in sig.parameters, (
|
||||
"get_history_page() is missing the 'origin' parameter"
|
||||
)
|
||||
|
||||
async def test_get_history_page_filters_by_origin(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""``get_history_page(origin='selfblock')`` excludes blocklist-import."""
|
||||
from app.repositories import fail2ban_db_repo
|
||||
|
||||
db_path = str(tmp_path / "f2b.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await db.execute(
|
||||
"CREATE TABLE bans "
|
||||
"(jail TEXT, ip TEXT, timeofban INTEGER, bancount INTEGER, data TEXT)"
|
||||
)
|
||||
await db.executemany(
|
||||
"INSERT INTO bans VALUES (?, ?, ?, ?, ?)",
|
||||
[
|
||||
("blocklist-import", "10.0.0.1", 100, 1, "{}"),
|
||||
("sshd", "10.0.0.2", 200, 1, "{}"),
|
||||
("sshd", "10.0.0.3", 300, 1, "{}"),
|
||||
],
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
rows, total = await fail2ban_db_repo.get_history_page(
|
||||
db_path=db_path, origin="selfblock"
|
||||
)
|
||||
|
||||
assert total == 2
|
||||
assert all(r.jail != "blocklist-import" for r in rows)
|
||||
|
||||
|
||||
# ── Bug 2 ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestJailConfigImports:
|
||||
"""Bug 2: ``jail_config_service`` must import ``_get_active_jail_names``."""
|
||||
|
||||
async def test_get_active_jail_names_is_importable(self) -> None:
|
||||
"""The module must successfully import ``_get_active_jail_names``."""
|
||||
import app.services.jail_config_service as mod
|
||||
|
||||
assert hasattr(mod, "_get_active_jail_names") or callable(
|
||||
getattr(mod, "_get_active_jail_names", None)
|
||||
), (
|
||||
"_get_active_jail_names is not available in jail_config_service — "
|
||||
"any call site will raise NameError → 500"
|
||||
)
|
||||
|
||||
async def test_list_inactive_jails_does_not_raise_name_error(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""``list_inactive_jails`` must not crash with NameError."""
|
||||
from app.services import jail_config_service
|
||||
|
||||
config_dir = str(tmp_path / "fail2ban")
|
||||
Path(config_dir).mkdir()
|
||||
(Path(config_dir) / "jail.conf").write_text("[DEFAULT]\n")
|
||||
|
||||
with patch(
|
||||
"app.services.jail_config_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
):
|
||||
result = await jail_config_service.list_inactive_jails(
|
||||
config_dir, "/fake/socket"
|
||||
)
|
||||
|
||||
assert result.total >= 0
|
||||
|
||||
|
||||
# ── Bug 3 ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestFilterConfigImports:
|
||||
"""Bug 3: ``filter_config_service`` must import ``_parse_jails_sync``
|
||||
and ``_get_active_jail_names``."""
|
||||
|
||||
async def test_parse_jails_sync_is_available(self) -> None:
|
||||
"""``_parse_jails_sync`` must be resolvable at module scope."""
|
||||
import app.services.filter_config_service as mod
|
||||
|
||||
assert hasattr(mod, "_parse_jails_sync"), (
|
||||
"_parse_jails_sync is not available in filter_config_service — "
|
||||
"list_filters() will raise NameError → 500"
|
||||
)
|
||||
|
||||
async def test_get_active_jail_names_is_available(self) -> None:
|
||||
"""``_get_active_jail_names`` must be resolvable at module scope."""
|
||||
import app.services.filter_config_service as mod
|
||||
|
||||
assert hasattr(mod, "_get_active_jail_names"), (
|
||||
"_get_active_jail_names is not available in filter_config_service — "
|
||||
"list_filters() will raise NameError → 500"
|
||||
)
|
||||
|
||||
async def test_list_filters_does_not_raise_name_error(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""``list_filters`` must not crash with NameError."""
|
||||
from app.services import filter_config_service
|
||||
|
||||
config_dir = str(tmp_path / "fail2ban")
|
||||
filter_d = Path(config_dir) / "filter.d"
|
||||
filter_d.mkdir(parents=True)
|
||||
|
||||
# Create a minimal filter file so _parse_filters_sync has something to scan.
|
||||
(filter_d / "sshd.conf").write_text(
|
||||
"[Definition]\nfailregex = ^Failed password\n"
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.filter_config_service._parse_jails_sync",
|
||||
return_value=({}, {}),
|
||||
),
|
||||
patch(
|
||||
"app.services.filter_config_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
),
|
||||
):
|
||||
result = await filter_config_service.list_filters(
|
||||
config_dir, "/fake/socket"
|
||||
)
|
||||
|
||||
assert result.total >= 0
|
||||
|
||||
|
||||
# ── Bug 4 ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestServiceStatusBanguiVersion:
|
||||
"""Bug 4: ``get_service_status`` must include application version
|
||||
in the ``version`` field of the ``ServiceStatusResponse``."""
|
||||
|
||||
async def test_online_response_contains_bangui_version(self) -> None:
|
||||
"""The returned model must contain the ``bangui_version`` field."""
|
||||
from app.models.server import ServerStatus
|
||||
from app.services import config_service
|
||||
import app
|
||||
|
||||
online_status = ServerStatus(
|
||||
online=True,
|
||||
version="1.0.0",
|
||||
active_jails=2,
|
||||
total_bans=5,
|
||||
total_failures=3,
|
||||
)
|
||||
|
||||
async def _send(command: list[Any]) -> Any:
|
||||
key = "|".join(str(c) for c in command)
|
||||
if key == "get|loglevel":
|
||||
return (0, "INFO")
|
||||
if key == "get|logtarget":
|
||||
return (0, "/var/log/fail2ban.log")
|
||||
return (0, None)
|
||||
|
||||
class _FakeClient:
|
||||
def __init__(self, **_kw: Any) -> None:
|
||||
self.send = AsyncMock(side_effect=_send)
|
||||
|
||||
with patch("app.services.config_service.Fail2BanClient", _FakeClient):
|
||||
result = await config_service.get_service_status(
|
||||
"/fake/socket",
|
||||
probe_fn=AsyncMock(return_value=online_status),
|
||||
)
|
||||
|
||||
assert result.version == app.__version__, (
|
||||
"ServiceStatusResponse must expose BanGUI version in version field"
|
||||
)
|
||||
|
||||
async def test_offline_response_contains_bangui_version(self) -> None:
|
||||
"""Even when fail2ban is offline, ``bangui_version`` must be present."""
|
||||
from app.models.server import ServerStatus
|
||||
from app.services import config_service
|
||||
import app
|
||||
|
||||
offline_status = ServerStatus(online=False)
|
||||
|
||||
result = await config_service.get_service_status(
|
||||
"/fake/socket",
|
||||
probe_fn=AsyncMock(return_value=offline_status),
|
||||
)
|
||||
|
||||
assert result.version == app.__version__
|
||||
167
backend/tests/test_repositories/test_fail2ban_db_repo.py
Normal file
167
backend/tests/test_repositories/test_fail2ban_db_repo.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""Tests for the fail2ban_db repository.
|
||||
|
||||
These tests use an in-memory sqlite file created under pytest's tmp_path and
|
||||
exercise the core query functions used by the services.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
|
||||
from app.repositories import fail2ban_db_repo
|
||||
|
||||
|
||||
async def _create_bans_table(db: aiosqlite.Connection) -> None:
|
||||
await db.execute(
|
||||
"""
|
||||
CREATE TABLE bans (
|
||||
jail TEXT,
|
||||
ip TEXT,
|
||||
timeofban INTEGER,
|
||||
bancount INTEGER,
|
||||
data TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_db_nonempty_returns_false_when_table_is_empty(tmp_path: Path) -> None:
|
||||
db_path = str(tmp_path / "fail2ban.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await _create_bans_table(db)
|
||||
|
||||
assert await fail2ban_db_repo.check_db_nonempty(db_path) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_db_nonempty_returns_true_when_row_exists(tmp_path: Path) -> None:
|
||||
db_path = str(tmp_path / "fail2ban.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await _create_bans_table(db)
|
||||
await db.execute(
|
||||
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
|
||||
("jail1", "1.2.3.4", 123, 1, "{}"),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
assert await fail2ban_db_repo.check_db_nonempty(db_path) is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_currently_banned_filters_and_pagination(tmp_path: Path) -> None:
|
||||
db_path = str(tmp_path / "fail2ban.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await _create_bans_table(db)
|
||||
# Three bans; one is from the blocklist-import jail.
|
||||
await db.executemany(
|
||||
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
|
||||
[
|
||||
("jail1", "1.1.1.1", 10, 1, "{}"),
|
||||
("blocklist-import", "2.2.2.2", 20, 2, "{}"),
|
||||
("jail1", "3.3.3.3", 30, 3, "{}"),
|
||||
],
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
records, total = await fail2ban_db_repo.get_currently_banned(
|
||||
db_path=db_path,
|
||||
since=15,
|
||||
origin="selfblock",
|
||||
limit=10,
|
||||
offset=0,
|
||||
)
|
||||
|
||||
# Only the non-blocklist row with timeofban >= 15 should remain.
|
||||
assert total == 1
|
||||
assert len(records) == 1
|
||||
assert records[0].ip == "3.3.3.3"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_ban_counts_by_bucket_ignores_out_of_range_buckets(tmp_path: Path) -> None:
|
||||
db_path = str(tmp_path / "fail2ban.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await _create_bans_table(db)
|
||||
await db.executemany(
|
||||
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
|
||||
[
|
||||
("jail1", "1.1.1.1", 5, 1, "{}"),
|
||||
("jail1", "2.2.2.2", 15, 1, "{}"),
|
||||
("jail1", "3.3.3.3", 35, 1, "{}"),
|
||||
],
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
counts = await fail2ban_db_repo.get_ban_counts_by_bucket(
|
||||
db_path=db_path,
|
||||
since=0,
|
||||
bucket_secs=10,
|
||||
num_buckets=3,
|
||||
)
|
||||
|
||||
assert counts == [1, 1, 0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_history_page_and_for_ip(tmp_path: Path) -> None:
|
||||
db_path = str(tmp_path / "fail2ban.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await _create_bans_table(db)
|
||||
await db.executemany(
|
||||
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
|
||||
[
|
||||
("jail1", "1.1.1.1", 100, 1, "{}"),
|
||||
("jail1", "1.1.1.1", 200, 2, "{}"),
|
||||
("jail1", "2.2.2.2", 300, 3, "{}"),
|
||||
],
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
page, total = await fail2ban_db_repo.get_history_page(
|
||||
db_path=db_path,
|
||||
since=None,
|
||||
jail="jail1",
|
||||
ip_filter="1.1.1",
|
||||
page=1,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
assert total == 2
|
||||
assert len(page) == 2
|
||||
assert page[0].ip == "1.1.1.1"
|
||||
|
||||
history_for_ip = await fail2ban_db_repo.get_history_for_ip(db_path=db_path, ip="2.2.2.2")
|
||||
assert len(history_for_ip) == 1
|
||||
assert history_for_ip[0].ip == "2.2.2.2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_history_page_origin_filter(tmp_path: Path) -> None:
|
||||
db_path = str(tmp_path / "fail2ban.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await _create_bans_table(db)
|
||||
await db.executemany(
|
||||
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
|
||||
[
|
||||
("jail1", "1.1.1.1", 100, 1, "{}"),
|
||||
("blocklist-import", "2.2.2.2", 200, 1, "{}"),
|
||||
],
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
page, total = await fail2ban_db_repo.get_history_page(
|
||||
db_path=db_path,
|
||||
since=None,
|
||||
jail=None,
|
||||
ip_filter=None,
|
||||
origin="selfblock",
|
||||
page=1,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
assert total == 1
|
||||
assert len(page) == 1
|
||||
assert page[0].ip == "1.1.1.1"
|
||||
140
backend/tests/test_repositories/test_geo_cache_repo.py
Normal file
140
backend/tests/test_repositories/test_geo_cache_repo.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""Tests for the geo cache repository."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
|
||||
from app.repositories import geo_cache_repo
|
||||
|
||||
|
||||
async def _create_geo_cache_table(db: aiosqlite.Connection) -> None:
|
||||
await db.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS geo_cache (
|
||||
ip TEXT PRIMARY KEY,
|
||||
country_code TEXT,
|
||||
country_name TEXT,
|
||||
asn TEXT,
|
||||
org TEXT,
|
||||
cached_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))
|
||||
)
|
||||
"""
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_unresolved_ips_returns_empty_when_none_exist(tmp_path: Path) -> None:
|
||||
db_path = str(tmp_path / "geo_cache.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await _create_geo_cache_table(db)
|
||||
await db.execute(
|
||||
"INSERT INTO geo_cache (ip, country_code, country_name, asn, org) VALUES (?, ?, ?, ?, ?)",
|
||||
("1.1.1.1", "DE", "Germany", "AS123", "Test"),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
ips = await geo_cache_repo.get_unresolved_ips(db)
|
||||
|
||||
assert ips == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_unresolved_ips_returns_pending_ips(tmp_path: Path) -> None:
|
||||
db_path = str(tmp_path / "geo_cache.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await _create_geo_cache_table(db)
|
||||
await db.executemany(
|
||||
"INSERT INTO geo_cache (ip, country_code) VALUES (?, ?)",
|
||||
[
|
||||
("2.2.2.2", None),
|
||||
("3.3.3.3", None),
|
||||
("4.4.4.4", "US"),
|
||||
],
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
ips = await geo_cache_repo.get_unresolved_ips(db)
|
||||
|
||||
assert sorted(ips) == ["2.2.2.2", "3.3.3.3"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_all_and_count_unresolved(tmp_path: Path) -> None:
|
||||
db_path = str(tmp_path / "geo_cache.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await _create_geo_cache_table(db)
|
||||
await db.executemany(
|
||||
"INSERT INTO geo_cache (ip, country_code, country_name, asn, org) VALUES (?, ?, ?, ?, ?)",
|
||||
[
|
||||
("5.5.5.5", None, None, None, None),
|
||||
("6.6.6.6", "FR", "France", "AS456", "TestOrg"),
|
||||
],
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
rows = await geo_cache_repo.load_all(db)
|
||||
unresolved = await geo_cache_repo.count_unresolved(db)
|
||||
|
||||
assert unresolved == 1
|
||||
assert any(row["ip"] == "6.6.6.6" and row["country_code"] == "FR" for row in rows)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_entry_and_neg_entry(tmp_path: Path) -> None:
|
||||
db_path = str(tmp_path / "geo_cache.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await _create_geo_cache_table(db)
|
||||
|
||||
await geo_cache_repo.upsert_entry(
|
||||
db,
|
||||
"7.7.7.7",
|
||||
"GB",
|
||||
"United Kingdom",
|
||||
"AS789",
|
||||
"TestOrg",
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
await geo_cache_repo.upsert_neg_entry(db, "8.8.8.8")
|
||||
await db.commit()
|
||||
|
||||
# Ensure positive entry is present.
|
||||
async with db.execute("SELECT country_code FROM geo_cache WHERE ip = ?", ("7.7.7.7",)) as cur:
|
||||
row = await cur.fetchone()
|
||||
assert row is not None
|
||||
assert row[0] == "GB"
|
||||
|
||||
# Ensure negative entry exists with NULL country_code.
|
||||
async with db.execute("SELECT country_code FROM geo_cache WHERE ip = ?", ("8.8.8.8",)) as cur:
|
||||
row = await cur.fetchone()
|
||||
assert row is not None
|
||||
assert row[0] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_upsert_entries_and_neg_entries(tmp_path: Path) -> None:
|
||||
db_path = str(tmp_path / "geo_cache.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await _create_geo_cache_table(db)
|
||||
|
||||
rows = [
|
||||
("9.9.9.9", "NL", "Netherlands", "AS101", "Test"),
|
||||
("10.10.10.10", "JP", "Japan", "AS102", "Test"),
|
||||
]
|
||||
count = await geo_cache_repo.bulk_upsert_entries(db, rows)
|
||||
assert count == 2
|
||||
|
||||
neg_count = await geo_cache_repo.bulk_upsert_neg_entries(db, ["11.11.11.11", "12.12.12.12"])
|
||||
assert neg_count == 2
|
||||
|
||||
await db.commit()
|
||||
|
||||
async with db.execute("SELECT COUNT(*) FROM geo_cache") as cur:
|
||||
row = await cur.fetchone()
|
||||
assert row is not None
|
||||
assert int(row[0]) == 4
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -157,12 +158,12 @@ class TestRequireAuthSessionCache:
|
||||
"""In-memory session token cache inside ``require_auth``."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_cache(self) -> None: # type: ignore[misc]
|
||||
def reset_cache(self) -> Generator[None, None, None]:
|
||||
"""Flush the session cache before and after every test in this class."""
|
||||
from app import dependencies
|
||||
|
||||
dependencies.clear_session_cache()
|
||||
yield # type: ignore[misc]
|
||||
yield
|
||||
dependencies.clear_session_cache()
|
||||
|
||||
async def test_second_request_skips_db(self, client: AsyncClient) -> None:
|
||||
|
||||
@@ -9,6 +9,8 @@ import aiosqlite
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
import app
|
||||
|
||||
from app.config import Settings
|
||||
from app.db import init_db
|
||||
from app.main import create_app
|
||||
@@ -370,6 +372,124 @@ class TestReloadFail2ban:
|
||||
|
||||
assert resp.status_code == 204
|
||||
|
||||
async def test_502_when_fail2ban_unreachable(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/reload returns 502 when fail2ban socket is unreachable."""
|
||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.jail_service.reload_all",
|
||||
AsyncMock(side_effect=Fail2BanConnectionError("no socket", "/fake.sock")),
|
||||
):
|
||||
resp = await config_client.post("/api/config/reload")
|
||||
|
||||
assert resp.status_code == 502
|
||||
|
||||
async def test_409_when_reload_operation_fails(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/reload returns 409 when fail2ban reports a reload error."""
|
||||
from app.services.jail_service import JailOperationError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.jail_service.reload_all",
|
||||
AsyncMock(side_effect=JailOperationError("reload rejected")),
|
||||
):
|
||||
resp = await config_client.post("/api/config/reload")
|
||||
|
||||
assert resp.status_code == 409
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/config/restart
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRestartFail2ban:
|
||||
"""Tests for ``POST /api/config/restart``."""
|
||||
|
||||
async def test_204_on_success(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/restart returns 204 when fail2ban restarts cleanly."""
|
||||
with (
|
||||
patch(
|
||||
"app.routers.config.jail_service.restart",
|
||||
AsyncMock(return_value=None),
|
||||
),
|
||||
patch(
|
||||
"app.routers.config.config_file_service.start_daemon",
|
||||
AsyncMock(return_value=True),
|
||||
),
|
||||
patch(
|
||||
"app.routers.config.config_file_service.wait_for_fail2ban",
|
||||
AsyncMock(return_value=True),
|
||||
),
|
||||
):
|
||||
resp = await config_client.post("/api/config/restart")
|
||||
|
||||
assert resp.status_code == 204
|
||||
|
||||
async def test_503_when_fail2ban_does_not_come_back(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/restart returns 503 when fail2ban does not come back online."""
|
||||
with (
|
||||
patch(
|
||||
"app.routers.config.jail_service.restart",
|
||||
AsyncMock(return_value=None),
|
||||
),
|
||||
patch(
|
||||
"app.routers.config.config_file_service.start_daemon",
|
||||
AsyncMock(return_value=True),
|
||||
),
|
||||
patch(
|
||||
"app.routers.config.config_file_service.wait_for_fail2ban",
|
||||
AsyncMock(return_value=False),
|
||||
),
|
||||
):
|
||||
resp = await config_client.post("/api/config/restart")
|
||||
|
||||
assert resp.status_code == 503
|
||||
|
||||
async def test_409_when_stop_command_fails(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/restart returns 409 when fail2ban rejects the stop command."""
|
||||
from app.services.jail_service import JailOperationError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.jail_service.restart",
|
||||
AsyncMock(side_effect=JailOperationError("stop failed")),
|
||||
):
|
||||
resp = await config_client.post("/api/config/restart")
|
||||
|
||||
assert resp.status_code == 409
|
||||
|
||||
async def test_502_when_fail2ban_unreachable(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/restart returns 502 when fail2ban socket is unreachable."""
|
||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.jail_service.restart",
|
||||
AsyncMock(side_effect=Fail2BanConnectionError("no socket", "/fake.sock")),
|
||||
):
|
||||
resp = await config_client.post("/api/config/restart")
|
||||
|
||||
assert resp.status_code == 502
|
||||
|
||||
async def test_start_daemon_called_after_stop(self, config_client: AsyncClient) -> None:
|
||||
"""start_daemon is called after a successful stop."""
|
||||
mock_start = AsyncMock(return_value=True)
|
||||
with (
|
||||
patch(
|
||||
"app.routers.config.jail_service.restart",
|
||||
AsyncMock(return_value=None),
|
||||
),
|
||||
patch(
|
||||
"app.routers.config.config_file_service.start_daemon",
|
||||
mock_start,
|
||||
),
|
||||
patch(
|
||||
"app.routers.config.config_file_service.wait_for_fail2ban",
|
||||
AsyncMock(return_value=True),
|
||||
),
|
||||
):
|
||||
await config_client.post("/api/config/restart")
|
||||
|
||||
mock_start.assert_awaited_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/config/regex-test
|
||||
@@ -383,7 +503,7 @@ class TestRegexTest:
|
||||
"""POST /api/config/regex-test returns matched=true for a valid match."""
|
||||
mock_response = RegexTestResponse(matched=True, groups=["1.2.3.4"], error=None)
|
||||
with patch(
|
||||
"app.routers.config.config_service.test_regex",
|
||||
"app.routers.config.log_service.test_regex",
|
||||
return_value=mock_response,
|
||||
):
|
||||
resp = await config_client.post(
|
||||
@@ -401,7 +521,7 @@ class TestRegexTest:
|
||||
"""POST /api/config/regex-test returns matched=false for no match."""
|
||||
mock_response = RegexTestResponse(matched=False, groups=[], error=None)
|
||||
with patch(
|
||||
"app.routers.config.config_service.test_regex",
|
||||
"app.routers.config.log_service.test_regex",
|
||||
return_value=mock_response,
|
||||
):
|
||||
resp = await config_client.post(
|
||||
@@ -479,7 +599,7 @@ class TestPreviewLog:
|
||||
matched_count=1,
|
||||
)
|
||||
with patch(
|
||||
"app.routers.config.config_service.preview_log",
|
||||
"app.routers.config.log_service.preview_log",
|
||||
AsyncMock(return_value=mock_response),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
@@ -607,7 +727,7 @@ class TestGetInactiveJails:
|
||||
mock_response = InactiveJailListResponse(jails=[mock_jail], total=1)
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.list_inactive_jails",
|
||||
"app.routers.config.jail_config_service.list_inactive_jails",
|
||||
AsyncMock(return_value=mock_response),
|
||||
):
|
||||
resp = await config_client.get("/api/config/jails/inactive")
|
||||
@@ -622,7 +742,7 @@ class TestGetInactiveJails:
|
||||
from app.models.config import InactiveJailListResponse
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.list_inactive_jails",
|
||||
"app.routers.config.jail_config_service.list_inactive_jails",
|
||||
AsyncMock(return_value=InactiveJailListResponse(jails=[], total=0)),
|
||||
):
|
||||
resp = await config_client.get("/api/config/jails/inactive")
|
||||
@@ -658,7 +778,7 @@ class TestActivateJail:
|
||||
message="Jail 'apache-auth' activated successfully.",
|
||||
)
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.activate_jail",
|
||||
"app.routers.config.jail_config_service.activate_jail",
|
||||
AsyncMock(return_value=mock_response),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
@@ -678,7 +798,7 @@ class TestActivateJail:
|
||||
name="apache-auth", active=True, message="Activated."
|
||||
)
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.activate_jail",
|
||||
"app.routers.config.jail_config_service.activate_jail",
|
||||
AsyncMock(return_value=mock_response),
|
||||
) as mock_activate:
|
||||
resp = await config_client.post(
|
||||
@@ -694,10 +814,10 @@ class TestActivateJail:
|
||||
|
||||
async def test_404_for_unknown_jail(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/jails/missing/activate returns 404."""
|
||||
from app.services.config_file_service import JailNotFoundInConfigError
|
||||
from app.services.jail_config_service import JailNotFoundInConfigError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.activate_jail",
|
||||
"app.routers.config.jail_config_service.activate_jail",
|
||||
AsyncMock(side_effect=JailNotFoundInConfigError("missing")),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
@@ -708,10 +828,10 @@ class TestActivateJail:
|
||||
|
||||
async def test_409_when_already_active(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/jails/sshd/activate returns 409 if already active."""
|
||||
from app.services.config_file_service import JailAlreadyActiveError
|
||||
from app.services.jail_config_service import JailAlreadyActiveError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.activate_jail",
|
||||
"app.routers.config.jail_config_service.activate_jail",
|
||||
AsyncMock(side_effect=JailAlreadyActiveError("sshd")),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
@@ -722,10 +842,10 @@ class TestActivateJail:
|
||||
|
||||
async def test_400_for_invalid_jail_name(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/jails/ with bad name returns 400."""
|
||||
from app.services.config_file_service import JailNameError
|
||||
from app.services.jail_config_service import JailNameError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.activate_jail",
|
||||
"app.routers.config.jail_config_service.activate_jail",
|
||||
AsyncMock(side_effect=JailNameError("bad name")),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
@@ -754,7 +874,7 @@ class TestActivateJail:
|
||||
message="Jail 'airsonic-auth' cannot be activated: log file '/var/log/airsonic/airsonic.log' not found",
|
||||
)
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.activate_jail",
|
||||
"app.routers.config.jail_config_service.activate_jail",
|
||||
AsyncMock(return_value=blocked_response),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
@@ -787,7 +907,7 @@ class TestDeactivateJail:
|
||||
message="Jail 'sshd' deactivated successfully.",
|
||||
)
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.deactivate_jail",
|
||||
"app.routers.config.jail_config_service.deactivate_jail",
|
||||
AsyncMock(return_value=mock_response),
|
||||
):
|
||||
resp = await config_client.post("/api/config/jails/sshd/deactivate")
|
||||
@@ -799,10 +919,10 @@ class TestDeactivateJail:
|
||||
|
||||
async def test_404_for_unknown_jail(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/jails/missing/deactivate returns 404."""
|
||||
from app.services.config_file_service import JailNotFoundInConfigError
|
||||
from app.services.jail_config_service import JailNotFoundInConfigError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.deactivate_jail",
|
||||
"app.routers.config.jail_config_service.deactivate_jail",
|
||||
AsyncMock(side_effect=JailNotFoundInConfigError("missing")),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
@@ -813,10 +933,10 @@ class TestDeactivateJail:
|
||||
|
||||
async def test_409_when_already_inactive(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/jails/apache-auth/deactivate returns 409 if already inactive."""
|
||||
from app.services.config_file_service import JailAlreadyInactiveError
|
||||
from app.services.jail_config_service import JailAlreadyInactiveError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.deactivate_jail",
|
||||
"app.routers.config.jail_config_service.deactivate_jail",
|
||||
AsyncMock(side_effect=JailAlreadyInactiveError("apache-auth")),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
@@ -827,10 +947,10 @@ class TestDeactivateJail:
|
||||
|
||||
async def test_400_for_invalid_jail_name(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/jails/.../deactivate with bad name returns 400."""
|
||||
from app.services.config_file_service import JailNameError
|
||||
from app.services.jail_config_service import JailNameError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.deactivate_jail",
|
||||
"app.routers.config.jail_config_service.deactivate_jail",
|
||||
AsyncMock(side_effect=JailNameError("bad")),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
@@ -858,7 +978,7 @@ class TestDeactivateJail:
|
||||
)
|
||||
with (
|
||||
patch(
|
||||
"app.routers.config.config_file_service.deactivate_jail",
|
||||
"app.routers.config.jail_config_service.deactivate_jail",
|
||||
AsyncMock(return_value=mock_response),
|
||||
),
|
||||
patch(
|
||||
@@ -909,7 +1029,7 @@ class TestListFilters:
|
||||
total=1,
|
||||
)
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.list_filters",
|
||||
"app.routers.config.filter_config_service.list_filters",
|
||||
AsyncMock(return_value=mock_response),
|
||||
):
|
||||
resp = await config_client.get("/api/config/filters")
|
||||
@@ -925,7 +1045,7 @@ class TestListFilters:
|
||||
from app.models.config import FilterListResponse
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.list_filters",
|
||||
"app.routers.config.filter_config_service.list_filters",
|
||||
AsyncMock(return_value=FilterListResponse(filters=[], total=0)),
|
||||
):
|
||||
resp = await config_client.get("/api/config/filters")
|
||||
@@ -948,7 +1068,7 @@ class TestListFilters:
|
||||
total=2,
|
||||
)
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.list_filters",
|
||||
"app.routers.config.filter_config_service.list_filters",
|
||||
AsyncMock(return_value=mock_response),
|
||||
):
|
||||
resp = await config_client.get("/api/config/filters")
|
||||
@@ -977,7 +1097,7 @@ class TestGetFilter:
|
||||
async def test_200_returns_filter(self, config_client: AsyncClient) -> None:
|
||||
"""GET /api/config/filters/sshd returns 200 with FilterConfig."""
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.get_filter",
|
||||
"app.routers.config.filter_config_service.get_filter",
|
||||
AsyncMock(return_value=_make_filter_config("sshd")),
|
||||
):
|
||||
resp = await config_client.get("/api/config/filters/sshd")
|
||||
@@ -990,10 +1110,10 @@ class TestGetFilter:
|
||||
|
||||
async def test_404_for_unknown_filter(self, config_client: AsyncClient) -> None:
|
||||
"""GET /api/config/filters/missing returns 404."""
|
||||
from app.services.config_file_service import FilterNotFoundError
|
||||
from app.services.filter_config_service import FilterNotFoundError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.get_filter",
|
||||
"app.routers.config.filter_config_service.get_filter",
|
||||
AsyncMock(side_effect=FilterNotFoundError("missing")),
|
||||
):
|
||||
resp = await config_client.get("/api/config/filters/missing")
|
||||
@@ -1020,7 +1140,7 @@ class TestUpdateFilter:
|
||||
async def test_200_returns_updated_filter(self, config_client: AsyncClient) -> None:
|
||||
"""PUT /api/config/filters/sshd returns 200 with updated FilterConfig."""
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.update_filter",
|
||||
"app.routers.config.filter_config_service.update_filter",
|
||||
AsyncMock(return_value=_make_filter_config("sshd")),
|
||||
):
|
||||
resp = await config_client.put(
|
||||
@@ -1033,10 +1153,10 @@ class TestUpdateFilter:
|
||||
|
||||
async def test_404_for_unknown_filter(self, config_client: AsyncClient) -> None:
|
||||
"""PUT /api/config/filters/missing returns 404."""
|
||||
from app.services.config_file_service import FilterNotFoundError
|
||||
from app.services.filter_config_service import FilterNotFoundError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.update_filter",
|
||||
"app.routers.config.filter_config_service.update_filter",
|
||||
AsyncMock(side_effect=FilterNotFoundError("missing")),
|
||||
):
|
||||
resp = await config_client.put(
|
||||
@@ -1048,10 +1168,10 @@ class TestUpdateFilter:
|
||||
|
||||
async def test_422_for_invalid_regex(self, config_client: AsyncClient) -> None:
|
||||
"""PUT /api/config/filters/sshd returns 422 for bad regex."""
|
||||
from app.services.config_file_service import FilterInvalidRegexError
|
||||
from app.services.filter_config_service import FilterInvalidRegexError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.update_filter",
|
||||
"app.routers.config.filter_config_service.update_filter",
|
||||
AsyncMock(side_effect=FilterInvalidRegexError("[bad", "unterminated")),
|
||||
):
|
||||
resp = await config_client.put(
|
||||
@@ -1063,10 +1183,10 @@ class TestUpdateFilter:
|
||||
|
||||
async def test_400_for_invalid_name(self, config_client: AsyncClient) -> None:
|
||||
"""PUT /api/config/filters/... with bad name returns 400."""
|
||||
from app.services.config_file_service import FilterNameError
|
||||
from app.services.filter_config_service import FilterNameError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.update_filter",
|
||||
"app.routers.config.filter_config_service.update_filter",
|
||||
AsyncMock(side_effect=FilterNameError("bad")),
|
||||
):
|
||||
resp = await config_client.put(
|
||||
@@ -1079,7 +1199,7 @@ class TestUpdateFilter:
|
||||
async def test_reload_query_param_passed(self, config_client: AsyncClient) -> None:
|
||||
"""PUT /api/config/filters/sshd?reload=true passes do_reload=True."""
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.update_filter",
|
||||
"app.routers.config.filter_config_service.update_filter",
|
||||
AsyncMock(return_value=_make_filter_config("sshd")),
|
||||
) as mock_update:
|
||||
resp = await config_client.put(
|
||||
@@ -1110,7 +1230,7 @@ class TestCreateFilter:
|
||||
async def test_201_creates_filter(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/filters returns 201 with FilterConfig."""
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.create_filter",
|
||||
"app.routers.config.filter_config_service.create_filter",
|
||||
AsyncMock(return_value=_make_filter_config("my-custom")),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
@@ -1123,10 +1243,10 @@ class TestCreateFilter:
|
||||
|
||||
async def test_409_when_already_exists(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/filters returns 409 if filter exists."""
|
||||
from app.services.config_file_service import FilterAlreadyExistsError
|
||||
from app.services.filter_config_service import FilterAlreadyExistsError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.create_filter",
|
||||
"app.routers.config.filter_config_service.create_filter",
|
||||
AsyncMock(side_effect=FilterAlreadyExistsError("sshd")),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
@@ -1138,10 +1258,10 @@ class TestCreateFilter:
|
||||
|
||||
async def test_422_for_invalid_regex(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/filters returns 422 for bad regex."""
|
||||
from app.services.config_file_service import FilterInvalidRegexError
|
||||
from app.services.filter_config_service import FilterInvalidRegexError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.create_filter",
|
||||
"app.routers.config.filter_config_service.create_filter",
|
||||
AsyncMock(side_effect=FilterInvalidRegexError("[bad", "unterminated")),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
@@ -1153,10 +1273,10 @@ class TestCreateFilter:
|
||||
|
||||
async def test_400_for_invalid_name(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/filters returns 400 for invalid filter name."""
|
||||
from app.services.config_file_service import FilterNameError
|
||||
from app.services.filter_config_service import FilterNameError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.create_filter",
|
||||
"app.routers.config.filter_config_service.create_filter",
|
||||
AsyncMock(side_effect=FilterNameError("bad")),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
@@ -1186,7 +1306,7 @@ class TestDeleteFilter:
|
||||
async def test_204_deletes_filter(self, config_client: AsyncClient) -> None:
|
||||
"""DELETE /api/config/filters/my-custom returns 204."""
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.delete_filter",
|
||||
"app.routers.config.filter_config_service.delete_filter",
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
resp = await config_client.delete("/api/config/filters/my-custom")
|
||||
@@ -1195,10 +1315,10 @@ class TestDeleteFilter:
|
||||
|
||||
async def test_404_for_unknown_filter(self, config_client: AsyncClient) -> None:
|
||||
"""DELETE /api/config/filters/missing returns 404."""
|
||||
from app.services.config_file_service import FilterNotFoundError
|
||||
from app.services.filter_config_service import FilterNotFoundError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.delete_filter",
|
||||
"app.routers.config.filter_config_service.delete_filter",
|
||||
AsyncMock(side_effect=FilterNotFoundError("missing")),
|
||||
):
|
||||
resp = await config_client.delete("/api/config/filters/missing")
|
||||
@@ -1207,10 +1327,10 @@ class TestDeleteFilter:
|
||||
|
||||
async def test_409_for_readonly_filter(self, config_client: AsyncClient) -> None:
|
||||
"""DELETE /api/config/filters/sshd returns 409 for shipped conf-only filter."""
|
||||
from app.services.config_file_service import FilterReadonlyError
|
||||
from app.services.filter_config_service import FilterReadonlyError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.delete_filter",
|
||||
"app.routers.config.filter_config_service.delete_filter",
|
||||
AsyncMock(side_effect=FilterReadonlyError("sshd")),
|
||||
):
|
||||
resp = await config_client.delete("/api/config/filters/sshd")
|
||||
@@ -1219,10 +1339,10 @@ class TestDeleteFilter:
|
||||
|
||||
async def test_400_for_invalid_name(self, config_client: AsyncClient) -> None:
|
||||
"""DELETE /api/config/filters/... with bad name returns 400."""
|
||||
from app.services.config_file_service import FilterNameError
|
||||
from app.services.filter_config_service import FilterNameError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.delete_filter",
|
||||
"app.routers.config.filter_config_service.delete_filter",
|
||||
AsyncMock(side_effect=FilterNameError("bad")),
|
||||
):
|
||||
resp = await config_client.delete("/api/config/filters/bad")
|
||||
@@ -1249,7 +1369,7 @@ class TestAssignFilterToJail:
|
||||
async def test_204_assigns_filter(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/jails/sshd/filter returns 204 on success."""
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.assign_filter_to_jail",
|
||||
"app.routers.config.filter_config_service.assign_filter_to_jail",
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
@@ -1261,10 +1381,10 @@ class TestAssignFilterToJail:
|
||||
|
||||
async def test_404_for_unknown_jail(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/jails/missing/filter returns 404."""
|
||||
from app.services.config_file_service import JailNotFoundInConfigError
|
||||
from app.services.jail_config_service import JailNotFoundInConfigError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.assign_filter_to_jail",
|
||||
"app.routers.config.filter_config_service.assign_filter_to_jail",
|
||||
AsyncMock(side_effect=JailNotFoundInConfigError("missing")),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
@@ -1276,10 +1396,10 @@ class TestAssignFilterToJail:
|
||||
|
||||
async def test_404_for_unknown_filter(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/jails/sshd/filter returns 404 when filter not found."""
|
||||
from app.services.config_file_service import FilterNotFoundError
|
||||
from app.services.filter_config_service import FilterNotFoundError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.assign_filter_to_jail",
|
||||
"app.routers.config.filter_config_service.assign_filter_to_jail",
|
||||
AsyncMock(side_effect=FilterNotFoundError("missing-filter")),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
@@ -1291,10 +1411,10 @@ class TestAssignFilterToJail:
|
||||
|
||||
async def test_400_for_invalid_jail_name(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/jails/.../filter with bad jail name returns 400."""
|
||||
from app.services.config_file_service import JailNameError
|
||||
from app.services.jail_config_service import JailNameError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.assign_filter_to_jail",
|
||||
"app.routers.config.filter_config_service.assign_filter_to_jail",
|
||||
AsyncMock(side_effect=JailNameError("bad")),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
@@ -1306,10 +1426,10 @@ class TestAssignFilterToJail:
|
||||
|
||||
async def test_400_for_invalid_filter_name(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/jails/sshd/filter with bad filter name returns 400."""
|
||||
from app.services.config_file_service import FilterNameError
|
||||
from app.services.filter_config_service import FilterNameError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.assign_filter_to_jail",
|
||||
"app.routers.config.filter_config_service.assign_filter_to_jail",
|
||||
AsyncMock(side_effect=FilterNameError("bad")),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
@@ -1322,7 +1442,7 @@ class TestAssignFilterToJail:
|
||||
async def test_reload_query_param_passed(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/jails/sshd/filter?reload=true passes do_reload=True."""
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.assign_filter_to_jail",
|
||||
"app.routers.config.filter_config_service.assign_filter_to_jail",
|
||||
AsyncMock(return_value=None),
|
||||
) as mock_assign:
|
||||
resp = await config_client.post(
|
||||
@@ -1360,7 +1480,7 @@ class TestListActionsRouter:
|
||||
mock_response = ActionListResponse(actions=[mock_action], total=1)
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.list_actions",
|
||||
"app.routers.config.action_config_service.list_actions",
|
||||
AsyncMock(return_value=mock_response),
|
||||
):
|
||||
resp = await config_client.get("/api/config/actions")
|
||||
@@ -1378,7 +1498,7 @@ class TestListActionsRouter:
|
||||
mock_response = ActionListResponse(actions=[inactive, active], total=2)
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.list_actions",
|
||||
"app.routers.config.action_config_service.list_actions",
|
||||
AsyncMock(return_value=mock_response),
|
||||
):
|
||||
resp = await config_client.get("/api/config/actions")
|
||||
@@ -1406,7 +1526,7 @@ class TestGetActionRouter:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.get_action",
|
||||
"app.routers.config.action_config_service.get_action",
|
||||
AsyncMock(return_value=mock_action),
|
||||
):
|
||||
resp = await config_client.get("/api/config/actions/iptables")
|
||||
@@ -1415,10 +1535,10 @@ class TestGetActionRouter:
|
||||
assert resp.json()["name"] == "iptables"
|
||||
|
||||
async def test_404_when_not_found(self, config_client: AsyncClient) -> None:
|
||||
from app.services.config_file_service import ActionNotFoundError
|
||||
from app.services.action_config_service import ActionNotFoundError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.get_action",
|
||||
"app.routers.config.action_config_service.get_action",
|
||||
AsyncMock(side_effect=ActionNotFoundError("missing")),
|
||||
):
|
||||
resp = await config_client.get("/api/config/actions/missing")
|
||||
@@ -1445,7 +1565,7 @@ class TestUpdateActionRouter:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.update_action",
|
||||
"app.routers.config.action_config_service.update_action",
|
||||
AsyncMock(return_value=updated),
|
||||
):
|
||||
resp = await config_client.put(
|
||||
@@ -1457,10 +1577,10 @@ class TestUpdateActionRouter:
|
||||
assert resp.json()["actionban"] == "echo ban"
|
||||
|
||||
async def test_404_when_not_found(self, config_client: AsyncClient) -> None:
|
||||
from app.services.config_file_service import ActionNotFoundError
|
||||
from app.services.action_config_service import ActionNotFoundError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.update_action",
|
||||
"app.routers.config.action_config_service.update_action",
|
||||
AsyncMock(side_effect=ActionNotFoundError("missing")),
|
||||
):
|
||||
resp = await config_client.put(
|
||||
@@ -1470,10 +1590,10 @@ class TestUpdateActionRouter:
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_400_for_bad_name(self, config_client: AsyncClient) -> None:
|
||||
from app.services.config_file_service import ActionNameError
|
||||
from app.services.action_config_service import ActionNameError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.update_action",
|
||||
"app.routers.config.action_config_service.update_action",
|
||||
AsyncMock(side_effect=ActionNameError()),
|
||||
):
|
||||
resp = await config_client.put(
|
||||
@@ -1502,7 +1622,7 @@ class TestCreateActionRouter:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.create_action",
|
||||
"app.routers.config.action_config_service.create_action",
|
||||
AsyncMock(return_value=created),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
@@ -1514,10 +1634,10 @@ class TestCreateActionRouter:
|
||||
assert resp.json()["name"] == "custom"
|
||||
|
||||
async def test_409_when_already_exists(self, config_client: AsyncClient) -> None:
|
||||
from app.services.config_file_service import ActionAlreadyExistsError
|
||||
from app.services.action_config_service import ActionAlreadyExistsError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.create_action",
|
||||
"app.routers.config.action_config_service.create_action",
|
||||
AsyncMock(side_effect=ActionAlreadyExistsError("iptables")),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
@@ -1528,10 +1648,10 @@ class TestCreateActionRouter:
|
||||
assert resp.status_code == 409
|
||||
|
||||
async def test_400_for_bad_name(self, config_client: AsyncClient) -> None:
|
||||
from app.services.config_file_service import ActionNameError
|
||||
from app.services.action_config_service import ActionNameError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.create_action",
|
||||
"app.routers.config.action_config_service.create_action",
|
||||
AsyncMock(side_effect=ActionNameError()),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
@@ -1553,7 +1673,7 @@ class TestCreateActionRouter:
|
||||
class TestDeleteActionRouter:
|
||||
async def test_204_on_delete(self, config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.delete_action",
|
||||
"app.routers.config.action_config_service.delete_action",
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
resp = await config_client.delete("/api/config/actions/custom")
|
||||
@@ -1561,10 +1681,10 @@ class TestDeleteActionRouter:
|
||||
assert resp.status_code == 204
|
||||
|
||||
async def test_404_when_not_found(self, config_client: AsyncClient) -> None:
|
||||
from app.services.config_file_service import ActionNotFoundError
|
||||
from app.services.action_config_service import ActionNotFoundError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.delete_action",
|
||||
"app.routers.config.action_config_service.delete_action",
|
||||
AsyncMock(side_effect=ActionNotFoundError("missing")),
|
||||
):
|
||||
resp = await config_client.delete("/api/config/actions/missing")
|
||||
@@ -1572,10 +1692,10 @@ class TestDeleteActionRouter:
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_409_when_readonly(self, config_client: AsyncClient) -> None:
|
||||
from app.services.config_file_service import ActionReadonlyError
|
||||
from app.services.action_config_service import ActionReadonlyError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.delete_action",
|
||||
"app.routers.config.action_config_service.delete_action",
|
||||
AsyncMock(side_effect=ActionReadonlyError("iptables")),
|
||||
):
|
||||
resp = await config_client.delete("/api/config/actions/iptables")
|
||||
@@ -1583,10 +1703,10 @@ class TestDeleteActionRouter:
|
||||
assert resp.status_code == 409
|
||||
|
||||
async def test_400_for_bad_name(self, config_client: AsyncClient) -> None:
|
||||
from app.services.config_file_service import ActionNameError
|
||||
from app.services.action_config_service import ActionNameError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.delete_action",
|
||||
"app.routers.config.action_config_service.delete_action",
|
||||
AsyncMock(side_effect=ActionNameError()),
|
||||
):
|
||||
resp = await config_client.delete("/api/config/actions/badname")
|
||||
@@ -1605,7 +1725,7 @@ class TestDeleteActionRouter:
|
||||
class TestAssignActionToJailRouter:
|
||||
async def test_204_on_success(self, config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.assign_action_to_jail",
|
||||
"app.routers.config.action_config_service.assign_action_to_jail",
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
@@ -1616,10 +1736,10 @@ class TestAssignActionToJailRouter:
|
||||
assert resp.status_code == 204
|
||||
|
||||
async def test_404_when_jail_not_found(self, config_client: AsyncClient) -> None:
|
||||
from app.services.config_file_service import JailNotFoundInConfigError
|
||||
from app.services.jail_config_service import JailNotFoundInConfigError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.assign_action_to_jail",
|
||||
"app.routers.config.action_config_service.assign_action_to_jail",
|
||||
AsyncMock(side_effect=JailNotFoundInConfigError("missing")),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
@@ -1630,10 +1750,10 @@ class TestAssignActionToJailRouter:
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_404_when_action_not_found(self, config_client: AsyncClient) -> None:
|
||||
from app.services.config_file_service import ActionNotFoundError
|
||||
from app.services.action_config_service import ActionNotFoundError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.assign_action_to_jail",
|
||||
"app.routers.config.action_config_service.assign_action_to_jail",
|
||||
AsyncMock(side_effect=ActionNotFoundError("missing")),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
@@ -1644,10 +1764,10 @@ class TestAssignActionToJailRouter:
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_400_for_bad_jail_name(self, config_client: AsyncClient) -> None:
|
||||
from app.services.config_file_service import JailNameError
|
||||
from app.services.jail_config_service import JailNameError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.assign_action_to_jail",
|
||||
"app.routers.config.action_config_service.assign_action_to_jail",
|
||||
AsyncMock(side_effect=JailNameError()),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
@@ -1658,10 +1778,10 @@ class TestAssignActionToJailRouter:
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_400_for_bad_action_name(self, config_client: AsyncClient) -> None:
|
||||
from app.services.config_file_service import ActionNameError
|
||||
from app.services.action_config_service import ActionNameError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.assign_action_to_jail",
|
||||
"app.routers.config.action_config_service.assign_action_to_jail",
|
||||
AsyncMock(side_effect=ActionNameError()),
|
||||
):
|
||||
resp = await config_client.post(
|
||||
@@ -1673,7 +1793,7 @@ class TestAssignActionToJailRouter:
|
||||
|
||||
async def test_reload_param_passed(self, config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.assign_action_to_jail",
|
||||
"app.routers.config.action_config_service.assign_action_to_jail",
|
||||
AsyncMock(return_value=None),
|
||||
) as mock_assign:
|
||||
resp = await config_client.post(
|
||||
@@ -1696,7 +1816,7 @@ class TestAssignActionToJailRouter:
|
||||
class TestRemoveActionFromJailRouter:
|
||||
async def test_204_on_success(self, config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.remove_action_from_jail",
|
||||
"app.routers.config.action_config_service.remove_action_from_jail",
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
resp = await config_client.delete(
|
||||
@@ -1706,10 +1826,10 @@ class TestRemoveActionFromJailRouter:
|
||||
assert resp.status_code == 204
|
||||
|
||||
async def test_404_when_jail_not_found(self, config_client: AsyncClient) -> None:
|
||||
from app.services.config_file_service import JailNotFoundInConfigError
|
||||
from app.services.jail_config_service import JailNotFoundInConfigError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.remove_action_from_jail",
|
||||
"app.routers.config.action_config_service.remove_action_from_jail",
|
||||
AsyncMock(side_effect=JailNotFoundInConfigError("missing")),
|
||||
):
|
||||
resp = await config_client.delete(
|
||||
@@ -1719,10 +1839,10 @@ class TestRemoveActionFromJailRouter:
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_400_for_bad_jail_name(self, config_client: AsyncClient) -> None:
|
||||
from app.services.config_file_service import JailNameError
|
||||
from app.services.jail_config_service import JailNameError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.remove_action_from_jail",
|
||||
"app.routers.config.action_config_service.remove_action_from_jail",
|
||||
AsyncMock(side_effect=JailNameError()),
|
||||
):
|
||||
resp = await config_client.delete(
|
||||
@@ -1732,10 +1852,10 @@ class TestRemoveActionFromJailRouter:
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_400_for_bad_action_name(self, config_client: AsyncClient) -> None:
|
||||
from app.services.config_file_service import ActionNameError
|
||||
from app.services.action_config_service import ActionNameError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.remove_action_from_jail",
|
||||
"app.routers.config.action_config_service.remove_action_from_jail",
|
||||
AsyncMock(side_effect=ActionNameError()),
|
||||
):
|
||||
resp = await config_client.delete(
|
||||
@@ -1746,7 +1866,7 @@ class TestRemoveActionFromJailRouter:
|
||||
|
||||
async def test_reload_param_passed(self, config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.remove_action_from_jail",
|
||||
"app.routers.config.action_config_service.remove_action_from_jail",
|
||||
AsyncMock(return_value=None),
|
||||
) as mock_rm:
|
||||
resp = await config_client.delete(
|
||||
@@ -1881,7 +2001,7 @@ class TestGetServiceStatus:
|
||||
def _mock_status(self, online: bool = True) -> ServiceStatusResponse:
|
||||
return ServiceStatusResponse(
|
||||
online=online,
|
||||
version="1.0.0" if online else None,
|
||||
version=app.__version__,
|
||||
jail_count=2 if online else 0,
|
||||
total_bans=10 if online else 0,
|
||||
total_failures=3 if online else 0,
|
||||
@@ -1900,6 +2020,7 @@ class TestGetServiceStatus:
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["online"] is True
|
||||
assert data["version"] == app.__version__
|
||||
assert data["jail_count"] == 2
|
||||
assert data["log_level"] == "INFO"
|
||||
|
||||
@@ -1913,6 +2034,7 @@ class TestGetServiceStatus:
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["version"] == app.__version__
|
||||
assert data["online"] is False
|
||||
assert data["log_level"] == "UNKNOWN"
|
||||
|
||||
@@ -1942,7 +2064,7 @@ class TestValidateJailEndpoint:
|
||||
jail_name="sshd", valid=True, issues=[]
|
||||
)
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.validate_jail_config",
|
||||
"app.routers.config.jail_config_service.validate_jail_config",
|
||||
AsyncMock(return_value=mock_result),
|
||||
):
|
||||
resp = await config_client.post("/api/config/jails/sshd/validate")
|
||||
@@ -1962,7 +2084,7 @@ class TestValidateJailEndpoint:
|
||||
jail_name="sshd", valid=False, issues=[issue]
|
||||
)
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.validate_jail_config",
|
||||
"app.routers.config.jail_config_service.validate_jail_config",
|
||||
AsyncMock(return_value=mock_result),
|
||||
):
|
||||
resp = await config_client.post("/api/config/jails/sshd/validate")
|
||||
@@ -1975,10 +2097,10 @@ class TestValidateJailEndpoint:
|
||||
|
||||
async def test_400_for_invalid_jail_name(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/jails/bad-name/validate returns 400 on JailNameError."""
|
||||
from app.services.config_file_service import JailNameError
|
||||
from app.services.jail_config_service import JailNameError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.validate_jail_config",
|
||||
"app.routers.config.jail_config_service.validate_jail_config",
|
||||
AsyncMock(side_effect=JailNameError("bad name")),
|
||||
):
|
||||
resp = await config_client.post("/api/config/jails/bad-name/validate")
|
||||
@@ -2070,7 +2192,7 @@ class TestRollbackEndpoint:
|
||||
message="Jail 'sshd' disabled and fail2ban restarted.",
|
||||
)
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.rollback_jail",
|
||||
"app.routers.config.jail_config_service.rollback_jail",
|
||||
AsyncMock(return_value=mock_result),
|
||||
):
|
||||
resp = await config_client.post("/api/config/jails/sshd/rollback")
|
||||
@@ -2107,7 +2229,7 @@ class TestRollbackEndpoint:
|
||||
message="fail2ban did not come back online.",
|
||||
)
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.rollback_jail",
|
||||
"app.routers.config.jail_config_service.rollback_jail",
|
||||
AsyncMock(return_value=mock_result),
|
||||
):
|
||||
resp = await config_client.post("/api/config/jails/sshd/rollback")
|
||||
@@ -2120,10 +2242,10 @@ class TestRollbackEndpoint:
|
||||
|
||||
async def test_400_for_invalid_jail_name(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/jails/bad/rollback returns 400 on JailNameError."""
|
||||
from app.services.config_file_service import JailNameError
|
||||
from app.services.jail_config_service import JailNameError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.rollback_jail",
|
||||
"app.routers.config.jail_config_service.rollback_jail",
|
||||
AsyncMock(side_effect=JailNameError("bad")),
|
||||
):
|
||||
resp = await config_client.post("/api/config/jails/bad/rollback")
|
||||
|
||||
@@ -9,6 +9,8 @@ import aiosqlite
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
import app
|
||||
|
||||
from app.config import Settings
|
||||
from app.db import init_db
|
||||
from app.main import create_app
|
||||
@@ -151,6 +153,7 @@ class TestDashboardStatus:
|
||||
body = response.json()
|
||||
|
||||
assert "status" in body
|
||||
|
||||
status = body["status"]
|
||||
assert "online" in status
|
||||
assert "version" in status
|
||||
@@ -163,10 +166,11 @@ class TestDashboardStatus:
|
||||
) -> None:
|
||||
"""Endpoint returns the exact values from ``app.state.server_status``."""
|
||||
response = await dashboard_client.get("/api/dashboard/status")
|
||||
status = response.json()["status"]
|
||||
body = response.json()
|
||||
status = body["status"]
|
||||
|
||||
assert status["online"] is True
|
||||
assert status["version"] == "1.0.2"
|
||||
assert status["version"] == app.__version__
|
||||
assert status["active_jails"] == 2
|
||||
assert status["total_bans"] == 10
|
||||
assert status["total_failures"] == 5
|
||||
@@ -177,10 +181,11 @@ class TestDashboardStatus:
|
||||
"""Endpoint returns online=False when the cache holds an offline snapshot."""
|
||||
response = await offline_dashboard_client.get("/api/dashboard/status")
|
||||
assert response.status_code == 200
|
||||
status = response.json()["status"]
|
||||
body = response.json()
|
||||
status = body["status"]
|
||||
|
||||
assert status["online"] is False
|
||||
assert status["version"] is None
|
||||
assert status["version"] == app.__version__
|
||||
assert status["active_jails"] == 0
|
||||
assert status["total_bans"] == 0
|
||||
assert status["total_failures"] == 0
|
||||
|
||||
@@ -26,7 +26,7 @@ from app.models.file_config import (
|
||||
JailConfigFileContent,
|
||||
JailConfigFilesResponse,
|
||||
)
|
||||
from app.services.file_config_service import (
|
||||
from app.services.raw_config_io_service import (
|
||||
ConfigDirError,
|
||||
ConfigFileExistsError,
|
||||
ConfigFileNameError,
|
||||
@@ -112,7 +112,7 @@ class TestListJailConfigFiles:
|
||||
self, file_config_client: AsyncClient
|
||||
) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.list_jail_config_files",
|
||||
"app.routers.file_config.raw_config_io_service.list_jail_config_files",
|
||||
AsyncMock(return_value=_jail_files_resp()),
|
||||
):
|
||||
resp = await file_config_client.get("/api/config/jail-files")
|
||||
@@ -126,7 +126,7 @@ class TestListJailConfigFiles:
|
||||
self, file_config_client: AsyncClient
|
||||
) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.list_jail_config_files",
|
||||
"app.routers.file_config.raw_config_io_service.list_jail_config_files",
|
||||
AsyncMock(side_effect=ConfigDirError("not found")),
|
||||
):
|
||||
resp = await file_config_client.get("/api/config/jail-files")
|
||||
@@ -157,7 +157,7 @@ class TestGetJailConfigFile:
|
||||
content="[sshd]\nenabled = true\n",
|
||||
)
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.get_jail_config_file",
|
||||
"app.routers.file_config.raw_config_io_service.get_jail_config_file",
|
||||
AsyncMock(return_value=content),
|
||||
):
|
||||
resp = await file_config_client.get("/api/config/jail-files/sshd.conf")
|
||||
@@ -167,7 +167,7 @@ class TestGetJailConfigFile:
|
||||
|
||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.get_jail_config_file",
|
||||
"app.routers.file_config.raw_config_io_service.get_jail_config_file",
|
||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing.conf")),
|
||||
):
|
||||
resp = await file_config_client.get("/api/config/jail-files/missing.conf")
|
||||
@@ -178,7 +178,7 @@ class TestGetJailConfigFile:
|
||||
self, file_config_client: AsyncClient
|
||||
) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.get_jail_config_file",
|
||||
"app.routers.file_config.raw_config_io_service.get_jail_config_file",
|
||||
AsyncMock(side_effect=ConfigFileNameError("bad name")),
|
||||
):
|
||||
resp = await file_config_client.get("/api/config/jail-files/bad.txt")
|
||||
@@ -194,7 +194,7 @@ class TestGetJailConfigFile:
|
||||
class TestSetJailConfigEnabled:
|
||||
async def test_204_on_success(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.set_jail_config_enabled",
|
||||
"app.routers.file_config.raw_config_io_service.set_jail_config_enabled",
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
@@ -206,7 +206,7 @@ class TestSetJailConfigEnabled:
|
||||
|
||||
async def test_404_file_not_found(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.set_jail_config_enabled",
|
||||
"app.routers.file_config.raw_config_io_service.set_jail_config_enabled",
|
||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing.conf")),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
@@ -232,7 +232,7 @@ class TestGetFilterFileRaw:
|
||||
|
||||
async def test_200_returns_content(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.get_filter_file",
|
||||
"app.routers.file_config.raw_config_io_service.get_filter_file",
|
||||
AsyncMock(return_value=_conf_file_content("nginx")),
|
||||
):
|
||||
resp = await file_config_client.get("/api/config/filters/nginx/raw")
|
||||
@@ -242,7 +242,7 @@ class TestGetFilterFileRaw:
|
||||
|
||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.get_filter_file",
|
||||
"app.routers.file_config.raw_config_io_service.get_filter_file",
|
||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
|
||||
):
|
||||
resp = await file_config_client.get("/api/config/filters/missing/raw")
|
||||
@@ -258,7 +258,7 @@ class TestGetFilterFileRaw:
|
||||
class TestUpdateFilterFile:
|
||||
async def test_204_on_success(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.write_filter_file",
|
||||
"app.routers.file_config.raw_config_io_service.write_filter_file",
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
@@ -270,7 +270,7 @@ class TestUpdateFilterFile:
|
||||
|
||||
async def test_400_write_error(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.write_filter_file",
|
||||
"app.routers.file_config.raw_config_io_service.write_filter_file",
|
||||
AsyncMock(side_effect=ConfigFileWriteError("disk full")),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
@@ -289,7 +289,7 @@ class TestUpdateFilterFile:
|
||||
class TestCreateFilterFile:
|
||||
async def test_201_creates_file(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.create_filter_file",
|
||||
"app.routers.file_config.raw_config_io_service.create_filter_file",
|
||||
AsyncMock(return_value="myfilter.conf"),
|
||||
):
|
||||
resp = await file_config_client.post(
|
||||
@@ -302,7 +302,7 @@ class TestCreateFilterFile:
|
||||
|
||||
async def test_409_conflict(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.create_filter_file",
|
||||
"app.routers.file_config.raw_config_io_service.create_filter_file",
|
||||
AsyncMock(side_effect=ConfigFileExistsError("myfilter.conf")),
|
||||
):
|
||||
resp = await file_config_client.post(
|
||||
@@ -314,7 +314,7 @@ class TestCreateFilterFile:
|
||||
|
||||
async def test_400_invalid_name(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.create_filter_file",
|
||||
"app.routers.file_config.raw_config_io_service.create_filter_file",
|
||||
AsyncMock(side_effect=ConfigFileNameError("bad/../name")),
|
||||
):
|
||||
resp = await file_config_client.post(
|
||||
@@ -342,7 +342,7 @@ class TestListActionFiles:
|
||||
)
|
||||
resp_data = ActionListResponse(actions=[mock_action], total=1)
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.list_actions",
|
||||
"app.routers.config.action_config_service.list_actions",
|
||||
AsyncMock(return_value=resp_data),
|
||||
):
|
||||
resp = await file_config_client.get("/api/config/actions")
|
||||
@@ -365,7 +365,7 @@ class TestCreateActionFile:
|
||||
actionban="echo ban <ip>",
|
||||
)
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.create_action",
|
||||
"app.routers.config.action_config_service.create_action",
|
||||
AsyncMock(return_value=created),
|
||||
):
|
||||
resp = await file_config_client.post(
|
||||
@@ -377,6 +377,102 @@ class TestCreateActionFile:
|
||||
assert resp.json()["name"] == "myaction"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/config/actions/{name}/raw
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetActionFileRaw:
|
||||
"""Tests for ``GET /api/config/actions/{name}/raw``."""
|
||||
|
||||
async def test_200_returns_content(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.raw_config_io_service.get_action_file",
|
||||
AsyncMock(return_value=_conf_file_content("iptables")),
|
||||
):
|
||||
resp = await file_config_client.get("/api/config/actions/iptables/raw")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["name"] == "iptables"
|
||||
|
||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.raw_config_io_service.get_action_file",
|
||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
|
||||
):
|
||||
resp = await file_config_client.get("/api/config/actions/missing/raw")
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_503_on_config_dir_error(
|
||||
self, file_config_client: AsyncClient
|
||||
) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.raw_config_io_service.get_action_file",
|
||||
AsyncMock(side_effect=ConfigDirError("no dir")),
|
||||
):
|
||||
resp = await file_config_client.get("/api/config/actions/iptables/raw")
|
||||
|
||||
assert resp.status_code == 503
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PUT /api/config/actions/{name}/raw
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUpdateActionFileRaw:
|
||||
"""Tests for ``PUT /api/config/actions/{name}/raw``."""
|
||||
|
||||
async def test_204_on_success(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.raw_config_io_service.write_action_file",
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
"/api/config/actions/iptables/raw",
|
||||
json={"content": "[Definition]\nactionban = iptables -I INPUT -s <ip> -j DROP\n"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 204
|
||||
|
||||
async def test_400_write_error(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.raw_config_io_service.write_action_file",
|
||||
AsyncMock(side_effect=ConfigFileWriteError("disk full")),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
"/api/config/actions/iptables/raw",
|
||||
json={"content": "x"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.raw_config_io_service.write_action_file",
|
||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
"/api/config/actions/missing/raw",
|
||||
json={"content": "x"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_400_invalid_name(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.raw_config_io_service.write_action_file",
|
||||
AsyncMock(side_effect=ConfigFileNameError("bad/../name")),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
"/api/config/actions/escape/raw",
|
||||
json={"content": "x"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/config/jail-files
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -385,7 +481,7 @@ class TestCreateActionFile:
|
||||
class TestCreateJailConfigFile:
|
||||
async def test_201_creates_file(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.create_jail_config_file",
|
||||
"app.routers.file_config.raw_config_io_service.create_jail_config_file",
|
||||
AsyncMock(return_value="myjail.conf"),
|
||||
):
|
||||
resp = await file_config_client.post(
|
||||
@@ -398,7 +494,7 @@ class TestCreateJailConfigFile:
|
||||
|
||||
async def test_409_conflict(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.create_jail_config_file",
|
||||
"app.routers.file_config.raw_config_io_service.create_jail_config_file",
|
||||
AsyncMock(side_effect=ConfigFileExistsError("myjail.conf")),
|
||||
):
|
||||
resp = await file_config_client.post(
|
||||
@@ -410,7 +506,7 @@ class TestCreateJailConfigFile:
|
||||
|
||||
async def test_400_invalid_name(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.create_jail_config_file",
|
||||
"app.routers.file_config.raw_config_io_service.create_jail_config_file",
|
||||
AsyncMock(side_effect=ConfigFileNameError("bad/../name")),
|
||||
):
|
||||
resp = await file_config_client.post(
|
||||
@@ -424,7 +520,7 @@ class TestCreateJailConfigFile:
|
||||
self, file_config_client: AsyncClient
|
||||
) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.create_jail_config_file",
|
||||
"app.routers.file_config.raw_config_io_service.create_jail_config_file",
|
||||
AsyncMock(side_effect=ConfigDirError("no dir")),
|
||||
):
|
||||
resp = await file_config_client.post(
|
||||
@@ -446,7 +542,7 @@ class TestGetParsedFilter:
|
||||
) -> None:
|
||||
cfg = FilterConfig(name="nginx", filename="nginx.conf")
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.get_parsed_filter_file",
|
||||
"app.routers.file_config.raw_config_io_service.get_parsed_filter_file",
|
||||
AsyncMock(return_value=cfg),
|
||||
):
|
||||
resp = await file_config_client.get("/api/config/filters/nginx/parsed")
|
||||
@@ -458,7 +554,7 @@ class TestGetParsedFilter:
|
||||
|
||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.get_parsed_filter_file",
|
||||
"app.routers.file_config.raw_config_io_service.get_parsed_filter_file",
|
||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
|
||||
):
|
||||
resp = await file_config_client.get(
|
||||
@@ -471,7 +567,7 @@ class TestGetParsedFilter:
|
||||
self, file_config_client: AsyncClient
|
||||
) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.get_parsed_filter_file",
|
||||
"app.routers.file_config.raw_config_io_service.get_parsed_filter_file",
|
||||
AsyncMock(side_effect=ConfigDirError("no dir")),
|
||||
):
|
||||
resp = await file_config_client.get("/api/config/filters/nginx/parsed")
|
||||
@@ -487,7 +583,7 @@ class TestGetParsedFilter:
|
||||
class TestUpdateParsedFilter:
|
||||
async def test_204_on_success(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.update_parsed_filter_file",
|
||||
"app.routers.file_config.raw_config_io_service.update_parsed_filter_file",
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
@@ -499,7 +595,7 @@ class TestUpdateParsedFilter:
|
||||
|
||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.update_parsed_filter_file",
|
||||
"app.routers.file_config.raw_config_io_service.update_parsed_filter_file",
|
||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
@@ -511,7 +607,7 @@ class TestUpdateParsedFilter:
|
||||
|
||||
async def test_400_write_error(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.update_parsed_filter_file",
|
||||
"app.routers.file_config.raw_config_io_service.update_parsed_filter_file",
|
||||
AsyncMock(side_effect=ConfigFileWriteError("disk full")),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
@@ -533,7 +629,7 @@ class TestGetParsedAction:
|
||||
) -> None:
|
||||
cfg = ActionConfig(name="iptables", filename="iptables.conf")
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.get_parsed_action_file",
|
||||
"app.routers.file_config.raw_config_io_service.get_parsed_action_file",
|
||||
AsyncMock(return_value=cfg),
|
||||
):
|
||||
resp = await file_config_client.get(
|
||||
@@ -547,7 +643,7 @@ class TestGetParsedAction:
|
||||
|
||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.get_parsed_action_file",
|
||||
"app.routers.file_config.raw_config_io_service.get_parsed_action_file",
|
||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
|
||||
):
|
||||
resp = await file_config_client.get(
|
||||
@@ -560,7 +656,7 @@ class TestGetParsedAction:
|
||||
self, file_config_client: AsyncClient
|
||||
) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.get_parsed_action_file",
|
||||
"app.routers.file_config.raw_config_io_service.get_parsed_action_file",
|
||||
AsyncMock(side_effect=ConfigDirError("no dir")),
|
||||
):
|
||||
resp = await file_config_client.get(
|
||||
@@ -578,7 +674,7 @@ class TestGetParsedAction:
|
||||
class TestUpdateParsedAction:
|
||||
async def test_204_on_success(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.update_parsed_action_file",
|
||||
"app.routers.file_config.raw_config_io_service.update_parsed_action_file",
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
@@ -590,7 +686,7 @@ class TestUpdateParsedAction:
|
||||
|
||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.update_parsed_action_file",
|
||||
"app.routers.file_config.raw_config_io_service.update_parsed_action_file",
|
||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
@@ -602,7 +698,7 @@ class TestUpdateParsedAction:
|
||||
|
||||
async def test_400_write_error(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.update_parsed_action_file",
|
||||
"app.routers.file_config.raw_config_io_service.update_parsed_action_file",
|
||||
AsyncMock(side_effect=ConfigFileWriteError("disk full")),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
@@ -625,7 +721,7 @@ class TestGetParsedJailFile:
|
||||
section = JailSectionConfig(enabled=True, port="ssh")
|
||||
cfg = JailFileConfig(filename="sshd.conf", jails={"sshd": section})
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.get_parsed_jail_file",
|
||||
"app.routers.file_config.raw_config_io_service.get_parsed_jail_file",
|
||||
AsyncMock(return_value=cfg),
|
||||
):
|
||||
resp = await file_config_client.get(
|
||||
@@ -639,7 +735,7 @@ class TestGetParsedJailFile:
|
||||
|
||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.get_parsed_jail_file",
|
||||
"app.routers.file_config.raw_config_io_service.get_parsed_jail_file",
|
||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing.conf")),
|
||||
):
|
||||
resp = await file_config_client.get(
|
||||
@@ -652,7 +748,7 @@ class TestGetParsedJailFile:
|
||||
self, file_config_client: AsyncClient
|
||||
) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.get_parsed_jail_file",
|
||||
"app.routers.file_config.raw_config_io_service.get_parsed_jail_file",
|
||||
AsyncMock(side_effect=ConfigDirError("no dir")),
|
||||
):
|
||||
resp = await file_config_client.get(
|
||||
@@ -670,7 +766,7 @@ class TestGetParsedJailFile:
|
||||
class TestUpdateParsedJailFile:
|
||||
async def test_204_on_success(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.update_parsed_jail_file",
|
||||
"app.routers.file_config.raw_config_io_service.update_parsed_jail_file",
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
@@ -682,7 +778,7 @@ class TestUpdateParsedJailFile:
|
||||
|
||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.update_parsed_jail_file",
|
||||
"app.routers.file_config.raw_config_io_service.update_parsed_jail_file",
|
||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing.conf")),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
@@ -694,7 +790,7 @@ class TestUpdateParsedJailFile:
|
||||
|
||||
async def test_400_write_error(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.update_parsed_jail_file",
|
||||
"app.routers.file_config.raw_config_io_service.update_parsed_jail_file",
|
||||
AsyncMock(side_effect=ConfigFileWriteError("disk full")),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
|
||||
@@ -12,7 +12,7 @@ from httpx import ASGITransport, AsyncClient
|
||||
from app.config import Settings
|
||||
from app.db import init_db
|
||||
from app.main import create_app
|
||||
from app.services.geo_service import GeoInfo
|
||||
from app.models.geo import GeoInfo
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
@@ -70,7 +70,7 @@ class TestGeoLookup:
|
||||
async def test_200_with_geo_info(self, geo_client: AsyncClient) -> None:
|
||||
"""GET /api/geo/lookup/{ip} returns 200 with enriched result."""
|
||||
geo = GeoInfo(country_code="DE", country_name="Germany", asn="12345", org="Acme")
|
||||
result = {
|
||||
result: dict[str, object] = {
|
||||
"ip": "1.2.3.4",
|
||||
"currently_banned_in": ["sshd"],
|
||||
"geo": geo,
|
||||
@@ -92,7 +92,7 @@ class TestGeoLookup:
|
||||
|
||||
async def test_200_when_not_banned(self, geo_client: AsyncClient) -> None:
|
||||
"""GET /api/geo/lookup/{ip} returns empty list when IP is not banned anywhere."""
|
||||
result = {
|
||||
result: dict[str, object] = {
|
||||
"ip": "8.8.8.8",
|
||||
"currently_banned_in": [],
|
||||
"geo": GeoInfo(country_code="US", country_name="United States", asn=None, org=None),
|
||||
@@ -108,7 +108,7 @@ class TestGeoLookup:
|
||||
|
||||
async def test_200_with_no_geo(self, geo_client: AsyncClient) -> None:
|
||||
"""GET /api/geo/lookup/{ip} returns null geo when enricher fails."""
|
||||
result = {
|
||||
result: dict[str, object] = {
|
||||
"ip": "1.2.3.4",
|
||||
"currently_banned_in": [],
|
||||
"geo": None,
|
||||
@@ -144,7 +144,7 @@ class TestGeoLookup:
|
||||
|
||||
async def test_ipv6_address(self, geo_client: AsyncClient) -> None:
|
||||
"""GET /api/geo/lookup/{ip} handles IPv6 addresses."""
|
||||
result = {
|
||||
result: dict[str, object] = {
|
||||
"ip": "2001:db8::1",
|
||||
"currently_banned_in": [],
|
||||
"geo": None,
|
||||
|
||||
@@ -213,6 +213,18 @@ class TestHistoryList:
|
||||
_args, kwargs = mock_fn.call_args
|
||||
assert kwargs.get("range_") == "7d"
|
||||
|
||||
async def test_forwards_origin_filter(self, history_client: AsyncClient) -> None:
|
||||
"""The ``origin`` query parameter is forwarded to the service."""
|
||||
mock_fn = AsyncMock(return_value=_make_history_list(n=0))
|
||||
with patch(
|
||||
"app.routers.history.history_service.list_history",
|
||||
new=mock_fn,
|
||||
):
|
||||
await history_client.get("/api/history?origin=blocklist")
|
||||
|
||||
_args, kwargs = mock_fn.call_args
|
||||
assert kwargs.get("origin") == "blocklist"
|
||||
|
||||
async def test_empty_result(self, history_client: AsyncClient) -> None:
|
||||
"""An empty history returns items=[] and total=0."""
|
||||
with patch(
|
||||
|
||||
@@ -12,6 +12,7 @@ from httpx import ASGITransport, AsyncClient
|
||||
from app.config import Settings
|
||||
from app.db import init_db
|
||||
from app.main import create_app
|
||||
from app.models.ban import JailBannedIpsResponse
|
||||
from app.models.jail import Jail, JailDetailResponse, JailListResponse, JailStatus, JailSummary
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -801,17 +802,17 @@ class TestGetJailBannedIps:
|
||||
def _mock_response(
|
||||
self,
|
||||
*,
|
||||
items: list[dict] | None = None,
|
||||
items: list[dict[str, str | None]] | None = None,
|
||||
total: int = 2,
|
||||
page: int = 1,
|
||||
page_size: int = 25,
|
||||
) -> "JailBannedIpsResponse": # type: ignore[name-defined]
|
||||
) -> JailBannedIpsResponse:
|
||||
from app.models.ban import ActiveBan, JailBannedIpsResponse
|
||||
|
||||
ban_items = (
|
||||
[
|
||||
ActiveBan(
|
||||
ip=item.get("ip", "1.2.3.4"),
|
||||
ip=item.get("ip") or "1.2.3.4",
|
||||
jail="sshd",
|
||||
banned_at=item.get("banned_at", "2025-01-01T10:00:00+00:00"),
|
||||
expires_at=item.get("expires_at", "2025-01-01T10:10:00+00:00"),
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
@@ -11,7 +11,7 @@ from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from app.config import Settings
|
||||
from app.db import init_db
|
||||
from app.main import create_app
|
||||
from app.main import _lifespan, create_app
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared setup payload
|
||||
@@ -247,9 +247,9 @@ class TestSetupCompleteCaching:
|
||||
assert not getattr(app.state, "_setup_complete_cached", False)
|
||||
|
||||
# First non-exempt request — middleware queries DB and sets the flag.
|
||||
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}) # type: ignore[call-overload]
|
||||
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]})
|
||||
|
||||
assert app.state._setup_complete_cached is True # type: ignore[attr-defined]
|
||||
assert app.state._setup_complete_cached is True
|
||||
|
||||
async def test_cached_path_skips_is_setup_complete(
|
||||
self,
|
||||
@@ -267,12 +267,12 @@ class TestSetupCompleteCaching:
|
||||
|
||||
# Do setup and warm the cache.
|
||||
await client.post("/api/setup", json=_SETUP_PAYLOAD)
|
||||
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}) # type: ignore[call-overload]
|
||||
assert app.state._setup_complete_cached is True # type: ignore[attr-defined]
|
||||
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]})
|
||||
assert app.state._setup_complete_cached is True
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def _counting(db): # type: ignore[no-untyped-def]
|
||||
async def _counting(db: aiosqlite.Connection) -> bool:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return True
|
||||
@@ -286,3 +286,151 @@ class TestSetupCompleteCaching:
|
||||
# Cache was warm — is_setup_complete must not have been called.
|
||||
assert call_count == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Task 0.1 — Lifespan creates the database parent directory (Task 0.1)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLifespanDatabaseDirectoryCreation:
|
||||
"""App lifespan creates the database parent directory when it does not exist."""
|
||||
|
||||
async def test_creates_nested_database_directory(self, tmp_path: Path) -> None:
|
||||
"""Lifespan creates intermediate directories for the database path.
|
||||
|
||||
Verifies that a deeply-nested database path is handled correctly —
|
||||
the parent directories are created before ``aiosqlite.connect`` is
|
||||
called so the app does not crash on a fresh volume.
|
||||
"""
|
||||
nested_db = tmp_path / "deep" / "nested" / "bangui.db"
|
||||
assert not nested_db.parent.exists()
|
||||
|
||||
settings = Settings(
|
||||
database_path=str(nested_db),
|
||||
fail2ban_socket="/tmp/fake.sock",
|
||||
session_secret="test-lifespan-mkdir-secret",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
)
|
||||
app = create_app(settings=settings)
|
||||
|
||||
mock_scheduler = MagicMock()
|
||||
mock_scheduler.start = MagicMock()
|
||||
mock_scheduler.shutdown = MagicMock()
|
||||
|
||||
with (
|
||||
patch("app.services.geo_service.init_geoip"),
|
||||
patch(
|
||||
"app.services.geo_service.load_cache_from_db",
|
||||
new=AsyncMock(return_value=None),
|
||||
),
|
||||
patch("app.tasks.health_check.register"),
|
||||
patch("app.tasks.blocklist_import.register"),
|
||||
patch("app.tasks.geo_cache_flush.register"),
|
||||
patch("app.tasks.geo_re_resolve.register"),
|
||||
patch("app.main.AsyncIOScheduler", return_value=mock_scheduler),
|
||||
patch("app.main.ensure_jail_configs"),
|
||||
):
|
||||
async with _lifespan(app):
|
||||
assert nested_db.parent.exists(), (
|
||||
"Expected lifespan to create database parent directory"
|
||||
)
|
||||
|
||||
async def test_existing_database_directory_is_not_an_error(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""Lifespan does not raise when the database directory already exists.
|
||||
|
||||
``mkdir(exist_ok=True)`` must be used so that re-starts on an existing
|
||||
volume do not fail.
|
||||
"""
|
||||
db_path = tmp_path / "bangui.db"
|
||||
# tmp_path already exists — this simulates a pre-existing volume.
|
||||
|
||||
settings = Settings(
|
||||
database_path=str(db_path),
|
||||
fail2ban_socket="/tmp/fake.sock",
|
||||
session_secret="test-lifespan-exist-ok-secret",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
)
|
||||
app = create_app(settings=settings)
|
||||
|
||||
mock_scheduler = MagicMock()
|
||||
mock_scheduler.start = MagicMock()
|
||||
mock_scheduler.shutdown = MagicMock()
|
||||
|
||||
with (
|
||||
patch("app.services.geo_service.init_geoip"),
|
||||
patch(
|
||||
"app.services.geo_service.load_cache_from_db",
|
||||
new=AsyncMock(return_value=None),
|
||||
),
|
||||
patch("app.tasks.health_check.register"),
|
||||
patch("app.tasks.blocklist_import.register"),
|
||||
patch("app.tasks.geo_cache_flush.register"),
|
||||
patch("app.tasks.geo_re_resolve.register"),
|
||||
patch("app.main.AsyncIOScheduler", return_value=mock_scheduler),
|
||||
patch("app.main.ensure_jail_configs"),
|
||||
):
|
||||
# Should not raise FileExistsError or similar.
|
||||
async with _lifespan(app):
|
||||
assert tmp_path.exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Task 0.2 — Middleware redirects when app.state.db is None
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSetupRedirectMiddlewareDbNone:
|
||||
"""SetupRedirectMiddleware redirects when the database is not yet available."""
|
||||
|
||||
async def test_redirects_to_setup_when_db_not_set(self, tmp_path: Path) -> None:
|
||||
"""A ``None`` db on app.state causes a 307 redirect to ``/api/setup``.
|
||||
|
||||
Simulates the race window where a request arrives before the lifespan
|
||||
has finished initialising the database connection.
|
||||
"""
|
||||
settings = Settings(
|
||||
database_path=str(tmp_path / "bangui.db"),
|
||||
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||
session_secret="test-db-none-secret",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
)
|
||||
app = create_app(settings=settings)
|
||||
# Deliberately do NOT set app.state.db to simulate startup not complete.
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(
|
||||
transport=transport, base_url="http://test"
|
||||
) as ac:
|
||||
response = await ac.get("/api/auth/login", follow_redirects=False)
|
||||
|
||||
assert response.status_code == 307
|
||||
assert response.headers["location"] == "/api/setup"
|
||||
|
||||
async def test_health_reachable_when_db_not_set(self, tmp_path: Path) -> None:
|
||||
"""Health endpoint is always reachable even when db is not initialised."""
|
||||
settings = Settings(
|
||||
database_path=str(tmp_path / "bangui.db"),
|
||||
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||
session_secret="test-db-none-health-secret",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
)
|
||||
app = create_app(settings=settings)
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(
|
||||
transport=transport, base_url="http://test"
|
||||
) as ac:
|
||||
response = await ac.get("/api/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ class TestCheckPasswordAsync:
|
||||
auth_service._check_password("secret", hashed), # noqa: SLF001
|
||||
auth_service._check_password("wrong", hashed), # noqa: SLF001
|
||||
)
|
||||
assert results == [True, False]
|
||||
assert tuple(results) == (True, False)
|
||||
|
||||
|
||||
class TestLogin:
|
||||
|
||||
@@ -64,7 +64,7 @@ async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
||||
async def f2b_db_path(tmp_path: Path) -> str:
|
||||
"""Return the path to a test fail2ban SQLite database with several bans."""
|
||||
path = str(tmp_path / "fail2ban_test.sqlite3")
|
||||
await _create_f2b_db(
|
||||
@@ -103,7 +103,7 @@ async def f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mixed_origin_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
||||
async def mixed_origin_db_path(tmp_path: Path) -> str:
|
||||
"""Return a database with bans from both blocklist-import and organic jails."""
|
||||
path = str(tmp_path / "fail2ban_mixed_origin.sqlite3")
|
||||
await _create_f2b_db(
|
||||
@@ -136,7 +136,7 @@ async def mixed_origin_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def empty_f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
||||
async def empty_f2b_db_path(tmp_path: Path) -> str:
|
||||
"""Return the path to a fail2ban SQLite database with no ban records."""
|
||||
path = str(tmp_path / "fail2ban_empty.sqlite3")
|
||||
await _create_f2b_db(path, [])
|
||||
@@ -154,7 +154,7 @@ class TestListBansHappyPath:
|
||||
async def test_returns_bans_in_range(self, f2b_db_path: str) -> None:
|
||||
"""Only bans within the selected range are returned."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
@@ -166,7 +166,7 @@ class TestListBansHappyPath:
|
||||
async def test_results_sorted_newest_first(self, f2b_db_path: str) -> None:
|
||||
"""Items are ordered by ``banned_at`` descending (newest first)."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
@@ -177,7 +177,7 @@ class TestListBansHappyPath:
|
||||
async def test_ban_fields_present(self, f2b_db_path: str) -> None:
|
||||
"""Each item contains ip, jail, banned_at, ban_count."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
@@ -191,7 +191,7 @@ class TestListBansHappyPath:
|
||||
async def test_service_extracted_from_first_match(self, f2b_db_path: str) -> None:
|
||||
"""``service`` field is the first element of ``data.matches``."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
@@ -203,7 +203,7 @@ class TestListBansHappyPath:
|
||||
async def test_service_is_none_when_no_matches(self, f2b_db_path: str) -> None:
|
||||
"""``service`` is ``None`` when the ban has no stored matches."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
# Use 7d to include the older ban with no matches.
|
||||
@@ -215,7 +215,7 @@ class TestListBansHappyPath:
|
||||
async def test_empty_db_returns_zero(self, empty_f2b_db_path: str) -> None:
|
||||
"""When no bans exist the result has total=0 and no items."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
@@ -226,7 +226,7 @@ class TestListBansHappyPath:
|
||||
async def test_365d_range_includes_old_bans(self, f2b_db_path: str) -> None:
|
||||
"""The ``365d`` range includes bans that are 2 days old."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "365d")
|
||||
@@ -246,7 +246,7 @@ class TestListBansGeoEnrichment:
|
||||
self, f2b_db_path: str
|
||||
) -> None:
|
||||
"""Geo fields are populated when an enricher returns data."""
|
||||
from app.services.geo_service import GeoInfo
|
||||
from app.models.geo import GeoInfo
|
||||
|
||||
async def fake_enricher(ip: str) -> GeoInfo:
|
||||
return GeoInfo(
|
||||
@@ -257,7 +257,7 @@ class TestListBansGeoEnrichment:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
@@ -278,7 +278,7 @@ class TestListBansGeoEnrichment:
|
||||
raise RuntimeError("geo service down")
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
@@ -304,25 +304,27 @@ class TestListBansBatchGeoEnrichment:
|
||||
"""Geo fields are populated via lookup_batch when http_session is given."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from app.services.geo_service import GeoInfo
|
||||
from app.models.geo import GeoInfo
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_geo_map = {
|
||||
"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn="AS3320", org="Deutsche Telekom"),
|
||||
"5.6.7.8": GeoInfo(country_code="US", country_name="United States", asn="AS15169", org="Google"),
|
||||
}
|
||||
fake_geo_batch = AsyncMock(return_value=fake_geo_map)
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
), patch(
|
||||
"app.services.geo_service.lookup_batch",
|
||||
new=AsyncMock(return_value=fake_geo_map),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock", "24h", http_session=fake_session
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
http_session=fake_session,
|
||||
geo_batch_lookup=fake_geo_batch,
|
||||
)
|
||||
|
||||
fake_geo_batch.assert_awaited_once_with(["1.2.3.4", "5.6.7.8"], fake_session, db=None)
|
||||
assert result.total == 2
|
||||
de_item = next(i for i in result.items if i.ip == "1.2.3.4")
|
||||
us_item = next(i for i in result.items if i.ip == "5.6.7.8")
|
||||
@@ -339,15 +341,17 @@ class TestListBansBatchGeoEnrichment:
|
||||
|
||||
fake_session = MagicMock()
|
||||
|
||||
failing_geo_batch = AsyncMock(side_effect=RuntimeError("batch geo down"))
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
), patch(
|
||||
"app.services.geo_service.lookup_batch",
|
||||
new=AsyncMock(side_effect=RuntimeError("batch geo down")),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock", "24h", http_session=fake_session
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
http_session=fake_session,
|
||||
geo_batch_lookup=failing_geo_batch,
|
||||
)
|
||||
|
||||
assert result.total == 2
|
||||
@@ -360,28 +364,27 @@ class TestListBansBatchGeoEnrichment:
|
||||
"""When both http_session and geo_enricher are provided, batch wins."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from app.services.geo_service import GeoInfo
|
||||
from app.models.geo import GeoInfo
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_geo_map = {
|
||||
"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None),
|
||||
"5.6.7.8": GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None),
|
||||
}
|
||||
fake_geo_batch = AsyncMock(return_value=fake_geo_map)
|
||||
|
||||
async def enricher_should_not_be_called(ip: str) -> GeoInfo:
|
||||
raise AssertionError(f"geo_enricher was called for {ip!r} — should not happen")
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
), patch(
|
||||
"app.services.geo_service.lookup_batch",
|
||||
new=AsyncMock(return_value=fake_geo_map),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
http_session=fake_session,
|
||||
geo_batch_lookup=fake_geo_batch,
|
||||
geo_enricher=enricher_should_not_be_called,
|
||||
)
|
||||
|
||||
@@ -401,7 +404,7 @@ class TestListBansPagination:
|
||||
async def test_page_size_respected(self, f2b_db_path: str) -> None:
|
||||
"""``page_size=1`` returns at most one item."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "7d", page_size=1)
|
||||
@@ -412,7 +415,7 @@ class TestListBansPagination:
|
||||
async def test_page_2_returns_remaining_items(self, f2b_db_path: str) -> None:
|
||||
"""The second page returns items not on the first page."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
page1 = await ban_service.list_bans("/fake/sock", "7d", page=1, page_size=1)
|
||||
@@ -426,7 +429,7 @@ class TestListBansPagination:
|
||||
) -> None:
|
||||
"""``total`` reports all matching records regardless of pagination."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "7d", page_size=1)
|
||||
@@ -447,7 +450,7 @@ class TestBanOriginDerivation:
|
||||
) -> None:
|
||||
"""Bans from ``blocklist-import`` jail carry ``origin == "blocklist"``."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
@@ -461,7 +464,7 @@ class TestBanOriginDerivation:
|
||||
) -> None:
|
||||
"""Bans from organic jails (sshd, nginx, …) carry ``origin == "selfblock"``."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
@@ -476,7 +479,7 @@ class TestBanOriginDerivation:
|
||||
) -> None:
|
||||
"""Every returned item has an ``origin`` field with a valid value."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
@@ -489,7 +492,7 @@ class TestBanOriginDerivation:
|
||||
) -> None:
|
||||
"""``bans_by_country`` also derives origin correctly for blocklist bans."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country("/fake/sock", "24h")
|
||||
@@ -503,7 +506,7 @@ class TestBanOriginDerivation:
|
||||
) -> None:
|
||||
"""``bans_by_country`` derives origin correctly for organic jails."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country("/fake/sock", "24h")
|
||||
@@ -527,7 +530,7 @@ class TestOriginFilter:
|
||||
) -> None:
|
||||
"""``origin='blocklist'`` returns only blocklist-import jail bans."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
@@ -544,7 +547,7 @@ class TestOriginFilter:
|
||||
) -> None:
|
||||
"""``origin='selfblock'`` excludes the blocklist-import jail."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
@@ -562,7 +565,7 @@ class TestOriginFilter:
|
||||
) -> None:
|
||||
"""``origin=None`` applies no jail restriction — all bans returned."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h", origin=None)
|
||||
@@ -574,7 +577,7 @@ class TestOriginFilter:
|
||||
) -> None:
|
||||
"""``bans_by_country`` with ``origin='blocklist'`` counts only blocklist bans."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
@@ -589,7 +592,7 @@ class TestOriginFilter:
|
||||
) -> None:
|
||||
"""``bans_by_country`` with ``origin='selfblock'`` excludes blocklist jails."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
@@ -604,7 +607,7 @@ class TestOriginFilter:
|
||||
) -> None:
|
||||
"""``bans_by_country`` with ``origin=None`` returns all bans."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
@@ -632,19 +635,19 @@ class TestBansbyCountryBackground:
|
||||
from app.services import geo_service
|
||||
|
||||
# Pre-populate the cache for all three IPs in the fixture.
|
||||
geo_service._cache["10.0.0.1"] = geo_service.GeoInfo( # type: ignore[attr-defined]
|
||||
geo_service._cache["10.0.0.1"] = geo_service.GeoInfo(
|
||||
country_code="DE", country_name="Germany", asn=None, org=None
|
||||
)
|
||||
geo_service._cache["10.0.0.2"] = geo_service.GeoInfo( # type: ignore[attr-defined]
|
||||
geo_service._cache["10.0.0.2"] = geo_service.GeoInfo(
|
||||
country_code="US", country_name="United States", asn=None, org=None
|
||||
)
|
||||
geo_service._cache["10.0.0.3"] = geo_service.GeoInfo( # type: ignore[attr-defined]
|
||||
geo_service._cache["10.0.0.3"] = geo_service.GeoInfo(
|
||||
country_code="JP", country_name="Japan", asn=None, org=None
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
),
|
||||
patch(
|
||||
@@ -652,8 +655,13 @@ class TestBansbyCountryBackground:
|
||||
) as mock_create_task,
|
||||
):
|
||||
mock_session = AsyncMock()
|
||||
mock_batch = AsyncMock(return_value={})
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock", "24h", http_session=mock_session
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
http_session=mock_session,
|
||||
geo_cache_lookup=geo_service.lookup_cached_only,
|
||||
geo_batch_lookup=mock_batch,
|
||||
)
|
||||
|
||||
# All countries resolved from cache — no background task needed.
|
||||
@@ -674,7 +682,7 @@ class TestBansbyCountryBackground:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
),
|
||||
patch(
|
||||
@@ -682,8 +690,13 @@ class TestBansbyCountryBackground:
|
||||
) as mock_create_task,
|
||||
):
|
||||
mock_session = AsyncMock()
|
||||
mock_batch = AsyncMock(return_value={})
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock", "24h", http_session=mock_session
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
http_session=mock_session,
|
||||
geo_cache_lookup=geo_service.lookup_cached_only,
|
||||
geo_batch_lookup=mock_batch,
|
||||
)
|
||||
|
||||
# Background task must have been scheduled for uncached IPs.
|
||||
@@ -701,7 +714,7 @@ class TestBansbyCountryBackground:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
),
|
||||
patch(
|
||||
@@ -727,7 +740,7 @@ class TestBanTrend:
|
||||
async def test_24h_returns_24_buckets(self, empty_f2b_db_path: str) -> None:
|
||||
"""``range_='24h'`` always yields exactly 24 buckets."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||
):
|
||||
result = await ban_service.ban_trend("/fake/sock", "24h")
|
||||
@@ -738,7 +751,7 @@ class TestBanTrend:
|
||||
async def test_7d_returns_28_buckets(self, empty_f2b_db_path: str) -> None:
|
||||
"""``range_='7d'`` yields 28 six-hour buckets."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||
):
|
||||
result = await ban_service.ban_trend("/fake/sock", "7d")
|
||||
@@ -749,7 +762,7 @@ class TestBanTrend:
|
||||
async def test_30d_returns_30_buckets(self, empty_f2b_db_path: str) -> None:
|
||||
"""``range_='30d'`` yields 30 daily buckets."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||
):
|
||||
result = await ban_service.ban_trend("/fake/sock", "30d")
|
||||
@@ -760,7 +773,7 @@ class TestBanTrend:
|
||||
async def test_365d_bucket_size_label(self, empty_f2b_db_path: str) -> None:
|
||||
"""``range_='365d'`` uses '7d' as the bucket size label."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||
):
|
||||
result = await ban_service.ban_trend("/fake/sock", "365d")
|
||||
@@ -771,7 +784,7 @@ class TestBanTrend:
|
||||
async def test_empty_db_all_buckets_zero(self, empty_f2b_db_path: str) -> None:
|
||||
"""All bucket counts are zero when the database has no bans."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||
):
|
||||
result = await ban_service.ban_trend("/fake/sock", "24h")
|
||||
@@ -781,7 +794,7 @@ class TestBanTrend:
|
||||
async def test_buckets_are_time_ordered(self, empty_f2b_db_path: str) -> None:
|
||||
"""Buckets are ordered chronologically (ascending timestamps)."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||
):
|
||||
result = await ban_service.ban_trend("/fake/sock", "7d")
|
||||
@@ -804,7 +817,7 @@ class TestBanTrend:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
):
|
||||
result = await ban_service.ban_trend("/fake/sock", "24h")
|
||||
@@ -828,7 +841,7 @@ class TestBanTrend:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
):
|
||||
result = await ban_service.ban_trend(
|
||||
@@ -854,7 +867,7 @@ class TestBanTrend:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
):
|
||||
result = await ban_service.ban_trend(
|
||||
@@ -868,7 +881,7 @@ class TestBanTrend:
|
||||
from datetime import datetime
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||
):
|
||||
result = await ban_service.ban_trend("/fake/sock", "24h")
|
||||
@@ -904,7 +917,7 @@ class TestBansByJail:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
):
|
||||
result = await ban_service.bans_by_jail("/fake/sock", "24h")
|
||||
@@ -931,7 +944,7 @@ class TestBansByJail:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
):
|
||||
result = await ban_service.bans_by_jail("/fake/sock", "24h")
|
||||
@@ -942,7 +955,7 @@ class TestBansByJail:
|
||||
async def test_empty_db_returns_empty_list(self, empty_f2b_db_path: str) -> None:
|
||||
"""An empty database returns an empty jails list with total zero."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_jail("/fake/sock", "24h")
|
||||
@@ -954,7 +967,7 @@ class TestBansByJail:
|
||||
"""Bans older than the time window are not counted."""
|
||||
# f2b_db_path has one ban from _TWO_DAYS_AGO, which is outside "24h".
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_jail("/fake/sock", "24h")
|
||||
@@ -965,7 +978,7 @@ class TestBansByJail:
|
||||
async def test_origin_filter_blocklist(self, mixed_origin_db_path: str) -> None:
|
||||
"""``origin='blocklist'`` returns only the blocklist-import jail."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_jail(
|
||||
@@ -979,7 +992,7 @@ class TestBansByJail:
|
||||
async def test_origin_filter_selfblock(self, mixed_origin_db_path: str) -> None:
|
||||
"""``origin='selfblock'`` excludes the blocklist-import jail."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_jail(
|
||||
@@ -995,7 +1008,7 @@ class TestBansByJail:
|
||||
) -> None:
|
||||
"""``origin=None`` returns bans from all jails."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_jail(
|
||||
@@ -1023,7 +1036,7 @@ class TestBansByJail:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
),
|
||||
patch("app.services.ban_service.log") as mock_log,
|
||||
|
||||
@@ -19,8 +19,8 @@ from unittest.mock import AsyncMock, patch
|
||||
import aiosqlite
|
||||
import pytest
|
||||
|
||||
from app.models.geo import GeoInfo
|
||||
from app.services import ban_service, geo_service
|
||||
from app.services.geo_service import GeoInfo
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
@@ -114,13 +114,13 @@ async def _seed_f2b_db(path: str, n: int) -> list[str]:
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def event_loop_policy() -> None: # type: ignore[misc]
|
||||
def event_loop_policy() -> None:
|
||||
"""Use the default event loop policy for module-scoped fixtures."""
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def perf_db_path(tmp_path_factory: Any) -> str: # type: ignore[misc]
|
||||
async def perf_db_path(tmp_path_factory: Any) -> str:
|
||||
"""Return the path to a fail2ban DB seeded with 10 000 synthetic bans.
|
||||
|
||||
Module-scoped so the database is created only once for all perf tests.
|
||||
@@ -161,7 +161,7 @@ class TestBanServicePerformance:
|
||||
return geo_service._cache.get(ip) # noqa: SLF001
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=perf_db_path),
|
||||
):
|
||||
start = time.perf_counter()
|
||||
@@ -191,7 +191,7 @@ class TestBanServicePerformance:
|
||||
return geo_service._cache.get(ip) # noqa: SLF001
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=perf_db_path),
|
||||
):
|
||||
start = time.perf_counter()
|
||||
@@ -217,7 +217,7 @@ class TestBanServicePerformance:
|
||||
return geo_service._cache.get(ip) # noqa: SLF001
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=perf_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
@@ -241,7 +241,7 @@ class TestBanServicePerformance:
|
||||
return geo_service._cache.get(ip) # noqa: SLF001
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=perf_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
|
||||
@@ -203,9 +203,15 @@ class TestImport:
|
||||
call_count += 1
|
||||
raise JailNotFoundError(jail)
|
||||
|
||||
with patch("app.services.jail_service.ban_ip", side_effect=_raise_jail_not_found):
|
||||
with patch("app.services.jail_service.ban_ip", side_effect=_raise_jail_not_found) as mocked_ban_ip:
|
||||
from app.services import jail_service
|
||||
|
||||
result = await blocklist_service.import_source(
|
||||
source, session, "/tmp/fake.sock", db
|
||||
source,
|
||||
session,
|
||||
"/tmp/fake.sock",
|
||||
db,
|
||||
ban_ip=jail_service.ban_ip,
|
||||
)
|
||||
|
||||
# Must abort after the first JailNotFoundError — only one ban attempt.
|
||||
@@ -226,7 +232,14 @@ class TestImport:
|
||||
with patch(
|
||||
"app.services.jail_service.ban_ip", new_callable=AsyncMock
|
||||
):
|
||||
result = await blocklist_service.import_all(db, session, "/tmp/fake.sock")
|
||||
from app.services import jail_service
|
||||
|
||||
result = await blocklist_service.import_all(
|
||||
db,
|
||||
session,
|
||||
"/tmp/fake.sock",
|
||||
ban_ip=jail_service.ban_ip,
|
||||
)
|
||||
|
||||
# Only S1 is enabled, S2 is disabled.
|
||||
assert len(result.results) == 1
|
||||
@@ -315,20 +328,15 @@ class TestGeoPrewarmCacheFilter:
|
||||
def _mock_is_cached(ip: str) -> bool:
|
||||
return ip == "1.2.3.4"
|
||||
|
||||
with (
|
||||
patch("app.services.jail_service.ban_ip", new_callable=AsyncMock),
|
||||
patch(
|
||||
"app.services.geo_service.is_cached",
|
||||
side_effect=_mock_is_cached,
|
||||
),
|
||||
patch(
|
||||
"app.services.geo_service.lookup_batch",
|
||||
new_callable=AsyncMock,
|
||||
return_value={},
|
||||
) as mock_batch,
|
||||
):
|
||||
mock_batch = AsyncMock(return_value={})
|
||||
with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock):
|
||||
result = await blocklist_service.import_source(
|
||||
source, session, "/tmp/fake.sock", db
|
||||
source,
|
||||
session,
|
||||
"/tmp/fake.sock",
|
||||
db,
|
||||
geo_is_cached=_mock_is_cached,
|
||||
geo_batch_lookup=mock_batch,
|
||||
)
|
||||
|
||||
assert result.ips_imported == 3
|
||||
@@ -337,3 +345,40 @@ class TestGeoPrewarmCacheFilter:
|
||||
call_ips = mock_batch.call_args[0][0]
|
||||
assert "1.2.3.4" not in call_ips
|
||||
assert set(call_ips) == {"5.6.7.8", "9.10.11.12"}
|
||||
|
||||
|
||||
class TestImportLogPagination:
|
||||
async def test_list_import_logs_empty(self, db: aiosqlite.Connection) -> None:
|
||||
"""list_import_logs returns an empty page when no logs exist."""
|
||||
resp = await blocklist_service.list_import_logs(
|
||||
db, source_id=None, page=1, page_size=10
|
||||
)
|
||||
assert resp.items == []
|
||||
assert resp.total == 0
|
||||
assert resp.page == 1
|
||||
assert resp.page_size == 10
|
||||
assert resp.total_pages == 1
|
||||
|
||||
async def test_list_import_logs_paginates(self, db: aiosqlite.Connection) -> None:
|
||||
"""list_import_logs computes total pages and returns the correct subset."""
|
||||
from app.repositories import import_log_repo
|
||||
|
||||
for i in range(3):
|
||||
await import_log_repo.add_log(
|
||||
db,
|
||||
source_id=None,
|
||||
source_url=f"https://example{i}.test/ips.txt",
|
||||
ips_imported=1,
|
||||
ips_skipped=0,
|
||||
errors=None,
|
||||
)
|
||||
|
||||
resp = await blocklist_service.list_import_logs(
|
||||
db, source_id=None, page=2, page_size=2
|
||||
)
|
||||
assert resp.total == 3
|
||||
assert resp.total_pages == 2
|
||||
assert resp.page == 2
|
||||
assert resp.page_size == 2
|
||||
assert len(resp.items) == 1
|
||||
assert resp.items[0].source_url == "https://example0.test/ips.txt"
|
||||
|
||||
@@ -6,7 +6,7 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.conffile_parser import (
|
||||
from app.utils.conffile_parser import (
|
||||
merge_action_update,
|
||||
merge_filter_update,
|
||||
parse_action_file,
|
||||
@@ -451,7 +451,7 @@ class TestParseJailFile:
|
||||
"""Unit tests for parse_jail_file."""
|
||||
|
||||
def test_minimal_parses_correctly(self) -> None:
|
||||
from app.services.conffile_parser import parse_jail_file
|
||||
from app.utils.conffile_parser import parse_jail_file
|
||||
|
||||
cfg = parse_jail_file(MINIMAL_JAIL, filename="sshd.conf")
|
||||
assert cfg.filename == "sshd.conf"
|
||||
@@ -463,7 +463,7 @@ class TestParseJailFile:
|
||||
assert jail.logpath == ["/var/log/auth.log"]
|
||||
|
||||
def test_full_parses_multiple_jails(self) -> None:
|
||||
from app.services.conffile_parser import parse_jail_file
|
||||
from app.utils.conffile_parser import parse_jail_file
|
||||
|
||||
cfg = parse_jail_file(FULL_JAIL)
|
||||
assert len(cfg.jails) == 2
|
||||
@@ -471,7 +471,7 @@ class TestParseJailFile:
|
||||
assert "nginx-botsearch" in cfg.jails
|
||||
|
||||
def test_full_jail_numeric_fields(self) -> None:
|
||||
from app.services.conffile_parser import parse_jail_file
|
||||
from app.utils.conffile_parser import parse_jail_file
|
||||
|
||||
jail = parse_jail_file(FULL_JAIL).jails["sshd"]
|
||||
assert jail.maxretry == 3
|
||||
@@ -479,7 +479,7 @@ class TestParseJailFile:
|
||||
assert jail.bantime == 3600
|
||||
|
||||
def test_full_jail_multiline_logpath(self) -> None:
|
||||
from app.services.conffile_parser import parse_jail_file
|
||||
from app.utils.conffile_parser import parse_jail_file
|
||||
|
||||
jail = parse_jail_file(FULL_JAIL).jails["sshd"]
|
||||
assert len(jail.logpath) == 2
|
||||
@@ -487,53 +487,53 @@ class TestParseJailFile:
|
||||
assert "/var/log/syslog" in jail.logpath
|
||||
|
||||
def test_full_jail_multiline_action(self) -> None:
|
||||
from app.services.conffile_parser import parse_jail_file
|
||||
from app.utils.conffile_parser import parse_jail_file
|
||||
|
||||
jail = parse_jail_file(FULL_JAIL).jails["nginx-botsearch"]
|
||||
assert len(jail.action) == 2
|
||||
assert "sendmail-whois" in jail.action
|
||||
|
||||
def test_enabled_true(self) -> None:
|
||||
from app.services.conffile_parser import parse_jail_file
|
||||
from app.utils.conffile_parser import parse_jail_file
|
||||
|
||||
jail = parse_jail_file(FULL_JAIL).jails["sshd"]
|
||||
assert jail.enabled is True
|
||||
|
||||
def test_enabled_false(self) -> None:
|
||||
from app.services.conffile_parser import parse_jail_file
|
||||
from app.utils.conffile_parser import parse_jail_file
|
||||
|
||||
jail = parse_jail_file(FULL_JAIL).jails["nginx-botsearch"]
|
||||
assert jail.enabled is False
|
||||
|
||||
def test_extra_keys_captured(self) -> None:
|
||||
from app.services.conffile_parser import parse_jail_file
|
||||
from app.utils.conffile_parser import parse_jail_file
|
||||
|
||||
jail = parse_jail_file(JAIL_WITH_EXTRA).jails["sshd"]
|
||||
assert jail.extra["custom_key"] == "custom_value"
|
||||
assert jail.extra["another_key"] == "42"
|
||||
|
||||
def test_extra_keys_not_in_named_fields(self) -> None:
|
||||
from app.services.conffile_parser import parse_jail_file
|
||||
from app.utils.conffile_parser import parse_jail_file
|
||||
|
||||
jail = parse_jail_file(JAIL_WITH_EXTRA).jails["sshd"]
|
||||
assert "enabled" not in jail.extra
|
||||
assert "logpath" not in jail.extra
|
||||
|
||||
def test_empty_file_yields_no_jails(self) -> None:
|
||||
from app.services.conffile_parser import parse_jail_file
|
||||
from app.utils.conffile_parser import parse_jail_file
|
||||
|
||||
cfg = parse_jail_file("")
|
||||
assert cfg.jails == {}
|
||||
|
||||
def test_invalid_ini_does_not_raise(self) -> None:
|
||||
from app.services.conffile_parser import parse_jail_file
|
||||
from app.utils.conffile_parser import parse_jail_file
|
||||
|
||||
# Should not raise; just parse what it can.
|
||||
cfg = parse_jail_file("@@@ not valid ini @@@", filename="bad.conf")
|
||||
assert isinstance(cfg.jails, dict)
|
||||
|
||||
def test_default_section_ignored(self) -> None:
|
||||
from app.services.conffile_parser import parse_jail_file
|
||||
from app.utils.conffile_parser import parse_jail_file
|
||||
|
||||
content = "[DEFAULT]\nignoreip = 127.0.0.1\n\n[sshd]\nenabled = true\n"
|
||||
cfg = parse_jail_file(content)
|
||||
@@ -550,7 +550,7 @@ class TestJailFileRoundTrip:
|
||||
"""Tests that parse → serialize → parse preserves values."""
|
||||
|
||||
def test_minimal_round_trip(self) -> None:
|
||||
from app.services.conffile_parser import parse_jail_file, serialize_jail_file_config
|
||||
from app.utils.conffile_parser import parse_jail_file, serialize_jail_file_config
|
||||
|
||||
original = parse_jail_file(MINIMAL_JAIL, filename="sshd.conf")
|
||||
serialized = serialize_jail_file_config(original)
|
||||
@@ -560,7 +560,7 @@ class TestJailFileRoundTrip:
|
||||
assert restored.jails["sshd"].logpath == original.jails["sshd"].logpath
|
||||
|
||||
def test_full_round_trip(self) -> None:
|
||||
from app.services.conffile_parser import parse_jail_file, serialize_jail_file_config
|
||||
from app.utils.conffile_parser import parse_jail_file, serialize_jail_file_config
|
||||
|
||||
original = parse_jail_file(FULL_JAIL)
|
||||
serialized = serialize_jail_file_config(original)
|
||||
@@ -573,7 +573,7 @@ class TestJailFileRoundTrip:
|
||||
assert sorted(restored_jail.action) == sorted(jail.action)
|
||||
|
||||
def test_extra_keys_round_trip(self) -> None:
|
||||
from app.services.conffile_parser import parse_jail_file, serialize_jail_file_config
|
||||
from app.utils.conffile_parser import parse_jail_file, serialize_jail_file_config
|
||||
|
||||
original = parse_jail_file(JAIL_WITH_EXTRA)
|
||||
serialized = serialize_jail_file_config(original)
|
||||
@@ -591,7 +591,7 @@ class TestMergeJailFileUpdate:
|
||||
|
||||
def test_none_update_returns_original(self) -> None:
|
||||
from app.models.config import JailFileConfigUpdate
|
||||
from app.services.conffile_parser import merge_jail_file_update, parse_jail_file
|
||||
from app.utils.conffile_parser import merge_jail_file_update, parse_jail_file
|
||||
|
||||
cfg = parse_jail_file(FULL_JAIL)
|
||||
update = JailFileConfigUpdate()
|
||||
@@ -600,7 +600,7 @@ class TestMergeJailFileUpdate:
|
||||
|
||||
def test_update_replaces_jail(self) -> None:
|
||||
from app.models.config import JailFileConfigUpdate, JailSectionConfig
|
||||
from app.services.conffile_parser import merge_jail_file_update, parse_jail_file
|
||||
from app.utils.conffile_parser import merge_jail_file_update, parse_jail_file
|
||||
|
||||
cfg = parse_jail_file(FULL_JAIL)
|
||||
new_sshd = JailSectionConfig(enabled=False, port="2222")
|
||||
@@ -613,7 +613,7 @@ class TestMergeJailFileUpdate:
|
||||
|
||||
def test_update_adds_new_jail(self) -> None:
|
||||
from app.models.config import JailFileConfigUpdate, JailSectionConfig
|
||||
from app.services.conffile_parser import merge_jail_file_update, parse_jail_file
|
||||
from app.utils.conffile_parser import merge_jail_file_update, parse_jail_file
|
||||
|
||||
cfg = parse_jail_file(MINIMAL_JAIL)
|
||||
new_jail = JailSectionConfig(enabled=True, port="443")
|
||||
|
||||
@@ -13,14 +13,19 @@ from app.services.config_file_service import (
|
||||
JailNameError,
|
||||
JailNotFoundInConfigError,
|
||||
_build_inactive_jail,
|
||||
_extract_action_base_name,
|
||||
_extract_filter_base_name,
|
||||
_ordered_config_files,
|
||||
_parse_jails_sync,
|
||||
_resolve_filter,
|
||||
_safe_jail_name,
|
||||
_validate_jail_config_sync,
|
||||
_write_local_override_sync,
|
||||
activate_jail,
|
||||
deactivate_jail,
|
||||
list_inactive_jails,
|
||||
rollback_jail,
|
||||
validate_jail_config,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -289,6 +294,24 @@ class TestBuildInactiveJail:
|
||||
jail = _build_inactive_jail("active-jail", settings, "/etc/fail2ban/jail.conf")
|
||||
assert jail.enabled is True
|
||||
|
||||
def test_has_local_override_absent(self, tmp_path: Path) -> None:
|
||||
"""has_local_override is False when no .local file exists."""
|
||||
jail = _build_inactive_jail("sshd", {}, "/etc/fail2ban/jail.d/sshd.conf", config_dir=tmp_path)
|
||||
assert jail.has_local_override is False
|
||||
|
||||
def test_has_local_override_present(self, tmp_path: Path) -> None:
|
||||
"""has_local_override is True when jail.d/{name}.local exists."""
|
||||
local = tmp_path / "jail.d" / "sshd.local"
|
||||
local.parent.mkdir(parents=True, exist_ok=True)
|
||||
local.write_text("[sshd]\nenabled = false\n")
|
||||
jail = _build_inactive_jail("sshd", {}, "/etc/fail2ban/jail.d/sshd.conf", config_dir=tmp_path)
|
||||
assert jail.has_local_override is True
|
||||
|
||||
def test_has_local_override_no_config_dir(self) -> None:
|
||||
"""has_local_override is False when config_dir is not provided."""
|
||||
jail = _build_inactive_jail("sshd", {}, "/etc/fail2ban/jail.conf")
|
||||
assert jail.has_local_override is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _write_local_override_sync
|
||||
@@ -340,9 +363,7 @@ class TestWriteLocalOverrideSync:
|
||||
assert "2222" in content
|
||||
|
||||
def test_override_logpath_list(self, tmp_path: Path) -> None:
|
||||
_write_local_override_sync(
|
||||
tmp_path, "sshd", True, {"logpath": ["/var/log/auth.log", "/var/log/secure"]}
|
||||
)
|
||||
_write_local_override_sync(tmp_path, "sshd", True, {"logpath": ["/var/log/auth.log", "/var/log/secure"]})
|
||||
content = (tmp_path / "jail.d" / "sshd.local").read_text()
|
||||
assert "/var/log/auth.log" in content
|
||||
assert "/var/log/secure" in content
|
||||
@@ -424,6 +445,117 @@ class TestListInactiveJails:
|
||||
assert "sshd" in names
|
||||
assert "apache-auth" in names
|
||||
|
||||
async def test_has_local_override_true_when_local_file_exists(self, tmp_path: Path) -> None:
|
||||
"""has_local_override is True for a jail whose jail.d .local file exists."""
|
||||
_write(tmp_path / "jail.conf", JAIL_CONF)
|
||||
local = tmp_path / "jail.d" / "apache-auth.local"
|
||||
local.parent.mkdir(parents=True, exist_ok=True)
|
||||
local.write_text("[apache-auth]\nenabled = false\n")
|
||||
with patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
):
|
||||
result = await list_inactive_jails(str(tmp_path), "/fake.sock")
|
||||
jail = next(j for j in result.jails if j.name == "apache-auth")
|
||||
assert jail.has_local_override is True
|
||||
|
||||
async def test_has_local_override_false_when_no_local_file(self, tmp_path: Path) -> None:
|
||||
"""has_local_override is False when no jail.d .local file exists."""
|
||||
_write(tmp_path / "jail.conf", JAIL_CONF)
|
||||
with patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
):
|
||||
result = await list_inactive_jails(str(tmp_path), "/fake.sock")
|
||||
jail = next(j for j in result.jails if j.name == "apache-auth")
|
||||
assert jail.has_local_override is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# delete_jail_local_override
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestDeleteJailLocalOverride:
|
||||
"""Tests for :func:`~app.services.config_file_service.delete_jail_local_override`."""
|
||||
|
||||
async def test_deletes_local_file(self, tmp_path: Path) -> None:
|
||||
"""delete_jail_local_override removes the jail.d/.local file."""
|
||||
from app.services.config_file_service import delete_jail_local_override
|
||||
|
||||
_write(tmp_path / "jail.conf", JAIL_CONF)
|
||||
local = tmp_path / "jail.d" / "apache-auth.local"
|
||||
local.parent.mkdir(parents=True, exist_ok=True)
|
||||
local.write_text("[apache-auth]\nenabled = false\n")
|
||||
|
||||
with patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
):
|
||||
await delete_jail_local_override(str(tmp_path), "/fake.sock", "apache-auth")
|
||||
|
||||
assert not local.exists()
|
||||
|
||||
async def test_no_error_when_local_file_missing(self, tmp_path: Path) -> None:
|
||||
"""delete_jail_local_override succeeds silently when no .local file exists."""
|
||||
from app.services.config_file_service import delete_jail_local_override
|
||||
|
||||
_write(tmp_path / "jail.conf", JAIL_CONF)
|
||||
with patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
):
|
||||
# Must not raise even though there is no .local file.
|
||||
await delete_jail_local_override(str(tmp_path), "/fake.sock", "apache-auth")
|
||||
|
||||
async def test_raises_jail_not_found(self, tmp_path: Path) -> None:
|
||||
"""delete_jail_local_override raises JailNotFoundInConfigError for unknown jail."""
|
||||
from app.services.config_file_service import (
|
||||
JailNotFoundInConfigError,
|
||||
delete_jail_local_override,
|
||||
)
|
||||
|
||||
_write(tmp_path / "jail.conf", JAIL_CONF)
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
),
|
||||
pytest.raises(JailNotFoundInConfigError),
|
||||
):
|
||||
await delete_jail_local_override(str(tmp_path), "/fake.sock", "nonexistent")
|
||||
|
||||
async def test_raises_jail_already_active(self, tmp_path: Path) -> None:
|
||||
"""delete_jail_local_override raises JailAlreadyActiveError when jail is running."""
|
||||
from app.services.config_file_service import (
|
||||
JailAlreadyActiveError,
|
||||
delete_jail_local_override,
|
||||
)
|
||||
|
||||
_write(tmp_path / "jail.conf", JAIL_CONF)
|
||||
local = tmp_path / "jail.d" / "sshd.local"
|
||||
local.parent.mkdir(parents=True, exist_ok=True)
|
||||
local.write_text("[sshd]\nenabled = false\n")
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value={"sshd"}),
|
||||
),
|
||||
pytest.raises(JailAlreadyActiveError),
|
||||
):
|
||||
await delete_jail_local_override(str(tmp_path), "/fake.sock", "sshd")
|
||||
|
||||
async def test_raises_jail_name_error(self, tmp_path: Path) -> None:
|
||||
"""delete_jail_local_override raises JailNameError for invalid jail names."""
|
||||
from app.services.config_file_service import (
|
||||
JailNameError,
|
||||
delete_jail_local_override,
|
||||
)
|
||||
|
||||
with pytest.raises(JailNameError):
|
||||
await delete_jail_local_override(str(tmp_path), "/fake.sock", "../evil")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# activate_jail
|
||||
@@ -470,7 +602,8 @@ class TestActivateJail:
|
||||
patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
),pytest.raises(JailNotFoundInConfigError)
|
||||
),
|
||||
pytest.raises(JailNotFoundInConfigError),
|
||||
):
|
||||
await activate_jail(str(tmp_path), "/fake.sock", "nonexistent", req)
|
||||
|
||||
@@ -483,7 +616,8 @@ class TestActivateJail:
|
||||
patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value={"sshd"}),
|
||||
),pytest.raises(JailAlreadyActiveError)
|
||||
),
|
||||
pytest.raises(JailAlreadyActiveError),
|
||||
):
|
||||
await activate_jail(str(tmp_path), "/fake.sock", "sshd", req)
|
||||
|
||||
@@ -553,7 +687,8 @@ class TestDeactivateJail:
|
||||
patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value={"sshd"}),
|
||||
),pytest.raises(JailNotFoundInConfigError)
|
||||
),
|
||||
pytest.raises(JailNotFoundInConfigError),
|
||||
):
|
||||
await deactivate_jail(str(tmp_path), "/fake.sock", "nonexistent")
|
||||
|
||||
@@ -563,7 +698,8 @@ class TestDeactivateJail:
|
||||
patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
),pytest.raises(JailAlreadyInactiveError)
|
||||
),
|
||||
pytest.raises(JailAlreadyInactiveError),
|
||||
):
|
||||
await deactivate_jail(str(tmp_path), "/fake.sock", "apache-auth")
|
||||
|
||||
@@ -572,38 +708,6 @@ class TestDeactivateJail:
|
||||
await deactivate_jail(str(tmp_path), "/fake.sock", "a/b")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _extract_filter_base_name
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExtractFilterBaseName:
|
||||
def test_simple_name(self) -> None:
|
||||
from app.services.config_file_service import _extract_filter_base_name
|
||||
|
||||
assert _extract_filter_base_name("sshd") == "sshd"
|
||||
|
||||
def test_name_with_mode(self) -> None:
|
||||
from app.services.config_file_service import _extract_filter_base_name
|
||||
|
||||
assert _extract_filter_base_name("sshd[mode=aggressive]") == "sshd"
|
||||
|
||||
def test_name_with_variable_mode(self) -> None:
|
||||
from app.services.config_file_service import _extract_filter_base_name
|
||||
|
||||
assert _extract_filter_base_name("sshd[mode=%(mode)s]") == "sshd"
|
||||
|
||||
def test_whitespace_stripped(self) -> None:
|
||||
from app.services.config_file_service import _extract_filter_base_name
|
||||
|
||||
assert _extract_filter_base_name(" nginx ") == "nginx"
|
||||
|
||||
def test_empty_string(self) -> None:
|
||||
from app.services.config_file_service import _extract_filter_base_name
|
||||
|
||||
assert _extract_filter_base_name("") == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_filter_to_jails_map
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -619,9 +723,7 @@ class TestBuildFilterToJailsMap:
|
||||
def test_inactive_jail_not_included(self) -> None:
|
||||
from app.services.config_file_service import _build_filter_to_jails_map
|
||||
|
||||
result = _build_filter_to_jails_map(
|
||||
{"apache-auth": {"filter": "apache-auth"}}, set()
|
||||
)
|
||||
result = _build_filter_to_jails_map({"apache-auth": {"filter": "apache-auth"}}, set())
|
||||
assert result == {}
|
||||
|
||||
def test_multiple_jails_sharing_filter(self) -> None:
|
||||
@@ -637,9 +739,7 @@ class TestBuildFilterToJailsMap:
|
||||
def test_mode_suffix_stripped(self) -> None:
|
||||
from app.services.config_file_service import _build_filter_to_jails_map
|
||||
|
||||
result = _build_filter_to_jails_map(
|
||||
{"sshd": {"filter": "sshd[mode=aggressive]"}}, {"sshd"}
|
||||
)
|
||||
result = _build_filter_to_jails_map({"sshd": {"filter": "sshd[mode=aggressive]"}}, {"sshd"})
|
||||
assert "sshd" in result
|
||||
|
||||
def test_missing_filter_key_falls_back_to_jail_name(self) -> None:
|
||||
@@ -850,10 +950,13 @@ class TestGetFilter:
|
||||
async def test_raises_filter_not_found(self, tmp_path: Path) -> None:
|
||||
from app.services.config_file_service import FilterNotFoundError, get_filter
|
||||
|
||||
with patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
), pytest.raises(FilterNotFoundError):
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
),
|
||||
pytest.raises(FilterNotFoundError),
|
||||
):
|
||||
await get_filter(str(tmp_path), "/fake.sock", "nonexistent")
|
||||
|
||||
async def test_has_local_override_detected(self, tmp_path: Path) -> None:
|
||||
@@ -955,10 +1058,13 @@ class TestGetFilterLocalOnly:
|
||||
async def test_raises_when_neither_conf_nor_local(self, tmp_path: Path) -> None:
|
||||
from app.services.config_file_service import FilterNotFoundError, get_filter
|
||||
|
||||
with patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
), pytest.raises(FilterNotFoundError):
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
),
|
||||
pytest.raises(FilterNotFoundError),
|
||||
):
|
||||
await get_filter(str(tmp_path), "/fake.sock", "nonexistent")
|
||||
|
||||
async def test_accepts_local_extension(self, tmp_path: Path) -> None:
|
||||
@@ -1074,9 +1180,7 @@ class TestSetJailLocalKeySync:
|
||||
|
||||
jail_d = tmp_path / "jail.d"
|
||||
jail_d.mkdir()
|
||||
(jail_d / "sshd.local").write_text(
|
||||
"[sshd]\nenabled = true\n"
|
||||
)
|
||||
(jail_d / "sshd.local").write_text("[sshd]\nenabled = true\n")
|
||||
|
||||
_set_jail_local_key_sync(tmp_path, "sshd", "filter", "newfilter")
|
||||
|
||||
@@ -1162,10 +1266,13 @@ class TestUpdateFilter:
|
||||
from app.models.config import FilterUpdateRequest
|
||||
from app.services.config_file_service import FilterNotFoundError, update_filter
|
||||
|
||||
with patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
), pytest.raises(FilterNotFoundError):
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
),
|
||||
pytest.raises(FilterNotFoundError),
|
||||
):
|
||||
await update_filter(
|
||||
str(tmp_path),
|
||||
"/fake.sock",
|
||||
@@ -1183,10 +1290,13 @@ class TestUpdateFilter:
|
||||
filter_d = tmp_path / "filter.d"
|
||||
_write(filter_d / "sshd.conf", _FILTER_CONF_WITH_REGEX)
|
||||
|
||||
with patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
), pytest.raises(FilterInvalidRegexError):
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
),
|
||||
pytest.raises(FilterInvalidRegexError),
|
||||
):
|
||||
await update_filter(
|
||||
str(tmp_path),
|
||||
"/fake.sock",
|
||||
@@ -1213,13 +1323,16 @@ class TestUpdateFilter:
|
||||
filter_d = tmp_path / "filter.d"
|
||||
_write(filter_d / "sshd.conf", _FILTER_CONF)
|
||||
|
||||
with patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
), patch(
|
||||
"app.services.config_file_service.jail_service.reload_all",
|
||||
new=AsyncMock(),
|
||||
) as mock_reload:
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
),
|
||||
patch(
|
||||
"app.services.config_file_service.jail_service.reload_all",
|
||||
new=AsyncMock(),
|
||||
) as mock_reload,
|
||||
):
|
||||
await update_filter(
|
||||
str(tmp_path),
|
||||
"/fake.sock",
|
||||
@@ -1267,10 +1380,13 @@ class TestCreateFilter:
|
||||
filter_d = tmp_path / "filter.d"
|
||||
_write(filter_d / "sshd.conf", _FILTER_CONF)
|
||||
|
||||
with patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
), pytest.raises(FilterAlreadyExistsError):
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
),
|
||||
pytest.raises(FilterAlreadyExistsError),
|
||||
):
|
||||
await create_filter(
|
||||
str(tmp_path),
|
||||
"/fake.sock",
|
||||
@@ -1284,10 +1400,13 @@ class TestCreateFilter:
|
||||
filter_d = tmp_path / "filter.d"
|
||||
_write(filter_d / "custom.local", "[Definition]\n")
|
||||
|
||||
with patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
), pytest.raises(FilterAlreadyExistsError):
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
),
|
||||
pytest.raises(FilterAlreadyExistsError),
|
||||
):
|
||||
await create_filter(
|
||||
str(tmp_path),
|
||||
"/fake.sock",
|
||||
@@ -1298,10 +1417,13 @@ class TestCreateFilter:
|
||||
from app.models.config import FilterCreateRequest
|
||||
from app.services.config_file_service import FilterInvalidRegexError, create_filter
|
||||
|
||||
with patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
), pytest.raises(FilterInvalidRegexError):
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
),
|
||||
pytest.raises(FilterInvalidRegexError),
|
||||
):
|
||||
await create_filter(
|
||||
str(tmp_path),
|
||||
"/fake.sock",
|
||||
@@ -1323,13 +1445,16 @@ class TestCreateFilter:
|
||||
from app.models.config import FilterCreateRequest
|
||||
from app.services.config_file_service import create_filter
|
||||
|
||||
with patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
), patch(
|
||||
"app.services.config_file_service.jail_service.reload_all",
|
||||
new=AsyncMock(),
|
||||
) as mock_reload:
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
),
|
||||
patch(
|
||||
"app.services.config_file_service.jail_service.reload_all",
|
||||
new=AsyncMock(),
|
||||
) as mock_reload,
|
||||
):
|
||||
await create_filter(
|
||||
str(tmp_path),
|
||||
"/fake.sock",
|
||||
@@ -1347,9 +1472,7 @@ class TestCreateFilter:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestDeleteFilter:
|
||||
async def test_deletes_local_file_when_conf_and_local_exist(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
async def test_deletes_local_file_when_conf_and_local_exist(self, tmp_path: Path) -> None:
|
||||
from app.services.config_file_service import delete_filter
|
||||
|
||||
filter_d = tmp_path / "filter.d"
|
||||
@@ -1386,9 +1509,7 @@ class TestDeleteFilter:
|
||||
with pytest.raises(FilterNotFoundError):
|
||||
await delete_filter(str(tmp_path), "nonexistent")
|
||||
|
||||
async def test_accepts_filter_name_error_for_invalid_name(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
async def test_accepts_filter_name_error_for_invalid_name(self, tmp_path: Path) -> None:
|
||||
from app.services.config_file_service import FilterNameError, delete_filter
|
||||
|
||||
with pytest.raises(FilterNameError):
|
||||
@@ -1469,9 +1590,7 @@ class TestAssignFilterToJail:
|
||||
AssignFilterRequest(filter_name="sshd"),
|
||||
)
|
||||
|
||||
async def test_raises_filter_name_error_for_invalid_filter(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
async def test_raises_filter_name_error_for_invalid_filter(self, tmp_path: Path) -> None:
|
||||
from app.models.config import AssignFilterRequest
|
||||
from app.services.config_file_service import FilterNameError, assign_filter_to_jail
|
||||
|
||||
@@ -1581,34 +1700,26 @@ class TestBuildActionToJailsMap:
|
||||
def test_active_jail_maps_to_action(self) -> None:
|
||||
from app.services.config_file_service import _build_action_to_jails_map
|
||||
|
||||
result = _build_action_to_jails_map(
|
||||
{"sshd": {"action": "iptables-multiport"}}, {"sshd"}
|
||||
)
|
||||
result = _build_action_to_jails_map({"sshd": {"action": "iptables-multiport"}}, {"sshd"})
|
||||
assert result == {"iptables-multiport": ["sshd"]}
|
||||
|
||||
def test_inactive_jail_not_included(self) -> None:
|
||||
from app.services.config_file_service import _build_action_to_jails_map
|
||||
|
||||
result = _build_action_to_jails_map(
|
||||
{"sshd": {"action": "iptables-multiport"}}, set()
|
||||
)
|
||||
result = _build_action_to_jails_map({"sshd": {"action": "iptables-multiport"}}, set())
|
||||
assert result == {}
|
||||
|
||||
def test_multiple_actions_per_jail(self) -> None:
|
||||
from app.services.config_file_service import _build_action_to_jails_map
|
||||
|
||||
result = _build_action_to_jails_map(
|
||||
{"sshd": {"action": "iptables-multiport\niptables-ipset"}}, {"sshd"}
|
||||
)
|
||||
result = _build_action_to_jails_map({"sshd": {"action": "iptables-multiport\niptables-ipset"}}, {"sshd"})
|
||||
assert "iptables-multiport" in result
|
||||
assert "iptables-ipset" in result
|
||||
|
||||
def test_parameter_block_stripped(self) -> None:
|
||||
from app.services.config_file_service import _build_action_to_jails_map
|
||||
|
||||
result = _build_action_to_jails_map(
|
||||
{"sshd": {"action": "iptables[port=ssh, protocol=tcp]"}}, {"sshd"}
|
||||
)
|
||||
result = _build_action_to_jails_map({"sshd": {"action": "iptables[port=ssh, protocol=tcp]"}}, {"sshd"})
|
||||
assert "iptables" in result
|
||||
|
||||
def test_multiple_jails_sharing_action(self) -> None:
|
||||
@@ -1863,10 +1974,13 @@ class TestGetAction:
|
||||
async def test_raises_for_unknown_action(self, tmp_path: Path) -> None:
|
||||
from app.services.config_file_service import ActionNotFoundError, get_action
|
||||
|
||||
with patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
), pytest.raises(ActionNotFoundError):
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
),
|
||||
pytest.raises(ActionNotFoundError),
|
||||
):
|
||||
await get_action(str(tmp_path), "/fake.sock", "nonexistent")
|
||||
|
||||
async def test_local_only_action_returned(self, tmp_path: Path) -> None:
|
||||
@@ -1980,10 +2094,13 @@ class TestUpdateAction:
|
||||
from app.models.config import ActionUpdateRequest
|
||||
from app.services.config_file_service import ActionNotFoundError, update_action
|
||||
|
||||
with patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
), pytest.raises(ActionNotFoundError):
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
),
|
||||
pytest.raises(ActionNotFoundError),
|
||||
):
|
||||
await update_action(
|
||||
str(tmp_path),
|
||||
"/fake.sock",
|
||||
@@ -2449,9 +2566,7 @@ class TestRemoveActionFromJail:
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
):
|
||||
await remove_action_from_jail(
|
||||
str(tmp_path), "/fake.sock", "sshd", "iptables-multiport"
|
||||
)
|
||||
await remove_action_from_jail(str(tmp_path), "/fake.sock", "sshd", "iptables-multiport")
|
||||
|
||||
content = (jail_d / "sshd.local").read_text()
|
||||
assert "iptables-multiport" not in content
|
||||
@@ -2463,17 +2578,13 @@ class TestRemoveActionFromJail:
|
||||
)
|
||||
|
||||
with pytest.raises(JailNotFoundInConfigError):
|
||||
await remove_action_from_jail(
|
||||
str(tmp_path), "/fake.sock", "nonexistent", "iptables"
|
||||
)
|
||||
await remove_action_from_jail(str(tmp_path), "/fake.sock", "nonexistent", "iptables")
|
||||
|
||||
async def test_raises_jail_name_error(self, tmp_path: Path) -> None:
|
||||
from app.services.config_file_service import JailNameError, remove_action_from_jail
|
||||
|
||||
with pytest.raises(JailNameError):
|
||||
await remove_action_from_jail(
|
||||
str(tmp_path), "/fake.sock", "../evil", "iptables"
|
||||
)
|
||||
await remove_action_from_jail(str(tmp_path), "/fake.sock", "../evil", "iptables")
|
||||
|
||||
async def test_raises_action_name_error(self, tmp_path: Path) -> None:
|
||||
from app.services.config_file_service import ActionNameError, remove_action_from_jail
|
||||
@@ -2481,9 +2592,7 @@ class TestRemoveActionFromJail:
|
||||
_write(tmp_path / "jail.conf", JAIL_CONF)
|
||||
|
||||
with pytest.raises(ActionNameError):
|
||||
await remove_action_from_jail(
|
||||
str(tmp_path), "/fake.sock", "sshd", "../evil"
|
||||
)
|
||||
await remove_action_from_jail(str(tmp_path), "/fake.sock", "sshd", "../evil")
|
||||
|
||||
async def test_triggers_reload_when_requested(self, tmp_path: Path) -> None:
|
||||
from app.services.config_file_service import remove_action_from_jail
|
||||
@@ -2502,9 +2611,7 @@ class TestRemoveActionFromJail:
|
||||
new=AsyncMock(),
|
||||
) as mock_reload,
|
||||
):
|
||||
await remove_action_from_jail(
|
||||
str(tmp_path), "/fake.sock", "sshd", "iptables", do_reload=True
|
||||
)
|
||||
await remove_action_from_jail(str(tmp_path), "/fake.sock", "sshd", "iptables", do_reload=True)
|
||||
|
||||
mock_reload.assert_awaited_once()
|
||||
|
||||
@@ -2542,13 +2649,9 @@ class TestActivateJailReloadArgs:
|
||||
mock_js.reload_all = AsyncMock()
|
||||
await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
|
||||
|
||||
mock_js.reload_all.assert_awaited_once_with(
|
||||
"/fake.sock", include_jails=["apache-auth"]
|
||||
)
|
||||
mock_js.reload_all.assert_awaited_once_with("/fake.sock", include_jails=["apache-auth"])
|
||||
|
||||
async def test_activate_returns_active_true_when_jail_starts(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
async def test_activate_returns_active_true_when_jail_starts(self, tmp_path: Path) -> None:
|
||||
"""activate_jail returns active=True when the jail appears in post-reload names."""
|
||||
_write(tmp_path / "jail.conf", JAIL_CONF)
|
||||
from app.models.config import ActivateJailRequest, JailValidationResult
|
||||
@@ -2570,16 +2673,12 @@ class TestActivateJailReloadArgs:
|
||||
),
|
||||
):
|
||||
mock_js.reload_all = AsyncMock()
|
||||
result = await activate_jail(
|
||||
str(tmp_path), "/fake.sock", "apache-auth", req
|
||||
)
|
||||
result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
|
||||
|
||||
assert result.active is True
|
||||
assert "activated" in result.message.lower()
|
||||
|
||||
async def test_activate_returns_active_false_when_jail_does_not_start(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
async def test_activate_returns_active_false_when_jail_does_not_start(self, tmp_path: Path) -> None:
|
||||
"""activate_jail returns active=False when the jail is absent after reload.
|
||||
|
||||
This covers the Stage 3.1 requirement: if the jail config is invalid
|
||||
@@ -2608,9 +2707,7 @@ class TestActivateJailReloadArgs:
|
||||
),
|
||||
):
|
||||
mock_js.reload_all = AsyncMock()
|
||||
result = await activate_jail(
|
||||
str(tmp_path), "/fake.sock", "apache-auth", req
|
||||
)
|
||||
result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
|
||||
|
||||
assert result.active is False
|
||||
assert "apache-auth" in result.name
|
||||
@@ -2638,23 +2735,13 @@ class TestDeactivateJailReloadArgs:
|
||||
mock_js.reload_all = AsyncMock()
|
||||
await deactivate_jail(str(tmp_path), "/fake.sock", "sshd")
|
||||
|
||||
mock_js.reload_all.assert_awaited_once_with(
|
||||
"/fake.sock", exclude_jails=["sshd"]
|
||||
)
|
||||
mock_js.reload_all.assert_awaited_once_with("/fake.sock", exclude_jails=["sshd"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_jail_config_sync (Task 3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from app.services.config_file_service import ( # noqa: E402 (added after block)
|
||||
_validate_jail_config_sync,
|
||||
_extract_filter_base_name,
|
||||
_extract_action_base_name,
|
||||
validate_jail_config,
|
||||
rollback_jail,
|
||||
)
|
||||
|
||||
|
||||
class TestExtractFilterBaseName:
|
||||
def test_plain_name(self) -> None:
|
||||
@@ -2800,11 +2887,11 @@ class TestRollbackJail:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_file_service._start_daemon",
|
||||
"app.services.config_file_service.start_daemon",
|
||||
new=AsyncMock(return_value=True),
|
||||
),
|
||||
patch(
|
||||
"app.services.config_file_service._wait_for_fail2ban",
|
||||
"app.services.config_file_service.wait_for_fail2ban",
|
||||
new=AsyncMock(return_value=True),
|
||||
),
|
||||
patch(
|
||||
@@ -2812,9 +2899,7 @@ class TestRollbackJail:
|
||||
new=AsyncMock(return_value=set()),
|
||||
),
|
||||
):
|
||||
result = await rollback_jail(
|
||||
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
|
||||
)
|
||||
result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
|
||||
|
||||
assert result.disabled is True
|
||||
assert result.fail2ban_running is True
|
||||
@@ -2830,26 +2915,22 @@ class TestRollbackJail:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_file_service._start_daemon",
|
||||
"app.services.config_file_service.start_daemon",
|
||||
new=AsyncMock(return_value=False),
|
||||
),
|
||||
patch(
|
||||
"app.services.config_file_service._wait_for_fail2ban",
|
||||
"app.services.config_file_service.wait_for_fail2ban",
|
||||
new=AsyncMock(return_value=False),
|
||||
),
|
||||
):
|
||||
result = await rollback_jail(
|
||||
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
|
||||
)
|
||||
result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
|
||||
|
||||
assert result.fail2ban_running is False
|
||||
assert result.disabled is True
|
||||
|
||||
async def test_rollback_rejects_unsafe_name(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(JailNameError):
|
||||
await rollback_jail(
|
||||
str(tmp_path), "/fake.sock", "../evil", ["fail2ban-client", "start"]
|
||||
)
|
||||
await rollback_jail(str(tmp_path), "/fake.sock", "../evil", ["fail2ban-client", "start"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -2958,9 +3039,7 @@ class TestActivateJailBlocking:
|
||||
class TestActivateJailRollback:
|
||||
"""Rollback logic in activate_jail restores the .local file and recovers."""
|
||||
|
||||
async def test_activate_jail_rollback_on_reload_failure(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
async def test_activate_jail_rollback_on_reload_failure(self, tmp_path: Path) -> None:
|
||||
"""Rollback when reload_all raises on the activation reload.
|
||||
|
||||
Expects:
|
||||
@@ -2997,23 +3076,17 @@ class TestActivateJailRollback:
|
||||
),
|
||||
patch(
|
||||
"app.services.config_file_service._validate_jail_config_sync",
|
||||
return_value=JailValidationResult(
|
||||
jail_name="apache-auth", valid=True
|
||||
),
|
||||
return_value=JailValidationResult(jail_name="apache-auth", valid=True),
|
||||
),
|
||||
):
|
||||
mock_js.reload_all = AsyncMock(side_effect=reload_side_effect)
|
||||
result = await activate_jail(
|
||||
str(tmp_path), "/fake.sock", "apache-auth", req
|
||||
)
|
||||
result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
|
||||
|
||||
assert result.active is False
|
||||
assert result.recovered is True
|
||||
assert local_path.read_text() == original_local
|
||||
|
||||
async def test_activate_jail_rollback_on_health_check_failure(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
async def test_activate_jail_rollback_on_health_check_failure(self, tmp_path: Path) -> None:
|
||||
"""Rollback when fail2ban is unreachable after the activation reload.
|
||||
|
||||
Expects:
|
||||
@@ -3052,15 +3125,11 @@ class TestActivateJailRollback:
|
||||
),
|
||||
patch(
|
||||
"app.services.config_file_service._validate_jail_config_sync",
|
||||
return_value=JailValidationResult(
|
||||
jail_name="apache-auth", valid=True
|
||||
),
|
||||
return_value=JailValidationResult(jail_name="apache-auth", valid=True),
|
||||
),
|
||||
):
|
||||
mock_js.reload_all = AsyncMock()
|
||||
result = await activate_jail(
|
||||
str(tmp_path), "/fake.sock", "apache-auth", req
|
||||
)
|
||||
result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
|
||||
|
||||
assert result.active is False
|
||||
assert result.recovered is True
|
||||
@@ -3094,25 +3163,17 @@ class TestActivateJailRollback:
|
||||
),
|
||||
patch(
|
||||
"app.services.config_file_service._validate_jail_config_sync",
|
||||
return_value=JailValidationResult(
|
||||
jail_name="apache-auth", valid=True
|
||||
),
|
||||
return_value=JailValidationResult(jail_name="apache-auth", valid=True),
|
||||
),
|
||||
):
|
||||
# Both the activation reload and the recovery reload fail.
|
||||
mock_js.reload_all = AsyncMock(
|
||||
side_effect=RuntimeError("fail2ban unavailable")
|
||||
)
|
||||
result = await activate_jail(
|
||||
str(tmp_path), "/fake.sock", "apache-auth", req
|
||||
)
|
||||
mock_js.reload_all = AsyncMock(side_effect=RuntimeError("fail2ban unavailable"))
|
||||
result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
|
||||
|
||||
assert result.active is False
|
||||
assert result.recovered is False
|
||||
|
||||
async def test_activate_jail_rollback_on_jail_not_found_error(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
async def test_activate_jail_rollback_on_jail_not_found_error(self, tmp_path: Path) -> None:
|
||||
"""Rollback when reload_all raises JailNotFoundError (invalid config).
|
||||
|
||||
When fail2ban cannot create a jail due to invalid configuration
|
||||
@@ -3156,16 +3217,12 @@ class TestActivateJailRollback:
|
||||
),
|
||||
patch(
|
||||
"app.services.config_file_service._validate_jail_config_sync",
|
||||
return_value=JailValidationResult(
|
||||
jail_name="apache-auth", valid=True
|
||||
),
|
||||
return_value=JailValidationResult(jail_name="apache-auth", valid=True),
|
||||
),
|
||||
):
|
||||
mock_js.reload_all = AsyncMock(side_effect=reload_side_effect)
|
||||
mock_js.JailNotFoundError = JailNotFoundError
|
||||
result = await activate_jail(
|
||||
str(tmp_path), "/fake.sock", "apache-auth", req
|
||||
)
|
||||
result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
|
||||
|
||||
assert result.active is False
|
||||
assert result.recovered is True
|
||||
@@ -3173,5 +3230,184 @@ class TestActivateJailRollback:
|
||||
# Verify the error message mentions logpath issues.
|
||||
assert "logpath" in result.message.lower() or "check that all logpath" in result.message.lower()
|
||||
|
||||
async def test_activate_jail_rollback_deletes_file_when_no_prior_local(self, tmp_path: Path) -> None:
|
||||
"""Rollback deletes the .local file when none existed before activation.
|
||||
|
||||
When a jail had no .local override before activation, activate_jail
|
||||
creates one with enabled = true. If reload then crashes, rollback must
|
||||
delete that file (leaving the jail in the same state as before the
|
||||
activation attempt).
|
||||
|
||||
Expects:
|
||||
- The .local file is absent after rollback.
|
||||
- The response indicates recovered=True.
|
||||
"""
|
||||
from app.models.config import ActivateJailRequest, JailValidationResult
|
||||
|
||||
_write(tmp_path / "jail.conf", JAIL_CONF)
|
||||
(tmp_path / "jail.d").mkdir(parents=True, exist_ok=True)
|
||||
local_path = tmp_path / "jail.d" / "apache-auth.local"
|
||||
# No .local file exists before activation.
|
||||
assert not local_path.exists()
|
||||
|
||||
req = ActivateJailRequest()
|
||||
reload_call_count = 0
|
||||
|
||||
async def reload_side_effect(socket_path: str, **kwargs: object) -> None:
|
||||
nonlocal reload_call_count
|
||||
reload_call_count += 1
|
||||
if reload_call_count == 1:
|
||||
raise RuntimeError("fail2ban crashed")
|
||||
# Recovery reload succeeds.
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
),
|
||||
patch("app.services.config_file_service.jail_service") as mock_js,
|
||||
patch(
|
||||
"app.services.config_file_service._probe_fail2ban_running",
|
||||
new=AsyncMock(return_value=True),
|
||||
),
|
||||
patch(
|
||||
"app.services.config_file_service._validate_jail_config_sync",
|
||||
return_value=JailValidationResult(jail_name="apache-auth", valid=True),
|
||||
),
|
||||
):
|
||||
mock_js.reload_all = AsyncMock(side_effect=reload_side_effect)
|
||||
result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
|
||||
|
||||
assert result.active is False
|
||||
assert result.recovered is True
|
||||
assert not local_path.exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# rollback_jail
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestRollbackJailIntegration:
|
||||
"""Integration tests for :func:`~app.services.config_file_service.rollback_jail`."""
|
||||
|
||||
async def test_local_file_written_enabled_false(self, tmp_path: Path) -> None:
|
||||
"""rollback_jail writes enabled=false to jail.d/{name}.local before any socket call."""
|
||||
(tmp_path / "jail.d").mkdir()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_file_service.start_daemon",
|
||||
AsyncMock(return_value=True),
|
||||
),
|
||||
patch(
|
||||
"app.services.config_file_service.wait_for_fail2ban",
|
||||
AsyncMock(return_value=True),
|
||||
),
|
||||
patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
AsyncMock(return_value={"sshd"}),
|
||||
),
|
||||
):
|
||||
await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
|
||||
|
||||
local = tmp_path / "jail.d" / "sshd.local"
|
||||
assert local.is_file(), "jail.d/sshd.local must be written"
|
||||
content = local.read_text()
|
||||
assert "enabled = false" in content
|
||||
|
||||
async def test_start_command_invoked_via_subprocess(self, tmp_path: Path) -> None:
|
||||
"""rollback_jail invokes the daemon start command via start_daemon, not via socket."""
|
||||
mock_start = AsyncMock(return_value=True)
|
||||
|
||||
with (
|
||||
patch("app.services.config_file_service.start_daemon", mock_start),
|
||||
patch(
|
||||
"app.services.config_file_service.wait_for_fail2ban",
|
||||
AsyncMock(return_value=True),
|
||||
),
|
||||
patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
AsyncMock(return_value={"other"}),
|
||||
),
|
||||
):
|
||||
await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
|
||||
|
||||
mock_start.assert_awaited_once_with(["fail2ban-client", "start"])
|
||||
|
||||
async def test_fail2ban_running_reflects_socket_probe_not_subprocess_exit(self, tmp_path: Path) -> None:
|
||||
"""fail2ban_running in the response reflects the socket probe result.
|
||||
|
||||
Even when start_daemon returns True (subprocess exit 0), if the socket
|
||||
probe returns False the response must report fail2ban_running=False.
|
||||
"""
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_file_service.start_daemon",
|
||||
AsyncMock(return_value=True),
|
||||
),
|
||||
patch(
|
||||
"app.services.config_file_service.wait_for_fail2ban",
|
||||
AsyncMock(return_value=False), # socket still unresponsive
|
||||
),
|
||||
):
|
||||
result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
|
||||
|
||||
assert result.fail2ban_running is False
|
||||
|
||||
async def test_active_jails_zero_when_fail2ban_not_running(self, tmp_path: Path) -> None:
|
||||
"""active_jails is 0 in the response when fail2ban_running is False."""
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_file_service.start_daemon",
|
||||
AsyncMock(return_value=False),
|
||||
),
|
||||
patch(
|
||||
"app.services.config_file_service.wait_for_fail2ban",
|
||||
AsyncMock(return_value=False),
|
||||
),
|
||||
):
|
||||
result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
|
||||
|
||||
assert result.active_jails == 0
|
||||
|
||||
async def test_active_jails_count_from_socket_when_running(self, tmp_path: Path) -> None:
|
||||
"""active_jails reflects the actual jail count from the socket when fail2ban is up."""
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_file_service.start_daemon",
|
||||
AsyncMock(return_value=True),
|
||||
),
|
||||
patch(
|
||||
"app.services.config_file_service.wait_for_fail2ban",
|
||||
AsyncMock(return_value=True),
|
||||
),
|
||||
patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
AsyncMock(return_value={"sshd", "nginx", "apache-auth"}),
|
||||
),
|
||||
):
|
||||
result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
|
||||
|
||||
assert result.active_jails == 3
|
||||
|
||||
async def test_fail2ban_down_at_start_still_succeeds_file_write(self, tmp_path: Path) -> None:
|
||||
"""rollback_jail writes the local file even when fail2ban is down at call time."""
|
||||
# fail2ban is down: start_daemon fails and wait_for_fail2ban returns False.
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_file_service.start_daemon",
|
||||
AsyncMock(return_value=False),
|
||||
),
|
||||
patch(
|
||||
"app.services.config_file_service.wait_for_fail2ban",
|
||||
AsyncMock(return_value=False),
|
||||
),
|
||||
):
|
||||
result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
|
||||
|
||||
local = tmp_path / "jail.d" / "sshd.local"
|
||||
assert local.is_file(), "local file must be written even when fail2ban is down"
|
||||
assert result.disabled is True
|
||||
assert result.fail2ban_running is False
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
@@ -256,6 +257,27 @@ class TestUpdateJailConfig:
|
||||
assert "bantime" in keys
|
||||
assert "maxretry" in keys
|
||||
|
||||
async def test_ignores_backend_field(self) -> None:
|
||||
"""update_jail_config does not send a set command for backend."""
|
||||
sent_commands: list[list[Any]] = []
|
||||
|
||||
async def _send(command: list[Any]) -> Any:
|
||||
sent_commands.append(command)
|
||||
return (0, "OK")
|
||||
|
||||
class _FakeClient:
|
||||
def __init__(self, **_kw: Any) -> None:
|
||||
self.send = AsyncMock(side_effect=_send)
|
||||
|
||||
from app.models.config import JailConfigUpdate
|
||||
|
||||
update = JailConfigUpdate(backend="polling")
|
||||
with patch("app.services.config_service.Fail2BanClient", _FakeClient):
|
||||
await config_service.update_jail_config(_SOCKET, "sshd", update)
|
||||
|
||||
keys = [cmd[2] for cmd in sent_commands if len(cmd) >= 3 and cmd[0] == "set"]
|
||||
assert "backend" not in keys
|
||||
|
||||
async def test_raises_validation_error_on_bad_regex(self) -> None:
|
||||
"""update_jail_config raises ConfigValidationError for invalid regex."""
|
||||
from app.models.config import JailConfigUpdate
|
||||
@@ -721,12 +743,16 @@ class TestGetServiceStatus:
|
||||
def __init__(self, **_kw: Any) -> None:
|
||||
self.send = AsyncMock(side_effect=_send)
|
||||
|
||||
with patch("app.services.config_service.Fail2BanClient", _FakeClient), \
|
||||
patch("app.services.health_service.probe", AsyncMock(return_value=online_status)):
|
||||
result = await config_service.get_service_status(_SOCKET)
|
||||
with patch("app.services.config_service.Fail2BanClient", _FakeClient):
|
||||
result = await config_service.get_service_status(
|
||||
_SOCKET,
|
||||
probe_fn=AsyncMock(return_value=online_status),
|
||||
)
|
||||
|
||||
from app import __version__
|
||||
|
||||
assert result.online is True
|
||||
assert result.version == "1.0.0"
|
||||
assert result.version == __version__
|
||||
assert result.jail_count == 2
|
||||
assert result.total_bans == 5
|
||||
assert result.total_failures == 3
|
||||
@@ -739,10 +765,71 @@ class TestGetServiceStatus:
|
||||
|
||||
offline_status = ServerStatus(online=False)
|
||||
|
||||
with patch("app.services.health_service.probe", AsyncMock(return_value=offline_status)):
|
||||
result = await config_service.get_service_status(_SOCKET)
|
||||
result = await config_service.get_service_status(
|
||||
_SOCKET,
|
||||
probe_fn=AsyncMock(return_value=offline_status),
|
||||
)
|
||||
|
||||
assert result.online is False
|
||||
assert result.jail_count == 0
|
||||
assert result.log_level == "UNKNOWN"
|
||||
assert result.log_target == "UNKNOWN"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestConfigModuleIntegration:
|
||||
async def test_jail_config_service_list_inactive_jails_uses_imports(self, tmp_path: Any) -> None:
|
||||
from app.services.jail_config_service import list_inactive_jails
|
||||
|
||||
# Arrange: fake parse_jails output with one active and one inactive
|
||||
def fake_parse_jails_sync(path: Path) -> tuple[dict[str, dict[str, str]], dict[str, str]]:
|
||||
return (
|
||||
{
|
||||
"sshd": {
|
||||
"enabled": "true",
|
||||
"filter": "sshd",
|
||||
"logpath": "/var/log/auth.log",
|
||||
},
|
||||
"apache-auth": {
|
||||
"enabled": "false",
|
||||
"filter": "apache-auth",
|
||||
"logpath": "/var/log/apache2/error.log",
|
||||
},
|
||||
},
|
||||
{
|
||||
"sshd": str(path / "jail.conf"),
|
||||
"apache-auth": str(path / "jail.conf"),
|
||||
},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.services.jail_config_service._parse_jails_sync",
|
||||
new=fake_parse_jails_sync,
|
||||
), patch(
|
||||
"app.services.jail_config_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value={"sshd"}),
|
||||
):
|
||||
result = await list_inactive_jails(str(tmp_path), "/fake.sock")
|
||||
|
||||
names = {j.name for j in result.jails}
|
||||
assert "apache-auth" in names
|
||||
assert "sshd" not in names
|
||||
|
||||
async def test_filter_config_service_list_filters_uses_imports(self, tmp_path: Any) -> None:
|
||||
from app.services.filter_config_service import list_filters
|
||||
|
||||
# Arrange minimal filter and jail config files
|
||||
filter_d = tmp_path / "filter.d"
|
||||
filter_d.mkdir(parents=True)
|
||||
(filter_d / "sshd.conf").write_text("[Definition]\nfailregex = ^%(__prefix_line)s.*$\n")
|
||||
(tmp_path / "jail.conf").write_text("[sshd]\nfilter = sshd\nenabled = true\n")
|
||||
|
||||
with patch(
|
||||
"app.services.filter_config_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value={"sshd"}),
|
||||
):
|
||||
result = await list_filters(str(tmp_path), "/fake.sock")
|
||||
|
||||
assert result.total == 1
|
||||
assert result.filters[0].name == "sshd"
|
||||
assert result.filters[0].active is True
|
||||
|
||||
@@ -8,7 +8,7 @@ import pytest
|
||||
|
||||
from app.models.config import ActionConfigUpdate, FilterConfigUpdate, JailFileConfigUpdate
|
||||
from app.models.file_config import ConfFileCreateRequest, ConfFileUpdateRequest
|
||||
from app.services.file_config_service import (
|
||||
from app.services.raw_config_io_service import (
|
||||
ConfigDirError,
|
||||
ConfigFileExistsError,
|
||||
ConfigFileNameError,
|
||||
|
||||
@@ -2,12 +2,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.geo import GeoInfo
|
||||
from app.services import geo_service
|
||||
from app.services.geo_service import GeoInfo
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
@@ -44,7 +45,7 @@ def _make_session(response_json: dict[str, object], status: int = 200) -> MagicM
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_geo_cache() -> None: # type: ignore[misc]
|
||||
def clear_geo_cache() -> None:
|
||||
"""Flush the module-level geo cache before every test."""
|
||||
geo_service.clear_cache()
|
||||
|
||||
@@ -68,7 +69,7 @@ class TestLookupSuccess:
|
||||
"org": "AS3320 Deutsche Telekom AG",
|
||||
}
|
||||
)
|
||||
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
||||
result = await geo_service.lookup("1.2.3.4", session)
|
||||
|
||||
assert result is not None
|
||||
assert result.country_code == "DE"
|
||||
@@ -84,7 +85,7 @@ class TestLookupSuccess:
|
||||
"org": "Google LLC",
|
||||
}
|
||||
)
|
||||
result = await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type]
|
||||
result = await geo_service.lookup("8.8.8.8", session)
|
||||
|
||||
assert result is not None
|
||||
assert result.country_name == "United States"
|
||||
@@ -100,7 +101,7 @@ class TestLookupSuccess:
|
||||
"org": "Deutsche Telekom",
|
||||
}
|
||||
)
|
||||
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
||||
result = await geo_service.lookup("1.2.3.4", session)
|
||||
|
||||
assert result is not None
|
||||
assert result.asn == "AS3320"
|
||||
@@ -116,7 +117,7 @@ class TestLookupSuccess:
|
||||
"org": "Google LLC",
|
||||
}
|
||||
)
|
||||
result = await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type]
|
||||
result = await geo_service.lookup("8.8.8.8", session)
|
||||
|
||||
assert result is not None
|
||||
assert result.org == "Google LLC"
|
||||
@@ -142,8 +143,8 @@ class TestLookupCaching:
|
||||
}
|
||||
)
|
||||
|
||||
await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("1.2.3.4", session)
|
||||
await geo_service.lookup("1.2.3.4", session)
|
||||
|
||||
# The session.get() should only have been called once.
|
||||
assert session.get.call_count == 1
|
||||
@@ -160,9 +161,9 @@ class TestLookupCaching:
|
||||
}
|
||||
)
|
||||
|
||||
await geo_service.lookup("2.3.4.5", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("2.3.4.5", session)
|
||||
geo_service.clear_cache()
|
||||
await geo_service.lookup("2.3.4.5", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("2.3.4.5", session)
|
||||
|
||||
assert session.get.call_count == 2
|
||||
|
||||
@@ -172,8 +173,8 @@ class TestLookupCaching:
|
||||
{"status": "fail", "message": "reserved range"}
|
||||
)
|
||||
|
||||
await geo_service.lookup("192.168.1.1", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("192.168.1.1", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("192.168.1.1", session)
|
||||
await geo_service.lookup("192.168.1.1", session)
|
||||
|
||||
# Second call is blocked by the negative cache — only one API hit.
|
||||
assert session.get.call_count == 1
|
||||
@@ -190,7 +191,7 @@ class TestLookupFailures:
|
||||
async def test_non_200_response_returns_null_geo_info(self) -> None:
|
||||
"""A 429 or 500 status returns GeoInfo with null fields (not None)."""
|
||||
session = _make_session({}, status=429)
|
||||
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
||||
result = await geo_service.lookup("1.2.3.4", session)
|
||||
assert result is not None
|
||||
assert isinstance(result, GeoInfo)
|
||||
assert result.country_code is None
|
||||
@@ -203,7 +204,7 @@ class TestLookupFailures:
|
||||
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
session.get = MagicMock(return_value=mock_ctx)
|
||||
|
||||
result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type]
|
||||
result = await geo_service.lookup("10.0.0.1", session)
|
||||
assert result is not None
|
||||
assert isinstance(result, GeoInfo)
|
||||
assert result.country_code is None
|
||||
@@ -211,7 +212,7 @@ class TestLookupFailures:
|
||||
async def test_failed_status_returns_geo_info_with_nulls(self) -> None:
|
||||
"""When ip-api returns ``status=fail`` a GeoInfo with null fields is returned (but not cached)."""
|
||||
session = _make_session({"status": "fail", "message": "private range"})
|
||||
result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type]
|
||||
result = await geo_service.lookup("10.0.0.1", session)
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, GeoInfo)
|
||||
@@ -231,8 +232,8 @@ class TestNegativeCache:
|
||||
"""After a failed lookup the second call is served from the neg cache."""
|
||||
session = _make_session({"status": "fail", "message": "private range"})
|
||||
|
||||
r1 = await geo_service.lookup("192.0.2.1", session) # type: ignore[arg-type]
|
||||
r2 = await geo_service.lookup("192.0.2.1", session) # type: ignore[arg-type]
|
||||
r1 = await geo_service.lookup("192.0.2.1", session)
|
||||
r2 = await geo_service.lookup("192.0.2.1", session)
|
||||
|
||||
# Only one HTTP call should have been made; second served from neg cache.
|
||||
assert session.get.call_count == 1
|
||||
@@ -243,12 +244,12 @@ class TestNegativeCache:
|
||||
"""When the neg-cache entry is older than the TTL a new API call is made."""
|
||||
session = _make_session({"status": "fail", "message": "private range"})
|
||||
|
||||
await geo_service.lookup("192.0.2.2", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("192.0.2.2", session)
|
||||
|
||||
# Manually expire the neg-cache entry.
|
||||
geo_service._neg_cache["192.0.2.2"] -= geo_service._NEG_CACHE_TTL + 1 # type: ignore[attr-defined]
|
||||
geo_service._neg_cache["192.0.2.2"] -= geo_service._NEG_CACHE_TTL + 1
|
||||
|
||||
await geo_service.lookup("192.0.2.2", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("192.0.2.2", session)
|
||||
|
||||
# Both calls should have hit the API.
|
||||
assert session.get.call_count == 2
|
||||
@@ -257,9 +258,9 @@ class TestNegativeCache:
|
||||
"""After clearing the neg cache the IP is eligible for a new API call."""
|
||||
session = _make_session({"status": "fail", "message": "private range"})
|
||||
|
||||
await geo_service.lookup("192.0.2.3", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("192.0.2.3", session)
|
||||
geo_service.clear_neg_cache()
|
||||
await geo_service.lookup("192.0.2.3", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("192.0.2.3", session)
|
||||
|
||||
assert session.get.call_count == 2
|
||||
|
||||
@@ -275,9 +276,9 @@ class TestNegativeCache:
|
||||
}
|
||||
)
|
||||
|
||||
await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("1.2.3.4", session)
|
||||
|
||||
assert "1.2.3.4" not in geo_service._neg_cache # type: ignore[attr-defined]
|
||||
assert "1.2.3.4" not in geo_service._neg_cache
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -307,7 +308,7 @@ class TestGeoipFallback:
|
||||
mock_reader = self._make_geoip_reader("DE", "Germany")
|
||||
|
||||
with patch.object(geo_service, "_geoip_reader", mock_reader):
|
||||
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
||||
result = await geo_service.lookup("1.2.3.4", session)
|
||||
|
||||
mock_reader.country.assert_called_once_with("1.2.3.4")
|
||||
assert result is not None
|
||||
@@ -320,12 +321,12 @@ class TestGeoipFallback:
|
||||
mock_reader = self._make_geoip_reader("US", "United States")
|
||||
|
||||
with patch.object(geo_service, "_geoip_reader", mock_reader):
|
||||
await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("8.8.8.8", session)
|
||||
# Second call must be served from positive cache without hitting API.
|
||||
await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("8.8.8.8", session)
|
||||
|
||||
assert session.get.call_count == 1
|
||||
assert "8.8.8.8" in geo_service._cache # type: ignore[attr-defined]
|
||||
assert "8.8.8.8" in geo_service._cache
|
||||
|
||||
async def test_geoip_fallback_not_called_on_api_success(self) -> None:
|
||||
"""When ip-api succeeds, the geoip2 reader must not be consulted."""
|
||||
@@ -341,7 +342,7 @@ class TestGeoipFallback:
|
||||
mock_reader = self._make_geoip_reader("XX", "Nowhere")
|
||||
|
||||
with patch.object(geo_service, "_geoip_reader", mock_reader):
|
||||
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
||||
result = await geo_service.lookup("1.2.3.4", session)
|
||||
|
||||
mock_reader.country.assert_not_called()
|
||||
assert result is not None
|
||||
@@ -352,7 +353,7 @@ class TestGeoipFallback:
|
||||
session = _make_session({"status": "fail", "message": "private range"})
|
||||
|
||||
with patch.object(geo_service, "_geoip_reader", None):
|
||||
result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type]
|
||||
result = await geo_service.lookup("10.0.0.1", session)
|
||||
|
||||
assert result is not None
|
||||
assert result.country_code is None
|
||||
@@ -363,7 +364,7 @@ class TestGeoipFallback:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_batch_session(batch_response: list[dict[str, object]]) -> MagicMock:
|
||||
def _make_batch_session(batch_response: Sequence[Mapping[str, object]]) -> MagicMock:
|
||||
"""Build a mock aiohttp.ClientSession for batch POST calls.
|
||||
|
||||
Args:
|
||||
@@ -412,7 +413,7 @@ class TestLookupBatchSingleCommit:
|
||||
session = _make_batch_session(batch_response)
|
||||
db = _make_async_db()
|
||||
|
||||
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
|
||||
await geo_service.lookup_batch(ips, session, db=db)
|
||||
|
||||
db.commit.assert_awaited_once()
|
||||
|
||||
@@ -426,7 +427,7 @@ class TestLookupBatchSingleCommit:
|
||||
session = _make_batch_session(batch_response)
|
||||
db = _make_async_db()
|
||||
|
||||
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
|
||||
await geo_service.lookup_batch(ips, session, db=db)
|
||||
|
||||
db.commit.assert_awaited_once()
|
||||
|
||||
@@ -452,13 +453,13 @@ class TestLookupBatchSingleCommit:
|
||||
|
||||
async def test_no_commit_for_all_cached_ips(self) -> None:
|
||||
"""When all IPs are already cached, no HTTP call and no commit occur."""
|
||||
geo_service._cache["5.5.5.5"] = GeoInfo( # type: ignore[attr-defined]
|
||||
geo_service._cache["5.5.5.5"] = GeoInfo(
|
||||
country_code="FR", country_name="France", asn="AS1", org="ISP"
|
||||
)
|
||||
db = _make_async_db()
|
||||
session = _make_batch_session([])
|
||||
|
||||
result = await geo_service.lookup_batch(["5.5.5.5"], session, db=db) # type: ignore[arg-type]
|
||||
result = await geo_service.lookup_batch(["5.5.5.5"], session, db=db)
|
||||
|
||||
assert result["5.5.5.5"].country_code == "FR"
|
||||
db.commit.assert_not_awaited()
|
||||
@@ -476,26 +477,26 @@ class TestDirtySetTracking:
|
||||
def test_successful_resolution_adds_to_dirty(self) -> None:
|
||||
"""Storing a GeoInfo with a country_code adds the IP to _dirty."""
|
||||
info = GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")
|
||||
geo_service._store("1.2.3.4", info) # type: ignore[attr-defined]
|
||||
geo_service._store("1.2.3.4", info)
|
||||
|
||||
assert "1.2.3.4" in geo_service._dirty # type: ignore[attr-defined]
|
||||
assert "1.2.3.4" in geo_service._dirty
|
||||
|
||||
def test_null_country_does_not_add_to_dirty(self) -> None:
|
||||
"""Storing a GeoInfo with country_code=None must not pollute _dirty."""
|
||||
info = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
|
||||
geo_service._store("10.0.0.1", info) # type: ignore[attr-defined]
|
||||
geo_service._store("10.0.0.1", info)
|
||||
|
||||
assert "10.0.0.1" not in geo_service._dirty # type: ignore[attr-defined]
|
||||
assert "10.0.0.1" not in geo_service._dirty
|
||||
|
||||
def test_clear_cache_also_clears_dirty(self) -> None:
|
||||
"""clear_cache() must discard any pending dirty entries."""
|
||||
info = GeoInfo(country_code="US", country_name="United States", asn="AS1", org="ISP")
|
||||
geo_service._store("8.8.8.8", info) # type: ignore[attr-defined]
|
||||
assert geo_service._dirty # type: ignore[attr-defined]
|
||||
geo_service._store("8.8.8.8", info)
|
||||
assert geo_service._dirty
|
||||
|
||||
geo_service.clear_cache()
|
||||
|
||||
assert not geo_service._dirty # type: ignore[attr-defined]
|
||||
assert not geo_service._dirty
|
||||
|
||||
async def test_lookup_batch_populates_dirty(self) -> None:
|
||||
"""After lookup_batch() with db=None, resolved IPs appear in _dirty."""
|
||||
@@ -509,7 +510,7 @@ class TestDirtySetTracking:
|
||||
await geo_service.lookup_batch(ips, session, db=None)
|
||||
|
||||
for ip in ips:
|
||||
assert ip in geo_service._dirty # type: ignore[attr-defined]
|
||||
assert ip in geo_service._dirty
|
||||
|
||||
|
||||
class TestFlushDirty:
|
||||
@@ -518,8 +519,8 @@ class TestFlushDirty:
|
||||
async def test_flush_writes_and_clears_dirty(self) -> None:
|
||||
"""flush_dirty() inserts all dirty IPs and clears _dirty afterwards."""
|
||||
info = GeoInfo(country_code="GB", country_name="United Kingdom", asn="AS2856", org="BT")
|
||||
geo_service._store("100.0.0.1", info) # type: ignore[attr-defined]
|
||||
assert "100.0.0.1" in geo_service._dirty # type: ignore[attr-defined]
|
||||
geo_service._store("100.0.0.1", info)
|
||||
assert "100.0.0.1" in geo_service._dirty
|
||||
|
||||
db = _make_async_db()
|
||||
count = await geo_service.flush_dirty(db)
|
||||
@@ -527,7 +528,7 @@ class TestFlushDirty:
|
||||
assert count == 1
|
||||
db.executemany.assert_awaited_once()
|
||||
db.commit.assert_awaited_once()
|
||||
assert "100.0.0.1" not in geo_service._dirty # type: ignore[attr-defined]
|
||||
assert "100.0.0.1" not in geo_service._dirty
|
||||
|
||||
async def test_flush_returns_zero_when_nothing_dirty(self) -> None:
|
||||
"""flush_dirty() returns 0 and makes no DB calls when _dirty is empty."""
|
||||
@@ -541,7 +542,7 @@ class TestFlushDirty:
|
||||
async def test_flush_re_adds_to_dirty_on_db_error(self) -> None:
|
||||
"""When the DB write fails, entries are re-added to _dirty for retry."""
|
||||
info = GeoInfo(country_code="AU", country_name="Australia", asn="AS1", org="ISP")
|
||||
geo_service._store("200.0.0.1", info) # type: ignore[attr-defined]
|
||||
geo_service._store("200.0.0.1", info)
|
||||
|
||||
db = _make_async_db()
|
||||
db.executemany = AsyncMock(side_effect=OSError("disk full"))
|
||||
@@ -549,7 +550,7 @@ class TestFlushDirty:
|
||||
count = await geo_service.flush_dirty(db)
|
||||
|
||||
assert count == 0
|
||||
assert "200.0.0.1" in geo_service._dirty # type: ignore[attr-defined]
|
||||
assert "200.0.0.1" in geo_service._dirty
|
||||
|
||||
async def test_flush_batch_and_lookup_batch_integration(self) -> None:
|
||||
"""lookup_batch() populates _dirty; flush_dirty() then persists them."""
|
||||
@@ -562,14 +563,14 @@ class TestFlushDirty:
|
||||
|
||||
# Resolve without DB to populate only in-memory cache and _dirty.
|
||||
await geo_service.lookup_batch(ips, session, db=None)
|
||||
assert geo_service._dirty == set(ips) # type: ignore[attr-defined]
|
||||
assert geo_service._dirty == set(ips)
|
||||
|
||||
# Now flush to the DB.
|
||||
db = _make_async_db()
|
||||
count = await geo_service.flush_dirty(db)
|
||||
|
||||
assert count == 2
|
||||
assert not geo_service._dirty # type: ignore[attr-defined]
|
||||
assert not geo_service._dirty
|
||||
db.commit.assert_awaited_once()
|
||||
|
||||
|
||||
@@ -585,7 +586,7 @@ class TestLookupBatchThrottling:
|
||||
"""When more than _BATCH_SIZE IPs are sent, asyncio.sleep is called
|
||||
between consecutive batch HTTP calls with at least _BATCH_DELAY."""
|
||||
# Generate _BATCH_SIZE + 1 IPs so we get exactly 2 batch calls.
|
||||
batch_size: int = geo_service._BATCH_SIZE # type: ignore[attr-defined]
|
||||
batch_size: int = geo_service._BATCH_SIZE
|
||||
ips = [f"10.0.{i // 256}.{i % 256}" for i in range(batch_size + 1)]
|
||||
|
||||
def _make_result(chunk: list[str], _session: object) -> dict[str, GeoInfo]:
|
||||
@@ -608,7 +609,7 @@ class TestLookupBatchThrottling:
|
||||
assert mock_batch.call_count == 2
|
||||
mock_sleep.assert_awaited_once()
|
||||
delay_arg: float = mock_sleep.call_args[0][0]
|
||||
assert delay_arg >= geo_service._BATCH_DELAY # type: ignore[attr-defined]
|
||||
assert delay_arg >= geo_service._BATCH_DELAY
|
||||
|
||||
async def test_lookup_batch_retries_on_full_chunk_failure(self) -> None:
|
||||
"""When a chunk returns all-None on first try, it retries and succeeds."""
|
||||
@@ -650,7 +651,7 @@ class TestLookupBatchThrottling:
|
||||
_empty = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
|
||||
_failure: dict[str, GeoInfo] = dict.fromkeys(ips, _empty)
|
||||
|
||||
max_retries: int = geo_service._BATCH_MAX_RETRIES # type: ignore[attr-defined]
|
||||
max_retries: int = geo_service._BATCH_MAX_RETRIES
|
||||
|
||||
with (
|
||||
patch(
|
||||
@@ -667,11 +668,11 @@ class TestLookupBatchThrottling:
|
||||
# IP should have no country.
|
||||
assert result["9.9.9.9"].country_code is None
|
||||
# Negative cache should contain the IP.
|
||||
assert "9.9.9.9" in geo_service._neg_cache # type: ignore[attr-defined]
|
||||
assert "9.9.9.9" in geo_service._neg_cache
|
||||
# Sleep called for each retry with exponential backoff.
|
||||
assert mock_sleep.call_count == max_retries
|
||||
backoff_values = [call.args[0] for call in mock_sleep.call_args_list]
|
||||
batch_delay: float = geo_service._BATCH_DELAY # type: ignore[attr-defined]
|
||||
batch_delay: float = geo_service._BATCH_DELAY
|
||||
for i, val in enumerate(backoff_values):
|
||||
expected = batch_delay * (2 ** (i + 1))
|
||||
assert val == pytest.approx(expected)
|
||||
@@ -709,7 +710,7 @@ class TestErrorLogging:
|
||||
import structlog.testing
|
||||
|
||||
with structlog.testing.capture_logs() as captured:
|
||||
result = await geo_service.lookup("197.221.98.153", session) # type: ignore[arg-type]
|
||||
result = await geo_service.lookup("197.221.98.153", session)
|
||||
|
||||
assert result is not None
|
||||
assert result.country_code is None
|
||||
@@ -733,7 +734,7 @@ class TestErrorLogging:
|
||||
import structlog.testing
|
||||
|
||||
with structlog.testing.capture_logs() as captured:
|
||||
await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("10.0.0.1", session)
|
||||
|
||||
request_failed = [e for e in captured if e.get("event") == "geo_lookup_request_failed"]
|
||||
assert len(request_failed) == 1
|
||||
@@ -757,7 +758,7 @@ class TestErrorLogging:
|
||||
import structlog.testing
|
||||
|
||||
with structlog.testing.capture_logs() as captured:
|
||||
result = await geo_service._batch_api_call(["1.2.3.4"], session) # type: ignore[attr-defined]
|
||||
result = await geo_service._batch_api_call(["1.2.3.4"], session)
|
||||
|
||||
assert result["1.2.3.4"].country_code is None
|
||||
|
||||
@@ -778,7 +779,7 @@ class TestLookupCachedOnly:
|
||||
|
||||
def test_returns_cached_ips(self) -> None:
|
||||
"""IPs already in the cache are returned in the geo_map."""
|
||||
geo_service._cache["1.1.1.1"] = GeoInfo( # type: ignore[attr-defined]
|
||||
geo_service._cache["1.1.1.1"] = GeoInfo(
|
||||
country_code="AU", country_name="Australia", asn="AS13335", org="Cloudflare"
|
||||
)
|
||||
geo_map, uncached = geo_service.lookup_cached_only(["1.1.1.1"])
|
||||
@@ -798,7 +799,7 @@ class TestLookupCachedOnly:
|
||||
"""IPs in the negative cache within TTL are not re-queued as uncached."""
|
||||
import time
|
||||
|
||||
geo_service._neg_cache["10.0.0.1"] = time.monotonic() # type: ignore[attr-defined]
|
||||
geo_service._neg_cache["10.0.0.1"] = time.monotonic()
|
||||
|
||||
geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.1"])
|
||||
|
||||
@@ -807,7 +808,7 @@ class TestLookupCachedOnly:
|
||||
|
||||
def test_expired_neg_cache_requeued(self) -> None:
|
||||
"""IPs whose neg-cache entry has expired are listed as uncached."""
|
||||
geo_service._neg_cache["10.0.0.2"] = 0.0 # epoch 0 → expired # type: ignore[attr-defined]
|
||||
geo_service._neg_cache["10.0.0.2"] = 0.0 # epoch 0 → expired
|
||||
|
||||
_geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.2"])
|
||||
|
||||
@@ -815,12 +816,12 @@ class TestLookupCachedOnly:
|
||||
|
||||
def test_mixed_ips(self) -> None:
|
||||
"""A mix of cached, neg-cached, and unknown IPs is split correctly."""
|
||||
geo_service._cache["1.2.3.4"] = GeoInfo( # type: ignore[attr-defined]
|
||||
geo_service._cache["1.2.3.4"] = GeoInfo(
|
||||
country_code="DE", country_name="Germany", asn=None, org=None
|
||||
)
|
||||
import time
|
||||
|
||||
geo_service._neg_cache["5.5.5.5"] = time.monotonic() # type: ignore[attr-defined]
|
||||
geo_service._neg_cache["5.5.5.5"] = time.monotonic()
|
||||
|
||||
geo_map, uncached = geo_service.lookup_cached_only(["1.2.3.4", "5.5.5.5", "9.9.9.9"])
|
||||
|
||||
@@ -829,7 +830,7 @@ class TestLookupCachedOnly:
|
||||
|
||||
def test_deduplication(self) -> None:
|
||||
"""Duplicate IPs in the input appear at most once in the output."""
|
||||
geo_service._cache["1.2.3.4"] = GeoInfo( # type: ignore[attr-defined]
|
||||
geo_service._cache["1.2.3.4"] = GeoInfo(
|
||||
country_code="US", country_name="United States", asn=None, org=None
|
||||
)
|
||||
|
||||
@@ -866,7 +867,7 @@ class TestLookupBatchBulkWrites:
|
||||
session = _make_batch_session(batch_response)
|
||||
db = _make_async_db()
|
||||
|
||||
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
|
||||
await geo_service.lookup_batch(ips, session, db=db)
|
||||
|
||||
# One executemany for the positive rows.
|
||||
assert db.executemany.await_count >= 1
|
||||
@@ -883,7 +884,7 @@ class TestLookupBatchBulkWrites:
|
||||
session = _make_batch_session(batch_response)
|
||||
db = _make_async_db()
|
||||
|
||||
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
|
||||
await geo_service.lookup_batch(ips, session, db=db)
|
||||
|
||||
assert db.executemany.await_count >= 1
|
||||
db.execute.assert_not_awaited()
|
||||
@@ -905,7 +906,7 @@ class TestLookupBatchBulkWrites:
|
||||
session = _make_batch_session(batch_response)
|
||||
db = _make_async_db()
|
||||
|
||||
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
|
||||
await geo_service.lookup_batch(ips, session, db=db)
|
||||
|
||||
# One executemany for positives, one for negatives.
|
||||
assert db.executemany.await_count == 2
|
||||
|
||||
@@ -64,7 +64,7 @@ async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
||||
async def f2b_db_path(tmp_path: Path) -> str:
|
||||
"""Return the path to a test fail2ban SQLite database."""
|
||||
path = str(tmp_path / "fail2ban_test.sqlite3")
|
||||
await _create_f2b_db(
|
||||
@@ -123,7 +123,7 @@ class TestListHistory:
|
||||
) -> None:
|
||||
"""No filter returns every record in the database."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history("fake_socket")
|
||||
@@ -135,7 +135,7 @@ class TestListHistory:
|
||||
) -> None:
|
||||
"""The ``range_`` filter excludes bans older than the window."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
# "24h" window should include only the two recent bans
|
||||
@@ -147,7 +147,7 @@ class TestListHistory:
|
||||
async def test_jail_filter(self, f2b_db_path: str) -> None:
|
||||
"""Jail filter restricts results to bans from that jail."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history("fake_socket", jail="nginx")
|
||||
@@ -157,7 +157,7 @@ class TestListHistory:
|
||||
async def test_ip_prefix_filter(self, f2b_db_path: str) -> None:
|
||||
"""IP prefix filter restricts results to matching IPs."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history(
|
||||
@@ -170,7 +170,7 @@ class TestListHistory:
|
||||
async def test_combined_filters(self, f2b_db_path: str) -> None:
|
||||
"""Jail + IP prefix filters applied together narrow the result set."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history(
|
||||
@@ -179,10 +179,23 @@ class TestListHistory:
|
||||
# 2 sshd bans for 1.2.3.4
|
||||
assert result.total == 2
|
||||
|
||||
async def test_origin_filter_selfblock(self, f2b_db_path: str) -> None:
|
||||
"""Origin filter should include only selfblock entries."""
|
||||
with patch(
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history(
|
||||
"fake_socket", origin="selfblock"
|
||||
)
|
||||
|
||||
assert result.total == 4
|
||||
assert all(item.jail != "blocklist-import" for item in result.items)
|
||||
|
||||
async def test_unknown_ip_returns_empty(self, f2b_db_path: str) -> None:
|
||||
"""Filtering by a non-existent IP returns an empty result set."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history(
|
||||
@@ -196,7 +209,7 @@ class TestListHistory:
|
||||
) -> None:
|
||||
"""``failures`` field is parsed from the JSON ``data`` column."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history(
|
||||
@@ -210,7 +223,7 @@ class TestListHistory:
|
||||
) -> None:
|
||||
"""``matches`` list is parsed from the JSON ``data`` column."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history(
|
||||
@@ -226,7 +239,7 @@ class TestListHistory:
|
||||
) -> None:
|
||||
"""Records with ``data=NULL`` produce failures=0 and matches=[]."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history(
|
||||
@@ -240,7 +253,7 @@ class TestListHistory:
|
||||
async def test_pagination(self, f2b_db_path: str) -> None:
|
||||
"""Pagination returns the correct slice."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history(
|
||||
@@ -265,7 +278,7 @@ class TestGetIpDetail:
|
||||
) -> None:
|
||||
"""Returns ``None`` when the IP has no records in the database."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.get_ip_detail("fake_socket", "99.99.99.99")
|
||||
@@ -276,7 +289,7 @@ class TestGetIpDetail:
|
||||
) -> None:
|
||||
"""Returns an IpDetailResponse with correct totals for a known IP."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.get_ip_detail("fake_socket", "1.2.3.4")
|
||||
@@ -291,7 +304,7 @@ class TestGetIpDetail:
|
||||
) -> None:
|
||||
"""Timeline events are ordered newest-first."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.get_ip_detail("fake_socket", "1.2.3.4")
|
||||
@@ -304,7 +317,7 @@ class TestGetIpDetail:
|
||||
async def test_last_ban_at_is_most_recent(self, f2b_db_path: str) -> None:
|
||||
"""``last_ban_at`` matches the banned_at of the first timeline event."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.get_ip_detail("fake_socket", "1.2.3.4")
|
||||
@@ -316,7 +329,7 @@ class TestGetIpDetail:
|
||||
self, f2b_db_path: str
|
||||
) -> None:
|
||||
"""Geolocation is applied when a geo_enricher is provided."""
|
||||
from app.services.geo_service import GeoInfo
|
||||
from app.models.geo import GeoInfo
|
||||
|
||||
mock_geo = GeoInfo(
|
||||
country_code="US",
|
||||
@@ -327,7 +340,7 @@ class TestGetIpDetail:
|
||||
fake_enricher = AsyncMock(return_value=mock_geo)
|
||||
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.get_ip_detail(
|
||||
|
||||
@@ -441,6 +441,33 @@ class TestJailControls:
|
||||
)
|
||||
assert exc_info.value.name == "airsonic-auth"
|
||||
|
||||
async def test_restart_sends_stop_command(self) -> None:
|
||||
"""restart() sends the ['stop'] command to the fail2ban socket."""
|
||||
with _patch_client({"stop": (0, None)}):
|
||||
await jail_service.restart(_SOCKET) # should not raise
|
||||
|
||||
async def test_restart_operation_error_raises(self) -> None:
|
||||
"""restart() raises JailOperationError when fail2ban rejects the stop."""
|
||||
with _patch_client({"stop": (1, Exception("cannot stop"))}), pytest.raises(
|
||||
JailOperationError
|
||||
):
|
||||
await jail_service.restart(_SOCKET)
|
||||
|
||||
async def test_restart_connection_error_propagates(self) -> None:
|
||||
"""restart() propagates Fail2BanConnectionError when socket is unreachable."""
|
||||
|
||||
class _FailClient:
|
||||
def __init__(self, **_kw: Any) -> None:
|
||||
self.send = AsyncMock(
|
||||
side_effect=Fail2BanConnectionError("no socket", _SOCKET)
|
||||
)
|
||||
|
||||
with (
|
||||
patch("app.services.jail_service.Fail2BanClient", _FailClient),
|
||||
pytest.raises(Fail2BanConnectionError),
|
||||
):
|
||||
await jail_service.restart(_SOCKET)
|
||||
|
||||
async def test_start_not_found_raises(self) -> None:
|
||||
"""start_jail raises JailNotFoundError for unknown jail."""
|
||||
with _patch_client({"start|ghost": (1, Exception("Unknown jail: 'ghost'"))}), pytest.raises(JailNotFoundError):
|
||||
@@ -608,7 +635,7 @@ class TestGetActiveBans:
|
||||
|
||||
async def test_http_session_triggers_lookup_batch(self) -> None:
|
||||
"""When http_session is provided, geo_service.lookup_batch is used."""
|
||||
from app.services.geo_service import GeoInfo
|
||||
from app.models.geo import GeoInfo
|
||||
|
||||
responses = {
|
||||
"status": _make_global_status("sshd"),
|
||||
@@ -618,17 +645,14 @@ class TestGetActiveBans:
|
||||
),
|
||||
}
|
||||
mock_geo = {"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")}
|
||||
mock_batch = AsyncMock(return_value=mock_geo)
|
||||
|
||||
with (
|
||||
_patch_client(responses),
|
||||
patch(
|
||||
"app.services.geo_service.lookup_batch",
|
||||
new=AsyncMock(return_value=mock_geo),
|
||||
) as mock_batch,
|
||||
):
|
||||
with _patch_client(responses):
|
||||
mock_session = AsyncMock()
|
||||
result = await jail_service.get_active_bans(
|
||||
_SOCKET, http_session=mock_session
|
||||
_SOCKET,
|
||||
http_session=mock_session,
|
||||
geo_batch_lookup=mock_batch,
|
||||
)
|
||||
|
||||
mock_batch.assert_awaited_once()
|
||||
@@ -645,16 +669,14 @@ class TestGetActiveBans:
|
||||
),
|
||||
}
|
||||
|
||||
with (
|
||||
_patch_client(responses),
|
||||
patch(
|
||||
"app.services.geo_service.lookup_batch",
|
||||
new=AsyncMock(side_effect=RuntimeError("geo down")),
|
||||
),
|
||||
):
|
||||
failing_batch = AsyncMock(side_effect=RuntimeError("geo down"))
|
||||
|
||||
with _patch_client(responses):
|
||||
mock_session = AsyncMock()
|
||||
result = await jail_service.get_active_bans(
|
||||
_SOCKET, http_session=mock_session
|
||||
_SOCKET,
|
||||
http_session=mock_session,
|
||||
geo_batch_lookup=failing_batch,
|
||||
)
|
||||
|
||||
assert result.total == 1
|
||||
@@ -662,7 +684,7 @@ class TestGetActiveBans:
|
||||
|
||||
async def test_geo_enricher_still_used_without_http_session(self) -> None:
|
||||
"""Legacy geo_enricher is still called when http_session is not provided."""
|
||||
from app.services.geo_service import GeoInfo
|
||||
from app.models.geo import GeoInfo
|
||||
|
||||
responses = {
|
||||
"status": _make_global_status("sshd"),
|
||||
@@ -960,6 +982,7 @@ class TestGetJailBannedIps:
|
||||
page=1,
|
||||
page_size=2,
|
||||
http_session=http_session,
|
||||
geo_batch_lookup=geo_service.lookup_batch,
|
||||
)
|
||||
|
||||
# Only the 2-IP page slice should be passed to geo enrichment.
|
||||
@@ -969,9 +992,6 @@ class TestGetJailBannedIps:
|
||||
|
||||
async def test_unknown_jail_raises_jail_not_found_error(self) -> None:
|
||||
"""get_jail_banned_ips raises JailNotFoundError for unknown jail."""
|
||||
responses = {
|
||||
"status|ghost|short": (0, pytest.raises), # will be overridden
|
||||
}
|
||||
# Simulate fail2ban returning an "unknown jail" error.
|
||||
class _FakeClient:
|
||||
def __init__(self, **_kw: Any) -> None:
|
||||
|
||||
@@ -13,7 +13,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.geo_service import GeoInfo
|
||||
from app.models.geo import GeoInfo
|
||||
from app.tasks.geo_re_resolve import _run_re_resolve
|
||||
|
||||
|
||||
@@ -79,6 +79,8 @@ async def test_run_re_resolve_no_unresolved_ips_skips() -> None:
|
||||
app = _make_app(unresolved_ips=[])
|
||||
|
||||
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
|
||||
mock_geo.get_unresolved_ips = AsyncMock(return_value=[])
|
||||
|
||||
await _run_re_resolve(app)
|
||||
|
||||
mock_geo.clear_neg_cache.assert_not_called()
|
||||
@@ -96,6 +98,7 @@ async def test_run_re_resolve_clears_neg_cache() -> None:
|
||||
app = _make_app(unresolved_ips=ips, lookup_result=result)
|
||||
|
||||
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
|
||||
mock_geo.get_unresolved_ips = AsyncMock(return_value=ips)
|
||||
mock_geo.lookup_batch = AsyncMock(return_value=result)
|
||||
|
||||
await _run_re_resolve(app)
|
||||
@@ -114,6 +117,7 @@ async def test_run_re_resolve_calls_lookup_batch_with_db() -> None:
|
||||
app = _make_app(unresolved_ips=ips, lookup_result=result)
|
||||
|
||||
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
|
||||
mock_geo.get_unresolved_ips = AsyncMock(return_value=ips)
|
||||
mock_geo.lookup_batch = AsyncMock(return_value=result)
|
||||
|
||||
await _run_re_resolve(app)
|
||||
@@ -137,6 +141,7 @@ async def test_run_re_resolve_logs_correct_counts(caplog: Any) -> None:
|
||||
app = _make_app(unresolved_ips=ips, lookup_result=result)
|
||||
|
||||
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
|
||||
mock_geo.get_unresolved_ips = AsyncMock(return_value=ips)
|
||||
mock_geo.lookup_batch = AsyncMock(return_value=result)
|
||||
|
||||
await _run_re_resolve(app)
|
||||
@@ -159,6 +164,7 @@ async def test_run_re_resolve_handles_all_resolved() -> None:
|
||||
app = _make_app(unresolved_ips=ips, lookup_result=result)
|
||||
|
||||
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
|
||||
mock_geo.get_unresolved_ips = AsyncMock(return_value=ips)
|
||||
mock_geo.lookup_batch = AsyncMock(return_value=result)
|
||||
|
||||
await _run_re_resolve(app)
|
||||
|
||||
@@ -270,7 +270,7 @@ class TestCrashDetection:
|
||||
async def test_crash_within_window_creates_pending_recovery(self) -> None:
|
||||
"""An online→offline transition within 60 s of activation must set pending_recovery."""
|
||||
app = _make_app(prev_online=True)
|
||||
now = datetime.datetime.now(tz=datetime.timezone.utc)
|
||||
now = datetime.datetime.now(tz=datetime.UTC)
|
||||
app.state.last_activation = {
|
||||
"jail_name": "sshd",
|
||||
"at": now - datetime.timedelta(seconds=10),
|
||||
@@ -297,7 +297,7 @@ class TestCrashDetection:
|
||||
app = _make_app(prev_online=True)
|
||||
app.state.last_activation = {
|
||||
"jail_name": "sshd",
|
||||
"at": datetime.datetime.now(tz=datetime.timezone.utc)
|
||||
"at": datetime.datetime.now(tz=datetime.UTC)
|
||||
- datetime.timedelta(seconds=120),
|
||||
}
|
||||
app.state.pending_recovery = None
|
||||
@@ -315,8 +315,8 @@ class TestCrashDetection:
|
||||
async def test_came_online_marks_pending_recovery_resolved(self) -> None:
|
||||
"""An offline→online transition must mark an existing pending_recovery as recovered."""
|
||||
app = _make_app(prev_online=False)
|
||||
activated_at = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(seconds=30)
|
||||
detected_at = datetime.datetime.now(tz=datetime.timezone.utc)
|
||||
activated_at = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(seconds=30)
|
||||
detected_at = datetime.datetime.now(tz=datetime.UTC)
|
||||
app.state.pending_recovery = PendingRecovery(
|
||||
jail_name="sshd",
|
||||
activated_at=activated_at,
|
||||
|
||||
138
backend/tests/test_utils/test_jail_config.py
Normal file
138
backend/tests/test_utils/test_jail_config.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Tests for app.utils.jail_config.ensure_jail_configs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from app.utils.jail_config import (
|
||||
_BLOCKLIST_IMPORT_CONF,
|
||||
_BLOCKLIST_IMPORT_LOCAL,
|
||||
_MANUAL_JAIL_CONF,
|
||||
_MANUAL_JAIL_LOCAL,
|
||||
ensure_jail_configs,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Expected filenames
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_MANUAL_CONF = "manual-Jail.conf"
|
||||
_MANUAL_LOCAL = "manual-Jail.local"
|
||||
_BLOCKLIST_CONF = "blocklist-import.conf"
|
||||
_BLOCKLIST_LOCAL = "blocklist-import.local"
|
||||
|
||||
_ALL_FILES = [_MANUAL_CONF, _MANUAL_LOCAL, _BLOCKLIST_CONF, _BLOCKLIST_LOCAL]
|
||||
|
||||
_CONTENT_MAP: dict[str, str] = {
|
||||
_MANUAL_CONF: _MANUAL_JAIL_CONF,
|
||||
_MANUAL_LOCAL: _MANUAL_JAIL_LOCAL,
|
||||
_BLOCKLIST_CONF: _BLOCKLIST_IMPORT_CONF,
|
||||
_BLOCKLIST_LOCAL: _BLOCKLIST_IMPORT_LOCAL,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _read(jail_d: Path, filename: str) -> str:
|
||||
return (jail_d / filename).read_text(encoding="utf-8")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: ensure_jail_configs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnsureJailConfigs:
|
||||
def test_all_missing_creates_all_four(self, tmp_path: Path) -> None:
|
||||
"""All four files are created when the directory is empty."""
|
||||
jail_d = tmp_path / "jail.d"
|
||||
ensure_jail_configs(jail_d)
|
||||
|
||||
for name in _ALL_FILES:
|
||||
assert (jail_d / name).exists(), f"{name} should have been created"
|
||||
assert _read(jail_d, name) == _CONTENT_MAP[name]
|
||||
|
||||
def test_all_missing_creates_correct_content(self, tmp_path: Path) -> None:
|
||||
"""Each created file has exactly the expected default content."""
|
||||
jail_d = tmp_path / "jail.d"
|
||||
ensure_jail_configs(jail_d)
|
||||
|
||||
# .conf files must set enabled = false
|
||||
for conf_file in (_MANUAL_CONF, _BLOCKLIST_CONF):
|
||||
content = _read(jail_d, conf_file)
|
||||
assert "enabled = false" in content
|
||||
|
||||
# Blocklist-import jail must have a 24-hour ban time
|
||||
blocklist_conf = _read(jail_d, _BLOCKLIST_CONF)
|
||||
assert "bantime = 86400" in blocklist_conf
|
||||
|
||||
# .local files must set enabled = true and nothing else
|
||||
for local_file in (_MANUAL_LOCAL, _BLOCKLIST_LOCAL):
|
||||
content = _read(jail_d, local_file)
|
||||
assert "enabled = true" in content
|
||||
|
||||
def test_all_present_overwrites_nothing(self, tmp_path: Path) -> None:
|
||||
"""Existing files are never overwritten."""
|
||||
jail_d = tmp_path / "jail.d"
|
||||
jail_d.mkdir()
|
||||
|
||||
sentinel = "# EXISTING CONTENT — must not be replaced\n"
|
||||
for name in _ALL_FILES:
|
||||
(jail_d / name).write_text(sentinel, encoding="utf-8")
|
||||
|
||||
ensure_jail_configs(jail_d)
|
||||
|
||||
for name in _ALL_FILES:
|
||||
assert _read(jail_d, name) == sentinel, (
|
||||
f"{name} should not have been overwritten"
|
||||
)
|
||||
|
||||
def test_only_local_files_missing_creates_only_locals(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""Only the .local files are created when the .conf files already exist."""
|
||||
jail_d = tmp_path / "jail.d"
|
||||
jail_d.mkdir()
|
||||
|
||||
sentinel = "# pre-existing conf\n"
|
||||
for conf_file in (_MANUAL_CONF, _BLOCKLIST_CONF):
|
||||
(jail_d / conf_file).write_text(sentinel, encoding="utf-8")
|
||||
|
||||
ensure_jail_configs(jail_d)
|
||||
|
||||
# .conf files must remain unchanged
|
||||
for conf_file in (_MANUAL_CONF, _BLOCKLIST_CONF):
|
||||
assert _read(jail_d, conf_file) == sentinel
|
||||
|
||||
# .local files must have been created with correct content
|
||||
for local_file, expected in (
|
||||
(_MANUAL_LOCAL, _MANUAL_JAIL_LOCAL),
|
||||
(_BLOCKLIST_LOCAL, _BLOCKLIST_IMPORT_LOCAL),
|
||||
):
|
||||
assert (jail_d / local_file).exists(), f"{local_file} should have been created"
|
||||
assert _read(jail_d, local_file) == expected
|
||||
|
||||
def test_creates_jail_d_directory_if_missing(self, tmp_path: Path) -> None:
|
||||
"""The jail.d directory is created automatically when absent."""
|
||||
jail_d = tmp_path / "nested" / "jail.d"
|
||||
assert not jail_d.exists()
|
||||
ensure_jail_configs(jail_d)
|
||||
assert jail_d.is_dir()
|
||||
|
||||
def test_idempotent_on_repeated_calls(self, tmp_path: Path) -> None:
|
||||
"""Calling ensure_jail_configs twice does not alter any file."""
|
||||
jail_d = tmp_path / "jail.d"
|
||||
ensure_jail_configs(jail_d)
|
||||
|
||||
# Record content after first call
|
||||
first_pass = {name: _read(jail_d, name) for name in _ALL_FILES}
|
||||
|
||||
ensure_jail_configs(jail_d)
|
||||
|
||||
for name in _ALL_FILES:
|
||||
assert _read(jail_d, name) == first_pass[name], (
|
||||
f"{name} changed on second call"
|
||||
)
|
||||
15
backend/tests/test_version.py
Normal file
15
backend/tests/test_version.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import app
|
||||
|
||||
|
||||
def test_app_version_matches_docker_version() -> None:
|
||||
"""The backend version should match the signed off Docker release version."""
|
||||
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
version_file = repo_root / "Docker" / "VERSION"
|
||||
expected = version_file.read_text(encoding="utf-8").strip().lstrip("v")
|
||||
|
||||
assert app.__version__ == expected
|
||||
4
frontend/package-lock.json
generated
4
frontend/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "bangui-frontend",
|
||||
"version": "0.1.0",
|
||||
"version": "0.9.10",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "bangui-frontend",
|
||||
"version": "0.1.0",
|
||||
"version": "0.9.10",
|
||||
"dependencies": {
|
||||
"@fluentui/react-components": "^9.55.0",
|
||||
"@fluentui/react-icons": "^2.0.257",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"name": "bangui-frontend",
|
||||
"private": true,
|
||||
"version": "0.1.0",
|
||||
"version": "0.9.12",
|
||||
"description": "BanGUI frontend — fail2ban web management interface",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
|
||||
@@ -26,6 +26,7 @@ import { AuthProvider } from "./providers/AuthProvider";
|
||||
import { TimezoneProvider } from "./providers/TimezoneProvider";
|
||||
import { RequireAuth } from "./components/RequireAuth";
|
||||
import { SetupGuard } from "./components/SetupGuard";
|
||||
import { ErrorBoundary } from "./components/ErrorBoundary";
|
||||
import { MainLayout } from "./layouts/MainLayout";
|
||||
import { SetupPage } from "./pages/SetupPage";
|
||||
import { LoginPage } from "./pages/LoginPage";
|
||||
@@ -43,9 +44,10 @@ import { BlocklistsPage } from "./pages/BlocklistsPage";
|
||||
function App(): React.JSX.Element {
|
||||
return (
|
||||
<FluentProvider theme={lightTheme}>
|
||||
<BrowserRouter future={{ v7_startTransition: true, v7_relativeSplatPath: true }}>
|
||||
<AuthProvider>
|
||||
<Routes>
|
||||
<ErrorBoundary>
|
||||
<BrowserRouter future={{ v7_startTransition: true, v7_relativeSplatPath: true }}>
|
||||
<AuthProvider>
|
||||
<Routes>
|
||||
{/* Setup wizard — always accessible; redirects to /login if already done */}
|
||||
<Route path="/setup" element={<SetupPage />} />
|
||||
|
||||
@@ -85,6 +87,7 @@ function App(): React.JSX.Element {
|
||||
</Routes>
|
||||
</AuthProvider>
|
||||
</BrowserRouter>
|
||||
</ErrorBoundary>
|
||||
</FluentProvider>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -7,22 +7,16 @@
|
||||
|
||||
import { api } from "./client";
|
||||
import { ENDPOINTS } from "./endpoints";
|
||||
import type { LoginRequest, LoginResponse, LogoutResponse } from "../types/auth";
|
||||
import { sha256Hex } from "../utils/crypto";
|
||||
import type { LoginResponse, LogoutResponse } from "../types/auth";
|
||||
|
||||
/**
|
||||
* Authenticate with the master password.
|
||||
*
|
||||
* The password is SHA-256 hashed client-side before transmission so that
|
||||
* the plaintext never leaves the browser. The backend bcrypt-verifies the
|
||||
* received hash against the stored bcrypt(sha256) digest.
|
||||
*
|
||||
* @param password - The master password entered by the user.
|
||||
* @returns The login response containing the session token.
|
||||
*/
|
||||
export async function login(password: string): Promise<LoginResponse> {
|
||||
const body: LoginRequest = { password: await sha256Hex(password) };
|
||||
return api.post<LoginResponse>(ENDPOINTS.authLogin, body);
|
||||
return api.post<LoginResponse>(ENDPOINTS.authLogin, { password });
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -39,10 +39,8 @@ import type {
|
||||
LogPreviewResponse,
|
||||
MapColorThresholdsResponse,
|
||||
MapColorThresholdsUpdate,
|
||||
PendingRecovery,
|
||||
RegexTestRequest,
|
||||
RegexTestResponse,
|
||||
RollbackResponse,
|
||||
ServerSettingsResponse,
|
||||
ServerSettingsUpdate,
|
||||
JailFileConfig,
|
||||
@@ -265,14 +263,14 @@ export async function fetchActionFiles(): Promise<ConfFilesResponse> {
|
||||
}
|
||||
|
||||
export async function fetchActionFile(name: string): Promise<ConfFileContent> {
|
||||
return get<ConfFileContent>(ENDPOINTS.configAction(name));
|
||||
return get<ConfFileContent>(ENDPOINTS.configActionRaw(name));
|
||||
}
|
||||
|
||||
export async function updateActionFile(
|
||||
name: string,
|
||||
req: ConfFileUpdateRequest
|
||||
): Promise<void> {
|
||||
await put<undefined>(ENDPOINTS.configAction(name), req);
|
||||
await put<undefined>(ENDPOINTS.configActionRaw(name), req);
|
||||
}
|
||||
|
||||
export async function createActionFile(
|
||||
@@ -552,6 +550,18 @@ export async function deactivateJail(
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete the ``jail.d/{name}.local`` override file for an inactive jail.
|
||||
*
|
||||
* Only valid when the jail is **not** currently active. Use this to clean up
|
||||
* leftover ``.local`` files after a jail has been fully deactivated.
|
||||
*
|
||||
* @param name - The jail name.
|
||||
*/
|
||||
export async function deleteJailLocalOverride(name: string): Promise<void> {
|
||||
await del<undefined>(ENDPOINTS.configJailLocalOverride(name));
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// fail2ban log viewer (Task 2)
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -593,21 +603,3 @@ export async function validateJailConfig(
|
||||
): Promise<JailValidationResult> {
|
||||
return post<JailValidationResult>(ENDPOINTS.configJailValidate(name), undefined);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch the pending crash-recovery record, if any.
|
||||
*
|
||||
* Returns null when fail2ban is healthy and no recovery is pending.
|
||||
*/
|
||||
export async function fetchPendingRecovery(): Promise<PendingRecovery | null> {
|
||||
return get<PendingRecovery | null>(ENDPOINTS.configPendingRecovery);
|
||||
}
|
||||
|
||||
/**
|
||||
* Rollback a bad jail — disables it and attempts to restart fail2ban.
|
||||
*
|
||||
* @param name - Name of the jail to disable.
|
||||
*/
|
||||
export async function rollbackJail(name: string): Promise<RollbackResponse> {
|
||||
return post<RollbackResponse>(ENDPOINTS.configJailRollback(name), undefined);
|
||||
}
|
||||
|
||||
@@ -71,11 +71,10 @@ export const ENDPOINTS = {
|
||||
`/config/jails/${encodeURIComponent(name)}/activate`,
|
||||
configJailDeactivate: (name: string): string =>
|
||||
`/config/jails/${encodeURIComponent(name)}/deactivate`,
|
||||
configJailLocalOverride: (name: string): string =>
|
||||
`/config/jails/${encodeURIComponent(name)}/local`,
|
||||
configJailValidate: (name: string): string =>
|
||||
`/config/jails/${encodeURIComponent(name)}/validate`,
|
||||
configJailRollback: (name: string): string =>
|
||||
`/config/jails/${encodeURIComponent(name)}/rollback`,
|
||||
configPendingRecovery: "/config/pending-recovery" as string,
|
||||
configGlobal: "/config/global",
|
||||
configReload: "/config/reload",
|
||||
configRestart: "/config/restart",
|
||||
@@ -105,6 +104,7 @@ export const ENDPOINTS = {
|
||||
`/config/jails/${encodeURIComponent(jailName)}/action/${encodeURIComponent(actionName)}`,
|
||||
configActions: "/config/actions",
|
||||
configAction: (name: string): string => `/config/actions/${encodeURIComponent(name)}`,
|
||||
configActionRaw: (name: string): string => `/config/actions/${encodeURIComponent(name)}/raw`,
|
||||
configActionParsed: (name: string): string =>
|
||||
`/config/actions/${encodeURIComponent(name)}/parsed`,
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ export async function fetchHistory(
|
||||
): Promise<HistoryListResponse> {
|
||||
const params = new URLSearchParams();
|
||||
if (query.range) params.set("range", query.range);
|
||||
if (query.origin) params.set("origin", query.origin);
|
||||
if (query.jail) params.set("jail", query.jail);
|
||||
if (query.ip) params.set("ip", query.ip);
|
||||
if (query.page !== undefined) params.set("page", String(query.page));
|
||||
|
||||
@@ -27,6 +27,7 @@ import {
|
||||
import { PageEmpty, PageError, PageLoading } from "./PageFeedback";
|
||||
import { ChevronLeftRegular, ChevronRightRegular } from "@fluentui/react-icons";
|
||||
import { useBans } from "../hooks/useBans";
|
||||
import { formatTimestamp } from "../utils/formatDate";
|
||||
import type { DashboardBanItem, TimeRange, BanOriginFilter } from "../types/ban";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -90,31 +91,6 @@ const useStyles = makeStyles({
|
||||
},
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Format an ISO 8601 timestamp for display.
|
||||
*
|
||||
* @param iso - ISO 8601 UTC string.
|
||||
* @returns Localised date+time string.
|
||||
*/
|
||||
function formatTimestamp(iso: string): string {
|
||||
try {
|
||||
return new Date(iso).toLocaleString(undefined, {
|
||||
year: "numeric",
|
||||
month: "2-digit",
|
||||
day: "2-digit",
|
||||
hour: "2-digit",
|
||||
minute: "2-digit",
|
||||
second: "2-digit",
|
||||
});
|
||||
} catch {
|
||||
return iso;
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Column definitions
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@@ -14,6 +14,7 @@ import {
|
||||
makeStyles,
|
||||
tokens,
|
||||
} from "@fluentui/react-components";
|
||||
import { useCardStyles } from "../theme/commonStyles";
|
||||
import type { BanOriginFilter, TimeRange } from "../types/ban";
|
||||
import {
|
||||
BAN_ORIGIN_FILTER_LABELS,
|
||||
@@ -57,20 +58,6 @@ const useStyles = makeStyles({
|
||||
alignItems: "center",
|
||||
flexWrap: "wrap",
|
||||
gap: tokens.spacingVerticalS,
|
||||
backgroundColor: tokens.colorNeutralBackground1,
|
||||
borderRadius: tokens.borderRadiusMedium,
|
||||
borderTopWidth: "1px",
|
||||
borderTopStyle: "solid",
|
||||
borderTopColor: tokens.colorNeutralStroke2,
|
||||
borderRightWidth: "1px",
|
||||
borderRightStyle: "solid",
|
||||
borderRightColor: tokens.colorNeutralStroke2,
|
||||
borderBottomWidth: "1px",
|
||||
borderBottomStyle: "solid",
|
||||
borderBottomColor: tokens.colorNeutralStroke2,
|
||||
borderLeftWidth: "1px",
|
||||
borderLeftStyle: "solid",
|
||||
borderLeftColor: tokens.colorNeutralStroke2,
|
||||
paddingTop: tokens.spacingVerticalS,
|
||||
paddingBottom: tokens.spacingVerticalS,
|
||||
paddingLeft: tokens.spacingHorizontalM,
|
||||
@@ -107,9 +94,10 @@ export function DashboardFilterBar({
|
||||
onOriginFilterChange,
|
||||
}: DashboardFilterBarProps): React.JSX.Element {
|
||||
const styles = useStyles();
|
||||
const cardStyles = useCardStyles();
|
||||
|
||||
return (
|
||||
<div className={styles.container}>
|
||||
<div className={`${styles.container} ${cardStyles.card}`}>
|
||||
{/* Time-range group */}
|
||||
<div className={styles.group}>
|
||||
<Text weight="semibold" size={300}>
|
||||
|
||||
62
frontend/src/components/ErrorBoundary.tsx
Normal file
62
frontend/src/components/ErrorBoundary.tsx
Normal file
@@ -0,0 +1,62 @@
|
||||
/**
|
||||
* React error boundary component.
|
||||
*
|
||||
* Catches render-time exceptions in child components and shows a fallback UI.
|
||||
*/
|
||||
import React from "react";
|
||||
|
||||
interface ErrorBoundaryState {
|
||||
hasError: boolean;
|
||||
errorMessage: string | null;
|
||||
}
|
||||
|
||||
interface ErrorBoundaryProps {
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
export class ErrorBoundary extends React.Component<ErrorBoundaryProps, ErrorBoundaryState> {
|
||||
constructor(props: ErrorBoundaryProps) {
|
||||
super(props);
|
||||
this.state = { hasError: false, errorMessage: null };
|
||||
this.handleReload = this.handleReload.bind(this);
|
||||
}
|
||||
|
||||
static getDerivedStateFromError(error: Error): ErrorBoundaryState {
|
||||
return { hasError: true, errorMessage: error.message || "Unknown error" };
|
||||
}
|
||||
|
||||
componentDidCatch(error: Error, errorInfo: React.ErrorInfo): void {
|
||||
console.error("ErrorBoundary caught an error", { error, errorInfo });
|
||||
}
|
||||
|
||||
handleReload(): void {
|
||||
window.location.reload();
|
||||
}
|
||||
|
||||
render(): React.ReactNode {
|
||||
if (this.state.hasError) {
|
||||
return (
|
||||
<div
|
||||
style={{
|
||||
display: "flex",
|
||||
flexDirection: "column",
|
||||
alignItems: "center",
|
||||
justifyContent: "center",
|
||||
minHeight: "100vh",
|
||||
padding: "24px",
|
||||
textAlign: "center",
|
||||
}}
|
||||
role="alert"
|
||||
>
|
||||
<h1>Something went wrong</h1>
|
||||
<p>{this.state.errorMessage ?? "Please try reloading the page."}</p>
|
||||
<button type="button" onClick={this.handleReload} style={{ marginTop: "16px" }}>
|
||||
Reload
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return this.props.children;
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,7 @@ import {
|
||||
tokens,
|
||||
Tooltip,
|
||||
} from "@fluentui/react-components";
|
||||
import { useCardStyles } from "../theme/commonStyles";
|
||||
import { ArrowClockwiseRegular, ShieldRegular } from "@fluentui/react-icons";
|
||||
import { useServerStatus } from "../hooks/useServerStatus";
|
||||
|
||||
@@ -31,20 +32,6 @@ const useStyles = makeStyles({
|
||||
alignItems: "center",
|
||||
gap: tokens.spacingHorizontalL,
|
||||
padding: `${tokens.spacingVerticalS} ${tokens.spacingHorizontalL}`,
|
||||
backgroundColor: tokens.colorNeutralBackground1,
|
||||
borderRadius: tokens.borderRadiusMedium,
|
||||
borderTopWidth: "1px",
|
||||
borderTopStyle: "solid",
|
||||
borderTopColor: tokens.colorNeutralStroke2,
|
||||
borderRightWidth: "1px",
|
||||
borderRightStyle: "solid",
|
||||
borderRightColor: tokens.colorNeutralStroke2,
|
||||
borderBottomWidth: "1px",
|
||||
borderBottomStyle: "solid",
|
||||
borderBottomColor: tokens.colorNeutralStroke2,
|
||||
borderLeftWidth: "1px",
|
||||
borderLeftStyle: "solid",
|
||||
borderLeftColor: tokens.colorNeutralStroke2,
|
||||
marginBottom: tokens.spacingVerticalL,
|
||||
flexWrap: "wrap",
|
||||
},
|
||||
@@ -85,8 +72,10 @@ export function ServerStatusBar(): React.JSX.Element {
|
||||
const styles = useStyles();
|
||||
const { status, loading, error, refresh } = useServerStatus();
|
||||
|
||||
const cardStyles = useCardStyles();
|
||||
|
||||
return (
|
||||
<div className={styles.bar} role="status" aria-label="fail2ban server status">
|
||||
<div className={`${cardStyles.card} ${styles.bar}`} role="status" aria-label="fail2ban server status">
|
||||
{/* ---------------------------------------------------------------- */}
|
||||
{/* Online / Offline badge */}
|
||||
{/* ---------------------------------------------------------------- */}
|
||||
@@ -109,7 +98,7 @@ export function ServerStatusBar(): React.JSX.Element {
|
||||
{/* Version */}
|
||||
{/* ---------------------------------------------------------------- */}
|
||||
{status?.version != null && (
|
||||
<Tooltip content="fail2ban version" relationship="description">
|
||||
<Tooltip content="BanGUI version" relationship="description">
|
||||
<Text size={200} className={styles.statValue}>
|
||||
v{status.version}
|
||||
</Text>
|
||||
@@ -139,9 +128,9 @@ export function ServerStatusBar(): React.JSX.Element {
|
||||
</div>
|
||||
</Tooltip>
|
||||
|
||||
<Tooltip content="Currently failing IPs" relationship="description">
|
||||
<Tooltip content="Total failed authentication attempts currently tracked by fail2ban across all active jails" relationship="description">
|
||||
<div className={styles.statGroup}>
|
||||
<Text size={200}>Failures:</Text>
|
||||
<Text size={200}>Failed Attempts:</Text>
|
||||
<Text size={200} className={styles.statValue}>
|
||||
{status.total_failures}
|
||||
</Text>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user