Compare commits
122 Commits
5e1b8134d9
...
fix/worldm
| Author | SHA1 | Date | |
|---|---|---|---|
| c03a5c1cbc | |||
| eb983799cd | |||
| d3f564d66f | |||
| bbd57c808b | |||
| ffaa14f864 | |||
| 7d09b78437 | |||
| 8e2bb5d3fb | |||
| bfe0daf754 | |||
| 13823b1182 | |||
| 7967191ccd | |||
| 470c29443c | |||
| 6f15e1fa24 | |||
| 487cb171f2 | |||
| 7789353690 | |||
| ccfcbc82c5 | |||
| 7626c9cb60 | |||
| ac4fd967aa | |||
| 9f05da2d4d | |||
| 876af46955 | |||
| 0d4a2a3311 | |||
| f555b1b0a2 | |||
| a30b92471a | |||
| 9e43282bbc | |||
| 2ea4a8304f | |||
| e99920e616 | |||
| 670ff3e8a2 | |||
| f6672d0d16 | |||
| d909f93efc | |||
| 965cdd765b | |||
| 0663740b08 | |||
| 29587f2353 | |||
| 798ed08ddd | |||
| ed184f1c84 | |||
| 8e1b4fa978 | |||
| e604e3aadf | |||
| cf721513e8 | |||
| a32cc82851 | |||
| 26af69e2a3 | |||
| 00e702a2c0 | |||
| ee73373111 | |||
| a1f97bd78f | |||
| 99fbddb0e7 | |||
| b15629a078 | |||
| 136f21ecbe | |||
| bf2abda595 | |||
| 335f89c554 | |||
| 05dc9fa1e3 | |||
| 471eed9664 | |||
| 1f272dc348 | |||
| f9cec2a975 | |||
| cc235b95c6 | |||
| 29415da421 | |||
| 8a6bcc4d94 | |||
| a442836c5c | |||
| 3aba2b6446 | |||
| 28a7610276 | |||
| d30d138146 | |||
| 8c4fe767de | |||
| 52b0936200 | |||
| 1c0bac1353 | |||
| bdcdd5d672 | |||
| 482399c9e2 | |||
| ce59a66973 | |||
| dfbe126368 | |||
| c9e688cc52 | |||
| 1ce5da9e23 | |||
| 93f0feabde | |||
| 376c13370d | |||
| fb6d0e588f | |||
| e44caccb3c | |||
| 15e4a5434e | |||
| 1cc9968d31 | |||
| 80a6bac33e | |||
| 133ab2e82c | |||
| 60f2f35b25 | |||
| 59da34dc3b | |||
| 90f54cf39c | |||
| 93d26e3c60 | |||
| 954dcf7ea6 | |||
| bf8144916a | |||
| 481daa4e1a | |||
| 889976c7ee | |||
| d3d2cb0915 | |||
| bf82e38b6e | |||
| e98fd1de93 | |||
| 8f515893ea | |||
| 81f99d0b50 | |||
| 030bca09b7 | |||
| 5b7d1a4360 | |||
| e7834a888e | |||
| abb224e01b | |||
| 57cf93b1e5 | |||
| c41165c294 | |||
| cdf73e2d65 | |||
| 21753c4f06 | |||
| eb859af371 | |||
| 5a5c619a34 | |||
| 00119ed68d | |||
| b81e0cdbb4 | |||
| 41dcd60225 | |||
| 12f04bd8d6 | |||
| d4d04491d2 | |||
| 93dc699825 | |||
| 61daa8bbc0 | |||
| 57a0bbe36e | |||
| f62785aaf2 | |||
| 1e33220f59 | |||
| 1da38361a9 | |||
| 9630aea877 | |||
| 037c18eb00 | |||
| 2e1a4b3b2b | |||
| 4be2469f92 | |||
| 6bb38dbd8c | |||
| d3b2022ffb | |||
| 4b6e118a88 | |||
| 936946010f | |||
| ee7412442a | |||
| 68d8056d2e | |||
| 528d0bd8ea | |||
| baf45c6c62 | |||
| 0966f347c4 | |||
| ab11ece001 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -105,6 +105,7 @@ Docker/fail2ban-dev-config/**
|
||||
!Docker/fail2ban-dev-config/fail2ban/jail.d/bangui-sim.conf
|
||||
!Docker/fail2ban-dev-config/fail2ban/jail.d/bangui-access.conf
|
||||
!Docker/fail2ban-dev-config/fail2ban/jail.d/blocklist-import.conf
|
||||
!Docker/fail2ban-dev-config/fail2ban/jail.local
|
||||
|
||||
# ── Misc ──────────────────────────────────────
|
||||
*.log
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
# ── Stage 1: build dependencies ──────────────────────────────
|
||||
FROM python:3.12-slim AS builder
|
||||
FROM docker.io/library/python:3.12-slim AS builder
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
@@ -28,7 +28,7 @@ RUN pip install --no-cache-dir --upgrade pip \
|
||||
&& pip install --no-cache-dir .
|
||||
|
||||
# ── Stage 2: runtime image ───────────────────────────────────
|
||||
FROM python:3.12-slim AS runtime
|
||||
FROM docker.io/library/python:3.12-slim AS runtime
|
||||
|
||||
LABEL maintainer="BanGUI" \
|
||||
description="BanGUI backend — fail2ban web management API"
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
# ── Stage 1: install & build ─────────────────────────────────
|
||||
FROM node:22-alpine AS builder
|
||||
FROM docker.io/library/node:22-alpine AS builder
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
@@ -23,7 +23,7 @@ COPY frontend/ /build/
|
||||
RUN npm run build
|
||||
|
||||
# ── Stage 2: serve with nginx ────────────────────────────────
|
||||
FROM nginx:1.27-alpine AS runtime
|
||||
FROM docker.io/library/nginx:1.27-alpine AS runtime
|
||||
|
||||
LABEL maintainer="BanGUI" \
|
||||
description="BanGUI frontend — fail2ban web management UI"
|
||||
|
||||
1
Docker/VERSION
Normal file
1
Docker/VERSION
Normal file
@@ -0,0 +1 @@
|
||||
v0.9.18
|
||||
@@ -2,7 +2,7 @@
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# check_ban_status.sh
|
||||
#
|
||||
# Queries the bangui-sim jail inside the running fail2ban
|
||||
# Queries the manual-Jail jail inside the running fail2ban
|
||||
# container and optionally unbans a specific IP.
|
||||
#
|
||||
# Usage:
|
||||
@@ -17,7 +17,7 @@
|
||||
set -euo pipefail
|
||||
|
||||
readonly CONTAINER="bangui-fail2ban-dev"
|
||||
readonly JAIL="bangui-sim"
|
||||
readonly JAIL="manual-Jail"
|
||||
|
||||
# ── Helper: run a fail2ban-client command inside the container ─
|
||||
f2b() {
|
||||
|
||||
@@ -37,6 +37,11 @@ services:
|
||||
timeout: 5s
|
||||
start_period: 15s
|
||||
retries: 3
|
||||
# NOTE: The fail2ban-config volume must be pre-populated with the following files:
|
||||
# • fail2ban/jail.conf (or jail.d/*.conf) with the DEFAULT section containing:
|
||||
# banaction = iptables-allports[lockingopt="-w 5"]
|
||||
# This prevents xtables lock contention errors when multiple jails start in parallel.
|
||||
# See https://fail2ban.readthedocs.io/en/latest/development/environment.html
|
||||
|
||||
# ── Backend (FastAPI + uvicorn) ─────────────────────────────
|
||||
backend:
|
||||
|
||||
73
Docker/docker-compose.yml
Normal file
73
Docker/docker-compose.yml
Normal file
@@ -0,0 +1,73 @@
|
||||
version: '3.8'
|
||||
services:
|
||||
fail2ban:
|
||||
image: lscr.io/linuxserver/fail2ban:latest
|
||||
container_name: fail2ban
|
||||
cap_add:
|
||||
- NET_ADMIN
|
||||
- NET_RAW
|
||||
network_mode: host
|
||||
environment:
|
||||
- PUID=1011
|
||||
- PGID=1001
|
||||
- TZ=Etc/UTC
|
||||
- VERBOSITY=-vv #optional
|
||||
|
||||
volumes:
|
||||
- /server/server_fail2ban/config:/config
|
||||
- /server/server_fail2ban/fail2ban-run:/var/run/fail2ban
|
||||
- /var/log:/var/log
|
||||
- /server/server_nextcloud/config/nextcloud.log:/remotelogs/nextcloud/nextcloud.log:ro #optional
|
||||
- /server/server_nginx/data/logs:/remotelogs/nginx:ro #optional
|
||||
- /server/server_gitea/log/gitea.log:/remotelogs/gitea/gitea.log:ro #optional
|
||||
|
||||
|
||||
#- /path/to/homeassistant/log:/remotelogs/homeassistant:ro #optional
|
||||
#- /path/to/unificontroller/log:/remotelogs/unificontroller:ro #optional
|
||||
#- /path/to/vaultwarden/log:/remotelogs/vaultwarden:ro #optional
|
||||
restart: unless-stopped
|
||||
|
||||
backend:
|
||||
image: git.lpl-mind.de/lukas.pupkalipinski/bangui/backend:latest
|
||||
container_name: bangui-backend
|
||||
restart: unless-stopped
|
||||
depends_on:
|
||||
fail2ban:
|
||||
condition: service_started
|
||||
environment:
|
||||
- PUID=1011
|
||||
- PGID=1001
|
||||
- BANGUI_DATABASE_PATH=/data/bangui.db
|
||||
- BANGUI_FAIL2BAN_SOCKET=/var/run/fail2ban/fail2ban.sock
|
||||
- BANGUI_FAIL2BAN_CONFIG_DIR=/config/fail2ban
|
||||
- BANGUI_LOG_LEVEL=info
|
||||
- BANGUI_SESSION_SECRET=${BANGUI_SESSION_SECRET:?Set BANGUI_SESSION_SECRET}
|
||||
- BANGUI_TIMEZONE=${BANGUI_TIMEZONE:-UTC}
|
||||
volumes:
|
||||
- /server/server_fail2ban/bangui-data:/data
|
||||
- /server/server_fail2ban/fail2ban-run:/var/run/fail2ban:ro
|
||||
- /server/server_fail2ban/config:/config:rw
|
||||
expose:
|
||||
- "8000"
|
||||
networks:
|
||||
- bangui-net
|
||||
|
||||
# ── Frontend (nginx serving built SPA + API proxy) ──────────
|
||||
frontend:
|
||||
image: git.lpl-mind.de/lukas.pupkalipinski/bangui/frontend:latest
|
||||
container_name: bangui-frontend
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- PUID=1011
|
||||
- PGID=1001
|
||||
ports:
|
||||
- "${BANGUI_PORT:-8080}:80"
|
||||
depends_on:
|
||||
backend:
|
||||
condition: service_started
|
||||
networks:
|
||||
- bangui-net
|
||||
|
||||
networks:
|
||||
bangui-net:
|
||||
name: bangui-net
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
This directory contains the fail2ban configuration and supporting scripts for a
|
||||
self-contained development test environment. A simulation script writes fake
|
||||
authentication-failure log lines, fail2ban detects them via the `bangui-sim`
|
||||
authentication-failure log lines, fail2ban detects them via the `manual-Jail`
|
||||
jail, and bans the offending IP — giving a fully reproducible ban/unban cycle
|
||||
without a real service.
|
||||
|
||||
@@ -71,14 +71,19 @@ Chains steps 1–3 automatically with appropriate sleep intervals.
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `fail2ban/filter.d/bangui-sim.conf` | Defines the `failregex` that matches simulation log lines |
|
||||
| `fail2ban/jail.d/bangui-sim.conf` | Jail settings: `maxretry=3`, `bantime=60s`, `findtime=120s` |
|
||||
| `fail2ban/filter.d/manual-Jail.conf` | Defines the `failregex` that matches simulation log lines |
|
||||
| `fail2ban/jail.d/manual-Jail.conf` | Jail settings: `maxretry=3`, `bantime=60s`, `findtime=120s` |
|
||||
| `Docker/logs/auth.log` | Log file written by the simulation script (host path) |
|
||||
|
||||
Inside the container the log file is mounted at `/remotelogs/bangui/auth.log`
|
||||
(see `fail2ban/paths-lsio.conf` — `remote_logs_path = /remotelogs`).
|
||||
|
||||
To change sensitivity, edit `fail2ban/jail.d/bangui-sim.conf`:
|
||||
BanGUI also extends fail2ban history retention for archive backfill. In
|
||||
the development config `fail2ban/fail2ban.conf` the database purge age is
|
||||
set to `648000` seconds (7.5 days) so the first archive sync can recover a
|
||||
full 7-day window before fail2ban purges old rows.
|
||||
|
||||
To change sensitivity, edit `fail2ban/jail.d/manual-Jail.conf`:
|
||||
|
||||
```ini
|
||||
maxretry = 3 # failures before a ban
|
||||
@@ -108,14 +113,14 @@ Test the regex manually:
|
||||
|
||||
```bash
|
||||
docker exec bangui-fail2ban-dev \
|
||||
fail2ban-regex /remotelogs/bangui/auth.log bangui-sim
|
||||
fail2ban-regex /remotelogs/bangui/auth.log manual-Jail
|
||||
```
|
||||
|
||||
The output should show matched lines. If nothing matches, check that the log
|
||||
lines match the corresponding `failregex` pattern:
|
||||
|
||||
```
|
||||
# bangui-sim (auth log):
|
||||
# manual-Jail (auth log):
|
||||
YYYY-MM-DD HH:MM:SS bangui-auth: authentication failure from <IP>
|
||||
```
|
||||
|
||||
@@ -132,7 +137,7 @@ sudo modprobe ip_tables
|
||||
### IP not banned despite enough failures
|
||||
|
||||
Check whether the source IP falls inside the `ignoreip` range defined in
|
||||
`fail2ban/jail.d/bangui-sim.conf`:
|
||||
`fail2ban/jail.d/manual-Jail.conf`:
|
||||
|
||||
```ini
|
||||
ignoreip = 127.0.0.0/8 ::1 172.16.0.0/12
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#
|
||||
# Matches lines written by Docker/simulate_failed_logins.sh
|
||||
# Format: <timestamp> bangui-auth: authentication failure from <HOST>
|
||||
# Jail: manual-Jail
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
[Definition]
|
||||
@@ -18,9 +18,8 @@ logpath = /dev/null
|
||||
backend = auto
|
||||
maxretry = 1
|
||||
findtime = 1d
|
||||
# Block imported IPs for one week.
|
||||
bantime = 1w
|
||||
banaction = iptables-allports
|
||||
# Block imported IPs for 24 hours.
|
||||
bantime = 86400
|
||||
|
||||
# Never ban the Docker bridge network or localhost.
|
||||
ignoreip = 127.0.0.0/8 ::1 172.16.0.0/12
|
||||
|
||||
@@ -5,16 +5,15 @@
|
||||
# for lines produced by Docker/simulate_failed_logins.sh.
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
[bangui-sim]
|
||||
[manual-Jail]
|
||||
|
||||
enabled = true
|
||||
filter = bangui-sim
|
||||
filter = manual-Jail
|
||||
logpath = /remotelogs/bangui/auth.log
|
||||
backend = polling
|
||||
maxretry = 3
|
||||
findtime = 120
|
||||
bantime = 60
|
||||
banaction = iptables-allports
|
||||
|
||||
# Never ban localhost, the Docker bridge network, or the host machine.
|
||||
ignoreip = 127.0.0.0/8 ::1 172.16.0.0/12
|
||||
6
Docker/fail2ban-dev-config/fail2ban/jail.local
Normal file
6
Docker/fail2ban-dev-config/fail2ban/jail.local
Normal file
@@ -0,0 +1,6 @@
|
||||
# Local overrides — not overwritten by the container init script.
|
||||
# Provides banaction so all jails can resolve %(action_)s interpolation.
|
||||
|
||||
[DEFAULT]
|
||||
banaction = iptables-multiport
|
||||
banaction_allports = iptables-allports
|
||||
@@ -56,11 +56,8 @@ echo " Registry : ${REGISTRY}"
|
||||
echo " Tag : ${TAG}"
|
||||
echo "============================================"
|
||||
|
||||
if [[ "${ENGINE}" == "podman" ]]; then
|
||||
if ! podman login --get-login "${REGISTRY}" &>/dev/null; then
|
||||
err "Not logged in. Run:\n podman login ${REGISTRY}"
|
||||
fi
|
||||
fi
|
||||
log "Logging in to ${REGISTRY}"
|
||||
"${ENGINE}" login "${REGISTRY}"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Build
|
||||
|
||||
101
Docker/release.sh
Normal file
101
Docker/release.sh
Normal file
@@ -0,0 +1,101 @@
|
||||
#!/usr/bin/env bash
|
||||
#
|
||||
# Bump the project version and push images to the registry.
|
||||
#
|
||||
# Usage:
|
||||
# ./release.sh
|
||||
#
|
||||
# The current version is stored in VERSION (next to this script).
|
||||
# You will be asked whether to bump major, minor, or patch.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
VERSION_FILE="${SCRIPT_DIR}/VERSION"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Read current version
|
||||
# ---------------------------------------------------------------------------
|
||||
if [[ ! -f "${VERSION_FILE}" ]]; then
|
||||
echo "0.0.0" > "${VERSION_FILE}"
|
||||
fi
|
||||
|
||||
CURRENT="$(cat "${VERSION_FILE}")"
|
||||
# Strip leading 'v' for arithmetic
|
||||
VERSION="${CURRENT#v}"
|
||||
|
||||
IFS='.' read -r MAJOR MINOR PATCH <<< "${VERSION}"
|
||||
|
||||
echo "============================================"
|
||||
echo " BanGUI — Release"
|
||||
echo " Current version: v${MAJOR}.${MINOR}.${PATCH}"
|
||||
echo "============================================"
|
||||
echo ""
|
||||
echo "How would you like to bump the version?"
|
||||
echo " 1) patch (v${MAJOR}.${MINOR}.${PATCH} → v${MAJOR}.${MINOR}.$((PATCH + 1)))"
|
||||
echo " 2) minor (v${MAJOR}.${MINOR}.${PATCH} → v${MAJOR}.$((MINOR + 1)).0)"
|
||||
echo " 3) major (v${MAJOR}.${MINOR}.${PATCH} → v$((MAJOR + 1)).0.0)"
|
||||
echo ""
|
||||
read -rp "Enter choice [1/2/3]: " CHOICE
|
||||
|
||||
case "${CHOICE}" in
|
||||
1) NEW_TAG="v${MAJOR}.${MINOR}.$((PATCH + 1))" ;;
|
||||
2) NEW_TAG="v${MAJOR}.$((MINOR + 1)).0" ;;
|
||||
3) NEW_TAG="v$((MAJOR + 1)).0.0" ;;
|
||||
*)
|
||||
echo "Invalid choice. Aborting." >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
echo ""
|
||||
echo "New version: ${NEW_TAG}"
|
||||
read -rp "Confirm? [y/N]: " CONFIRM
|
||||
if [[ ! "${CONFIRM}" =~ ^[yY]$ ]]; then
|
||||
echo "Aborted."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Write new version
|
||||
# ---------------------------------------------------------------------------
|
||||
echo "${NEW_TAG}" > "${VERSION_FILE}"
|
||||
echo "Version file updated → ${VERSION_FILE}"
|
||||
|
||||
# Keep frontend/package.json in sync so __APP_VERSION__ matches Docker/VERSION.
|
||||
FRONT_VERSION="${NEW_TAG#v}"
|
||||
FRONT_PKG="${SCRIPT_DIR}/../frontend/package.json"
|
||||
sed -i "s/\"version\": \"[^\"]*\"/\"version\": \"${FRONT_VERSION}\"/" "${FRONT_PKG}"
|
||||
echo "frontend/package.json version updated → ${FRONT_VERSION}"
|
||||
|
||||
# Keep backend/pyproject.toml in sync so app.__version__ matches Docker/VERSION in the runtime container.
|
||||
BACKEND_PYPROJECT="${SCRIPT_DIR}/../backend/pyproject.toml"
|
||||
if [[ -f "${BACKEND_PYPROJECT}" ]]; then
|
||||
sed -i "s/^version = \".*\"/version = \"${FRONT_VERSION}\"/" "${BACKEND_PYPROJECT}"
|
||||
echo "backend/pyproject.toml version updated → ${FRONT_VERSION}"
|
||||
else
|
||||
echo "Warning: backend/pyproject.toml not found, skipping backend version sync" >&2
|
||||
fi
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Push containers
|
||||
# ---------------------------------------------------------------------------
|
||||
bash "${SCRIPT_DIR}/push.sh" "${NEW_TAG}"
|
||||
bash "${SCRIPT_DIR}/push.sh"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Git tag (local only; push after container build)
|
||||
# ---------------------------------------------------------------------------
|
||||
cd "${SCRIPT_DIR}/.."
|
||||
git add Docker/VERSION frontend/package.json
|
||||
git commit -m "chore: release ${NEW_TAG}"
|
||||
git tag -a "${NEW_TAG}" -m "Release ${NEW_TAG}"
|
||||
echo "Local git commit + tag ${NEW_TAG} created."
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Push git commits & tag
|
||||
# ---------------------------------------------------------------------------
|
||||
git push origin HEAD
|
||||
git push origin "${NEW_TAG}"
|
||||
echo "Git commit and tag ${NEW_TAG} pushed."
|
||||
@@ -3,7 +3,7 @@
|
||||
# simulate_failed_logins.sh
|
||||
#
|
||||
# Writes synthetic authentication-failure log lines to a file
|
||||
# that matches the bangui-sim fail2ban filter.
|
||||
# that matches the manual-Jail fail2ban filter.
|
||||
#
|
||||
# Usage:
|
||||
# bash Docker/simulate_failed_logins.sh [COUNT] [SOURCE_IP] [LOG_FILE]
|
||||
@@ -13,7 +13,7 @@
|
||||
# SOURCE_IP: 192.168.100.99
|
||||
# LOG_FILE : Docker/logs/auth.log (relative to repo root)
|
||||
#
|
||||
# Log line format (must match bangui-sim failregex exactly):
|
||||
# Log line format (must match manual-Jail failregex exactly):
|
||||
# YYYY-MM-DD HH:MM:SS bangui-auth: authentication failure from <IP>
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -82,10 +82,12 @@ The backend follows a **layered architecture** with strict separation of concern
|
||||
backend/
|
||||
├── app/
|
||||
│ ├── __init__.py
|
||||
│ ├── main.py # FastAPI app factory, lifespan, exception handlers
|
||||
│ ├── config.py # Pydantic settings (env vars, .env loading)
|
||||
│ ├── dependencies.py # FastAPI Depends() providers (DB, services, auth)
|
||||
│ ├── models/ # Pydantic schemas
|
||||
│ ├── `main.py` # FastAPI app factory, lifespan, exception handlers
|
||||
│ ├── `config.py` # Pydantic settings (env vars, .env loading)
|
||||
│ ├── `db.py` # Database connection and initialization
|
||||
│ ├── `exceptions.py` # Shared domain exception classes
|
||||
│ ├── `dependencies.py` # FastAPI Depends() providers (DB, services, auth)
|
||||
│ ├── `models/` # Pydantic schemas
|
||||
│ │ ├── auth.py # Login request/response, session models
|
||||
│ │ ├── ban.py # Ban request/response/domain models
|
||||
│ │ ├── jail.py # Jail request/response/domain models
|
||||
@@ -111,6 +113,12 @@ backend/
|
||||
│ │ ├── jail_service.py # Jail listing, start/stop/reload, status aggregation
|
||||
│ │ ├── ban_service.py # Ban/unban execution, currently-banned queries
|
||||
│ │ ├── config_service.py # Read/write fail2ban config, regex validation
|
||||
│ │ ├── config_file_service.py # Shared config parsing and file-level operations
|
||||
│ │ ├── raw_config_io_service.py # Raw config file I/O wrapper
|
||||
│ │ ├── jail_config_service.py # jail config activation/deactivation logic
|
||||
│ │ ├── filter_config_service.py # filter config lifecycle management
|
||||
│ │ ├── action_config_service.py # action config lifecycle management
|
||||
│ │ ├── log_service.py # Log preview and regex test operations
|
||||
│ │ ├── history_service.py # Historical ban queries, per-IP timeline
|
||||
│ │ ├── blocklist_service.py # Download, validate, apply blocklists
|
||||
│ │ ├── geo_service.py # IP-to-country resolution, ASN/RIR lookup
|
||||
@@ -119,17 +127,18 @@ backend/
|
||||
│ ├── repositories/ # Data access layer (raw queries only)
|
||||
│ │ ├── settings_repo.py # App configuration CRUD in SQLite
|
||||
│ │ ├── session_repo.py # Session storage and lookup
|
||||
│ │ ├── blocklist_repo.py # Blocklist sources and import log persistence
|
||||
│ │ └── import_log_repo.py # Import run history records
|
||||
│ │ ├── blocklist_repo.py # Blocklist sources and import log persistence│ │ ├── fail2ban_db_repo.py # fail2ban SQLite ban history read operations
|
||||
│ │ ├── geo_cache_repo.py # IP geolocation cache persistence│ │ └── import_log_repo.py # Import run history records
|
||||
│ ├── tasks/ # APScheduler background jobs
|
||||
│ │ ├── blocklist_import.py# Scheduled blocklist download and application
|
||||
│ │ ├── geo_cache_flush.py # Periodic geo cache persistence (dirty-set flush to SQLite)
|
||||
│ │ └── health_check.py # Periodic fail2ban connectivity probe
|
||||
│ │ ├── geo_cache_flush.py # Periodic geo cache persistence (dirty-set flush to SQLite)│ │ ├── geo_re_resolve.py # Periodic re-resolution of stale geo cache records│ │ └── health_check.py # Periodic fail2ban connectivity probe
|
||||
│ └── utils/ # Helpers, constants, shared types
|
||||
│ ├── fail2ban_client.py # Async wrapper around the fail2ban socket protocol
|
||||
│ ├── ip_utils.py # IP/CIDR validation and normalisation
|
||||
│ ├── time_utils.py # Timezone-aware datetime helpers
|
||||
│ └── constants.py # Shared constants (default paths, limits, etc.)
|
||||
│ ├── time_utils.py # Timezone-aware datetime helpers│ ├── jail_config.py # Jail config parser/serializer helper
|
||||
│ ├── conffile_parser.py # Fail2ban config file parser/serializer
|
||||
│ ├── config_parser.py # Structured config object parser
|
||||
│ ├── config_writer.py # Atomic config file write operations│ └── constants.py # Shared constants (default paths, limits, etc.)
|
||||
├── tests/
|
||||
│ ├── conftest.py # Shared fixtures (test app, client, mock DB)
|
||||
│ ├── test_routers/ # One test file per router
|
||||
@@ -152,14 +161,15 @@ The HTTP interface layer. Each router maps URL paths to handler functions. Route
|
||||
| `dashboard.py` | `/api/dashboard` | Server status bar data, recent bans for the dashboard |
|
||||
| `jails.py` | `/api/jails` | List jails, jail detail, start/stop/reload/idle controls |
|
||||
| `bans.py` | `/api/bans` | Ban an IP, unban an IP, unban all, list currently banned IPs |
|
||||
| `config.py` | `/api/config` | Read and write fail2ban jail/filter/server configuration via the socket |
|
||||
| `config.py` | `/api/config` | Read and write fail2ban jail/filter/server configuration via the socket; also serves the fail2ban log tail and service status for the Log tab |
|
||||
| `file_config.py` | `/api/config` | Read and write fail2ban config files on disk (jail.d/, filter.d/, action.d/) — list, get, and overwrite raw file contents, toggle jail enabled/disabled |
|
||||
| `history.py` | `/api/history` | Query historical bans, per-IP timeline |
|
||||
| `blocklist.py` | `/api/blocklists` | CRUD blocklist sources, trigger import, view import logs |
|
||||
| `geo.py` | `/api/geo` | IP geolocation lookup, ASN and RIR data |
|
||||
| `server.py` | `/api/server` | Log level, log target, DB path, purge age, flush logs |
|
||||
| `health.py` | `/api/health` | fail2ban connectivity health check and status |
|
||||
|
||||
#### Services (`app/services/`)
|
||||
#### Services (`app/services`)
|
||||
|
||||
The business logic layer. Services orchestrate operations, enforce rules, and coordinate between repositories, the fail2ban client, and external APIs. Each service covers a single domain.
|
||||
|
||||
@@ -169,10 +179,14 @@ The business logic layer. Services orchestrate operations, enforce rules, and co
|
||||
| `setup_service.py` | Validates setup input, persists initial configuration, ensures setup runs only once |
|
||||
| `jail_service.py` | Retrieves jail list and details from fail2ban, aggregates metrics (banned count, failure count), sends start/stop/reload/idle commands |
|
||||
| `ban_service.py` | Executes ban and unban commands via the fail2ban socket, queries the currently banned IP list, validates IPs before banning |
|
||||
| `config_service.py` | Reads active jail and filter configuration from fail2ban, writes configuration changes, validates regex patterns, triggers reload |
|
||||
| `config_service.py` | Reads active jail and filter configuration from fail2ban, writes configuration changes, validates regex patterns, triggers reload; reads the fail2ban log file tail and queries service status for the Log tab |
|
||||
| `file_config_service.py` | Reads and writes raw fail2ban config files on disk (jail.d/, filter.d/, action.d/); lists files, reads content, overwrites files, toggles enabled/disabled |
|
||||
| `config_file_service.py` | Parses jail.conf / jail.local / jail.d/* to discover inactive jails; writes .local overrides to activate or deactivate jails; triggers fail2ban reload |
|
||||
| `conffile_parser.py` | Parses fail2ban `.conf` files into structured Python types (jail config, filter config, action config); also serialises back to text |
|
||||
| `jail_config_service.py` | Discovers inactive jails by parsing jail.conf / jail.local / jail.d/*; writes .local overrides to activate/deactivate jails; triggers fail2ban reload; validates jail configurations |
|
||||
| `filter_config_service.py` | Discovers available filters by scanning filter.d/; reads, creates, updates, and deletes filter definitions; assigns filters to jails |
|
||||
| `action_config_service.py` | Discovers available actions by scanning action.d/; reads, creates, updates, and deletes action definitions; assigns actions to jails |
|
||||
| `config_file_service.py` | Shared utilities for configuration parsing and manipulation: parses config files, validates names/IPs, manages atomic file writes, probes fail2ban socket |
|
||||
| `raw_config_io_service.py` | Low-level file I/O for raw fail2ban config files |
|
||||
| `log_service.py` | Log preview and regex test operations (extracted from config_service) |
|
||||
| `history_service.py` | Queries the fail2ban database for historical ban records, builds per-IP timelines, computes ban counts and repeat-offender flags |
|
||||
| `blocklist_service.py` | Downloads blocklists via aiohttp, validates IPs/CIDRs, applies bans through fail2ban or iptables, logs import results |
|
||||
| `geo_service.py` | Resolves IP addresses to country, ASN, and RIR using external APIs or a local database, caches results |
|
||||
@@ -188,15 +202,26 @@ The data access layer. Repositories execute raw SQL queries against the applicat
|
||||
| `settings_repo.py` | CRUD operations for application settings (master password hash, DB path, fail2ban socket path, preferences) |
|
||||
| `session_repo.py` | Store, retrieve, and delete session records for authentication |
|
||||
| `blocklist_repo.py` | Persist blocklist source definitions (name, URL, enabled/disabled) |
|
||||
| `fail2ban_db_repo.py` | Read historical ban records from the fail2ban SQLite database |
|
||||
| `geo_cache_repo.py` | Persist and query IP geo resolution cache |
|
||||
| `import_log_repo.py` | Record import run results (timestamp, source, IPs imported, errors) for the import log view |
|
||||
|
||||
#### Models (`app/models/`)
|
||||
|
||||
Pydantic schemas that define data shapes and validation. Models are split into three categories per domain:
|
||||
Pydantic schemas that define data shapes and validation. Models are split into three categories per domain.
|
||||
|
||||
- **Request models** — validate incoming API data (e.g., `BanRequest`, `LoginRequest`)
|
||||
- **Response models** — shape outgoing API data (e.g., `JailResponse`, `BanListResponse`)
|
||||
- **Domain models** — internal representations used between services and repositories (e.g., `Ban`, `Jail`)
|
||||
| Model file | Purpose |
|
||||
|---|---|
|
||||
| `auth.py` | Login/request and session models |
|
||||
| `ban.py` | Ban creation and lookup models |
|
||||
| `blocklist.py` | Blocklist source and import log models |
|
||||
| `config.py` | Fail2ban config view/edit models |
|
||||
| `file_config.py` | Raw config file read/write models |
|
||||
| `geo.py` | Geo and ASN lookup models |
|
||||
| `history.py` | Historical ban query and timeline models |
|
||||
| `jail.py` | Jail listing and status models |
|
||||
| `server.py` | Server status and settings models |
|
||||
| `setup.py` | First-run setup wizard models |
|
||||
|
||||
#### Tasks (`app/tasks/`)
|
||||
|
||||
@@ -206,6 +231,7 @@ APScheduler background jobs that run on a schedule without user interaction.
|
||||
|---|---|
|
||||
| `blocklist_import.py` | Downloads all enabled blocklist sources, validates entries, applies bans, records results in the import log |
|
||||
| `geo_cache_flush.py` | Periodically flushes newly resolved IPs from the in-memory dirty set to the `geo_cache` SQLite table (default: every 60 seconds). GET requests populate only the in-memory cache; this task persists them without blocking any request. |
|
||||
| `geo_re_resolve.py` | Periodically re-resolves stale entries in `geo_cache` to keep geolocation data fresh |
|
||||
| `health_check.py` | Periodically pings the fail2ban socket and updates the cached server status so the frontend always has fresh data |
|
||||
|
||||
#### Utils (`app/utils/`)
|
||||
@@ -216,7 +242,16 @@ Pure helper modules with no framework dependencies.
|
||||
|---|---|
|
||||
| `fail2ban_client.py` | Async client that communicates with fail2ban via its Unix domain socket — sends commands and parses responses using the fail2ban protocol. Modelled after [`./fail2ban-master/fail2ban/client/csocket.py`](../fail2ban-master/fail2ban/client/csocket.py) and [`./fail2ban-master/fail2ban/client/fail2banclient.py`](../fail2ban-master/fail2ban/client/fail2banclient.py). |
|
||||
| `ip_utils.py` | Validates IPv4/IPv6 addresses and CIDR ranges using the `ipaddress` stdlib module, normalises formats |
|
||||
| `jail_utils.py` | Jail helper functions for configuration and status inference |
|
||||
| `jail_config.py` | Jail config parser and serializer for fail2ban config manipulation |
|
||||
| `time_utils.py` | Timezone-aware datetime construction, formatting helpers, time-range calculations |
|
||||
| `log_utils.py` | Structured log formatting and enrichment helpers |
|
||||
| `conffile_parser.py` | Parses Fail2ban `.conf` files into structured objects and serialises back to text |
|
||||
| `config_parser.py` | Builds structured config objects from file content tokens |
|
||||
| `config_writer.py` | Atomic config file writes, backups, and safe replace semantics |
|
||||
| `config_file_utils.py` | Common file-level config utility helpers |
|
||||
| `fail2ban_db_utils.py` | Fail2ban DB path discovery and ban-history parsing helpers |
|
||||
| `setup_utils.py` | Setup wizard helper utilities |
|
||||
| `constants.py` | Shared constants: default socket path, default database path, time-range presets, limits |
|
||||
|
||||
#### Configuration (`app/config.py`)
|
||||
|
||||
@@ -52,6 +52,8 @@ The main landing page after login. Shows recent ban activity at a glance.
|
||||
- Last 7 days (week)
|
||||
- Last 30 days (month)
|
||||
- Last 365 days (year)
|
||||
- **Data source selection:** The "Last 24 hours" preset queries fail2ban's live database directly for real-time accuracy. All longer presets (7 days, 30 days, 365 days) query the BanGUI long-term archive, because fail2ban's own database only retains the last 24 hours by default.
|
||||
- A **data-source badge** next to the time-range selector indicates whether the current view is showing **Live (fail2ban DB)** or **Archive (BanGUI DB)** data.
|
||||
|
||||
---
|
||||
|
||||
@@ -70,14 +72,16 @@ A geographical overview of ban activity.
|
||||
- Colors are smoothly interpolated between the thresholds (e.g., 35 bans shows a yellow-green blend)
|
||||
- The color threshold values are configurable through the application settings
|
||||
- **Interactive zoom and pan:** Users can zoom in/out using mouse wheel or touch gestures, and pan by clicking and dragging. This allows detailed inspection of densely-affected regions. Zoom controls (zoom in, zoom out, reset view) are provided as overlay buttons in the top-right corner.
|
||||
- For every country that has bans, the total count is displayed centred inside that country's borders in the selected time range.
|
||||
- Countries with zero banned IPs show no number and no label — they remain blank and transparent.
|
||||
- For every country that has bans, the total count is shown only in the country tooltip, not rendered on the map itself.
|
||||
- Countries with zero banned IPs show no tooltip and remain blank and transparent.
|
||||
- Clicking a country filters the companion table below to show only bans from that country.
|
||||
- Time-range selector with the same quick presets:
|
||||
- Last 24 hours
|
||||
- Last 7 days
|
||||
- Last 30 days
|
||||
- Last 365 days
|
||||
- **Data source selection:** Same rule as the Dashboard — "Last 24 hours" uses the live fail2ban database; all other ranges use the BanGUI archive.
|
||||
- A **data-source badge** is displayed alongside the time-range selector indicating **Live (fail2ban DB)** or **Archive (BanGUI DB)**.
|
||||
|
||||
---
|
||||
|
||||
@@ -220,17 +224,40 @@ A page to inspect and modify the fail2ban configuration without leaving the web
|
||||
- Countries with zero bans remain transparent (no fill).
|
||||
- Changes take effect immediately on the World Map view without requiring a page reload.
|
||||
|
||||
### Log
|
||||
|
||||
- A dedicated **Log** tab on the Configuration page shows fail2ban service health and a live log viewer in one place.
|
||||
- **Service Health panel** (always visible):
|
||||
- Online/offline **badge** (Running / Offline).
|
||||
- When online: version, active jail count, currently banned IPs, and currently failed attempts as stat cards.
|
||||
- Log level and log target displayed as meta labels.
|
||||
- Warning banner when fail2ban is offline, prompting the user to check the server and socket configuration.
|
||||
- **Log Viewer** (shown when fail2ban logs to a file):
|
||||
- Displays the tail of the fail2ban log file in a scrollable monospace container.
|
||||
- Log lines are **color-coded by severity**: errors and critical messages in red, warnings in yellow, debug lines in grey, and informational lines in the default color.
|
||||
- Toolbar controls:
|
||||
- **Filter** — substring input with 300 ms debounce; only lines containing the filter text are shown.
|
||||
- **Lines** — selector for how many tail lines to fetch (100 / 200 / 500 / 1000).
|
||||
- **Refresh** button for an on-demand reload.
|
||||
- **Auto-refresh** toggle with interval selector (5 s / 10 s / 30 s) for live monitoring.
|
||||
- Truncation notice when the total log file line count exceeds the requested tail limit.
|
||||
- Container automatically scrolls to the bottom after each data update.
|
||||
- When fail2ban is configured to log to a non-file target (STDOUT, STDERR, SYSLOG, SYSTEMD-JOURNAL), an informational banner explains that file-based log viewing is unavailable.
|
||||
- The log file path is validated against a safe prefix allowlist on the backend to prevent path-traversal reads.
|
||||
|
||||
---
|
||||
|
||||
## 7. Ban History
|
||||
|
||||
A view for exploring historical ban data stored in the fail2ban database.
|
||||
A view for exploring historical ban data stored in the BanGUI long-term archive.
|
||||
|
||||
### History Table
|
||||
|
||||
- Browse all past bans across all jails, not just the currently active ones.
|
||||
- **Columns:** Time of ban, IP address, jail, ban duration, ban count (how many times this IP was banned), country.
|
||||
- Filter by jail, by IP address, or by time range.
|
||||
- The default time range on first load is **Last 7 days** and the data source is always the **BanGUI archive**, ensuring the full retention window is visible regardless of fail2ban's `dbpurgeage` setting.
|
||||
- A **data-source badge** is displayed indicating **Archive (BanGUI DB)**.
|
||||
- See at a glance which IPs are repeat offenders (high ban count).
|
||||
|
||||
### Per-IP History
|
||||
@@ -238,6 +265,17 @@ A view for exploring historical ban data stored in the fail2ban database.
|
||||
- Select any IP to see its full ban timeline: every ban event, which jail triggered it, when it started, and how long it lasted.
|
||||
- Merged view showing total failures and matched log lines aggregated across all bans for that IP.
|
||||
|
||||
### Persistent Historical Archive
|
||||
|
||||
- BanGUI stores a separate long-term historical ban archive in its own application database, independent from fail2ban's database retention settings.
|
||||
- On each configured sync cycle (default every 5 minutes), BanGUI reads latest entries from fail2ban `bans` table and appends any new events to BanGUI history storage.
|
||||
- Supports both `ban` and `unban` events; audit record includes: `timestamp`, `ip`, `jail`, `action`, `duration`, `origin` (manual, auto, blocklist, etc.), `failures`, `matches`, and optional `country` / `ASN` enrichment.
|
||||
- Includes incremental import logic with dedupe: using unique constraint on (ip, jail, action, timeofban) to prevent duplication across sync cycles.
|
||||
- Provides backfill mode for initial startup: import the last 7.5 days of existing fail2ban history into BanGUI to avoid dark gaps after restart. Requires fail2ban's `dbpurgeage` to be set to at least `648000` (7.5 days) — BanGUI ships with this value pre-configured in its Docker setup.
|
||||
- Includes configurable archive purge policy in BanGUI (default 365 days), separate from fail2ban `dbpurgeage`, to keep app storage bounded while preserving audit data.
|
||||
- Expose API endpoints for querying persistent history, with filters for timeframe, jail, origin, IP, and current ban status.
|
||||
- On fail2ban connectivity failure, BanGUI continues serving historical data; next successful sync resumes ingestion without data loss.
|
||||
|
||||
---
|
||||
|
||||
## 8. External Blocklist Importer
|
||||
|
||||
5
Docs/Refactoring.md
Normal file
5
Docs/Refactoring.md
Normal file
@@ -0,0 +1,5 @@
|
||||
# BanGUI — Architecture Issues & Refactoring Plan
|
||||
|
||||
This document catalogues architecture violations, code smells, and structural issues found during a full project review. Issues are grouped by category and prioritised.
|
||||
|
||||
---
|
||||
287
Docs/Tasks.md
287
Docs/Tasks.md
@@ -2,256 +2,77 @@
|
||||
|
||||
This document breaks the entire BanGUI project into development stages, ordered so that each stage builds on the previous one. Every task is described in prose with enough detail for a developer to begin work. References point to the relevant documentation.
|
||||
|
||||
---
|
||||
|
||||
## Task 1 — Jail Page: Show Only Active Jails (No Inactive Configs)
|
||||
|
||||
**Status:** done
|
||||
|
||||
**Summary:** Backend `GET /api/jails` already only returned active jails (queries fail2ban socket `status` command). Frontend `JailsPage.tsx` updated: removed the "Inactive Jails" section, the "Show inactive" toggle, the `fetchInactiveJails()` call, the `ActivateJailDialog` import/usage, and the `InactiveJail` type import. The Config page (`JailsTab.tsx`) retains full inactive-jail management. All backend tests pass (96/96). TypeScript and ESLint report zero errors. (`JailsPage.tsx`) currently displays inactive jail configurations alongside active jails. Inactive jails — those defined in config files but not running — belong on the **Configuration** page (`ConfigPage.tsx`, Jails tab), not on the operational Jail management page. The Jail page should be a pure operational view: only jails that fail2ban reports as active/running appear here.
|
||||
|
||||
### Goal
|
||||
|
||||
Remove all inactive-jail display and activation UI from the Jail management page. The Jail page shows only jails that are currently loaded in the running fail2ban instance. Users who want to discover and activate inactive jails do so exclusively through the Configuration page's Jails tab.
|
||||
|
||||
### Backend Changes
|
||||
|
||||
1. **Review `GET /api/jails`** in `backend/app/routers/jails.py` and `jail_service.py`. Confirm this endpoint only returns jails that are reported as active by fail2ban via the socket (`status` command). If it already does, no change needed. If it includes inactive/config-only jails in its response, strip them out.
|
||||
2. **No new endpoints needed.** The inactive-jail listing and activation endpoints already live under `/api/config/jails` and `/api/config/jails/{name}/activate` in `config.py` / `config_file_service.py` — those stay as-is for the Config page.
|
||||
|
||||
### Frontend Changes
|
||||
|
||||
3. **`JailsPage.tsx`** — Remove the "Inactive Jails" section, the toggle that reveals inactive jails, and the `fetchInactiveJails()` call. The page should only call `fetchJails()` (which queries `/api/jails`) and render that list. Remove the `ActivateJailDialog` import and usage from this page if present.
|
||||
4. **`JailsPage.tsx`** — Remove any "Activate" buttons or affordances that reference inactive jails. The jail overview table should show: jail name, status (running / stopped / idle), backend type, currently banned count, total bans, currently failed, total failed, find time, ban time, max retries. No "Inactive" badge or "Activate" button.
|
||||
5. **Verify the Config page** (`ConfigPage.tsx` → Jails tab / `JailsTab.tsx`) still shows the full list including inactive jails with Active/Inactive badges and the Activate button. This is the only place where inactive jails are managed. No changes expected here — just verify nothing broke.
|
||||
|
||||
### Tests
|
||||
|
||||
6. **Backend:** If there are existing tests for `GET /api/jails` that assert inactive jails are included, update them so they assert inactive jails are excluded.
|
||||
7. **Frontend:** Update or remove any component tests for the inactive-jail section on `JailsPage`. Ensure Config-page tests for inactive jail activation still pass.
|
||||
|
||||
### Acceptance Criteria
|
||||
|
||||
- The Jail page shows zero inactive jails under any circumstance.
|
||||
- All Jail page data comes only from the fail2ban socket's active jail list.
|
||||
- Inactive-jail discovery and activation remain fully functional on the Configuration page, Jails tab.
|
||||
- No regressions in existing jail control actions (start, stop, reload, idle, ignore-list) on the Jail page.
|
||||
Reference: `Docs/Refactoring.md` for full analysis of each issue.
|
||||
|
||||
---
|
||||
|
||||
## Task 2 — Configuration Subpage: fail2ban Log Viewer & Service Health
|
||||
## Open Issues
|
||||
|
||||
**Status:** not started
|
||||
**References:** [Features.md § 6 — Configuration View](Features.md), [Architekture.md § 2](Architekture.md)
|
||||
### Backend Architecture
|
||||
|
||||
### Problem
|
||||
- **Replace the single shared SQLite connection.**
|
||||
- Current startup code opens one `aiosqlite.Connection` and reuses it for every request.
|
||||
- This should be replaced with either a connection pool or request-scoped connections to avoid concurrency and locking issues.
|
||||
- Update request dependencies, application lifecycle, and tests to use the new pattern.
|
||||
|
||||
There is currently no way to view the fail2ban daemon log (`/var/log/fail2ban.log` or wherever the log target is configured) through the web interface. There is also no dedicated place in the Configuration section that shows at a glance whether fail2ban is running correctly. The existing health probe (`health_service.py`) and dashboard status bar give connectivity info, but the Configuration page should have its own panel showing service health alongside the raw log output.
|
||||
- **Refactor dependency wiring and shared resource management.**
|
||||
- Remove hidden module-level import coupling between routers, services, and shared utilities.
|
||||
- Introduce explicit factories or providers for shared resources such as DB, HTTP client session, scheduler, and settings.
|
||||
- Ensure routers depend on injected providers rather than global state or dynamic imports.
|
||||
|
||||
### Goal
|
||||
- **Harden fail2ban integration.**
|
||||
- Remove the `sys.path` hack that locates `fail2ban-master` at runtime.
|
||||
- Replace it with a deterministic packaging or configuration model so the backend does not depend on repository layout.
|
||||
- Refactor `Fail2BanClient` so concurrency control is instance-based and not backed by hidden module globals.
|
||||
|
||||
Add a new **Log** tab to the Configuration page. This tab shows two things:
|
||||
1. A **Service Health panel** — a compact summary showing whether fail2ban is running, its version, active jail count, total bans, total failures, and the current log level/target. This reuses data from the existing health probe.
|
||||
2. A **Log viewer** — displays the tail of the fail2ban daemon log file with newest entries at the bottom. Supports manual refresh and optional auto-refresh on an interval.
|
||||
- **Improve startup / setup guard behavior.**
|
||||
- Convert `SetupRedirectMiddleware` from an on-demand DB check into a startup/initialisation guard where possible.
|
||||
- Cache setup completion in a safe way and provide an explicit invalidation path if the application state changes.
|
||||
- Reduce middleware responsibility and avoid DB access during normal request dispatch.
|
||||
|
||||
### Backend Changes
|
||||
- **Make deployment configuration explicit.**
|
||||
- Move hard-coded environment assumptions such as CORS origins into settings.
|
||||
- Ensure `fail2ban_socket`, `fail2ban_config_dir`, and startup commands are fully configurable via `Settings`.
|
||||
- Document production-ready defaults separately from development defaults.
|
||||
|
||||
#### New Endpoint: Read fail2ban Log
|
||||
### Reliability and Resilience
|
||||
|
||||
1. **Create `GET /api/config/fail2ban-log`** in `backend/app/routers/config.py` (or a new router file `backend/app/routers/log.py` if `config.py` is getting large).
|
||||
- **Query parameters:**
|
||||
- `lines` (int, default 200, max 2000) — number of lines to return from the tail of the log file.
|
||||
- `filter` (optional string) — a plain-text substring filter; only return lines containing this string (for searching).
|
||||
- **Response model:** `Fail2BanLogResponse` with fields:
|
||||
- `log_path: str` — the resolved path of the log file being read.
|
||||
- `lines: list[str]` — the log lines.
|
||||
- `total_lines: int` — total number of lines in the file (so the UI can indicate if it's truncated).
|
||||
- `log_level: str` — the current fail2ban log level.
|
||||
- `log_target: str` — the current fail2ban log target.
|
||||
- **Behaviour:** Query the fail2ban socket for `get logtarget` to find the current log file path. Read the last N lines from that file using an efficient tail implementation (read from end of file, do not load the entire file into memory). If the log target is not a file (stdout, syslog, systemd-journal), return an informative error explaining that log viewing is only available when fail2ban logs to a file.
|
||||
- **Security:** Validate that the resolved log path is under an expected directory (e.g. `/var/log/`). Do not allow path traversal. Never expose arbitrary file contents.
|
||||
- **Add backend lifecycle tests for resource cleanup.**
|
||||
- Verify startup opens and initialises DB, HTTP session, scheduler, and geo cache correctly.
|
||||
- Verify shutdown closes those resources cleanly.
|
||||
|
||||
2. **Create the service method** `read_fail2ban_log()` in `backend/app/services/config_service.py` (or a new `log_service.py`).
|
||||
- Use `fail2ban_client.py` to query `get logtarget` and `get loglevel`.
|
||||
- Implement an async file tail: open the file, seek to end, read backwards until N newlines are found OR the beginning of the file is reached.
|
||||
- Apply the optional substring filter on the server side before returning.
|
||||
- **Add concurrency/regression coverage for DB and fail2ban socket use.**
|
||||
- Add tests that simulate multiple concurrent requests using the same DB dependency.
|
||||
- Add tests around fail2ban socket retries, protocol errors, and rate limiting.
|
||||
|
||||
3. **Create Pydantic models** in `backend/app/models/config.py`:
|
||||
- `Fail2BanLogResponse(log_path: str, lines: list[str], total_lines: int, log_level: str, log_target: str)`
|
||||
- **Improve state caching and invalidation.**
|
||||
- Add tests for session cache invalidation on logout.
|
||||
- Add tests for setup completion caching so stale state is never served.
|
||||
|
||||
#### Extend Health Data for Config Page
|
||||
### Backend Feature Work
|
||||
|
||||
4. **Create `GET /api/config/service-status`** (or reuse/extend `GET /api/dashboard/status` if appropriate).
|
||||
- Returns: `online` (bool), `version` (str), `jail_count` (int), `total_bans` (int), `total_failures` (int), `log_level` (str), `log_target` (str), `db_path` (str), `uptime` or `start_time` if available.
|
||||
- This can delegate to the existing `health_service.probe()` and augment with the log-level/target info from the socket.
|
||||
- **Document and implement backend-safe environment-driven CORS.**
|
||||
- Add support for production and local development origins through configuration.
|
||||
- Avoid a hardcoded Vite origin in the core app factory.
|
||||
|
||||
### Frontend Changes
|
||||
- **Centralise scheduler job registration.**
|
||||
- Refactor APScheduler registration so background tasks are registered through a common lifecycle helper.
|
||||
- Ensure jobs can be discovered, replaced, and tested without requiring implicit `app.state` side effects.
|
||||
|
||||
#### New Tab: Log
|
||||
- **Strengthen fail2ban error handling and reporting.**
|
||||
- Standardise `502` responses for connection/protocol failures across all endpoints.
|
||||
- Add structured logging for retries and fatal socket failures.
|
||||
- Ensure the UI can distinguish offline fail2ban from internal backend failures.
|
||||
|
||||
5. **Create `frontend/src/components/config/LogTab.tsx`.**
|
||||
- **Service Health panel** at the top:
|
||||
- A status badge: green "Running" or red "Offline".
|
||||
- Version, active jails count, total bans, total failures displayed in a compact row of stat cards.
|
||||
- Current log level and log target shown as labels.
|
||||
- If fail2ban is offline, show a prominent warning banner with the text: "fail2ban is not running or unreachable. Check the server and socket configuration."
|
||||
- **Log viewer** below:
|
||||
- A monospace-font scrollable container showing the log lines.
|
||||
- A toolbar above the log area with:
|
||||
- A **Refresh** button to re-fetch the log.
|
||||
- An **Auto-refresh** toggle (off by default) with a selectable interval (5s, 10s, 30s).
|
||||
- A **Lines** dropdown to choose how many lines to load (100, 200, 500, 1000).
|
||||
- A **Filter** text input to search within the log (sends the filter param to the backend).
|
||||
- Log lines should be syntax-highlighted or at minimum color-coded by log level (ERROR = red, WARNING = yellow, INFO = default, DEBUG = muted).
|
||||
- The container auto-scrolls to the bottom on load and on refresh (since newest entries are at the end).
|
||||
- If the log target is not a file, show an info banner: "fail2ban is logging to [target]. File-based log viewing is not available."
|
||||
- **Improve documentation of backend responsibilities.**
|
||||
- Keep `Docs/Tasks.md` aligned with the backend architecture review.
|
||||
- Add references to the backend modules, resource lifecycle, and dependency model in the documentation.
|
||||
|
||||
6. **Register the tab** in `ConfigPage.tsx`. Add a "Log" tab after the existing tabs (Jails, Filters, Actions, Global, Server, Map, Regex Tester). Use a log-file icon.
|
||||
### Priority Execution Plan
|
||||
|
||||
7. **Create API functions** in `frontend/src/api/config.ts`:
|
||||
- `fetchFail2BanLog(lines?: number, filter?: string): Promise<Fail2BanLogResponse>`
|
||||
- `fetchServiceStatus(): Promise<ServiceStatusResponse>`
|
||||
|
||||
8. **Create TypeScript types** in `frontend/src/types/config.ts` (or wherever config types live):
|
||||
- `Fail2BanLogResponse { log_path: string; lines: string[]; total_lines: number; log_level: string; log_target: string; }`
|
||||
- `ServiceStatusResponse { online: boolean; version: string; jail_count: number; total_bans: number; total_failures: number; log_level: string; log_target: string; }`
|
||||
|
||||
### Tests
|
||||
|
||||
9. **Backend:** Write tests for the new log endpoint — mock the file read, test line-count limiting, test the substring filter, test the error case when log target is not a file, test path-traversal prevention.
|
||||
10. **Backend:** Write tests for the service-status endpoint.
|
||||
11. **Frontend:** Write component tests for `LogTab.tsx` — renders health panel, renders log lines, filter input works, handles offline state.
|
||||
|
||||
### Acceptance Criteria
|
||||
|
||||
- The Configuration page has a new "Log" tab.
|
||||
- The Log tab shows a clear health summary with running/offline state and key metrics.
|
||||
- The Log tab displays the tail of the fail2ban daemon log file.
|
||||
- Users can choose how many lines to display, can refresh manually, and can optionally enable auto-refresh.
|
||||
- Users can filter log lines by substring.
|
||||
- Log lines are visually differentiated by severity level.
|
||||
- If fail2ban logs to a non-file target, a clear message is shown instead of the log viewer.
|
||||
- The log endpoint does not allow reading arbitrary files — only the actual fail2ban log target.
|
||||
|
||||
---
|
||||
|
||||
## Task 3 — Invalid Jail Config Recovery: Detect Broken fail2ban & Auto-Disable Bad Jails
|
||||
|
||||
**Status:** not started
|
||||
**References:** [Features.md § 5 — Jail Management](Features.md), [Features.md § 6 — Configuration View](Features.md), [Architekture.md § 2](Architekture.md)
|
||||
|
||||
### Problem
|
||||
|
||||
When a user activates a jail from the Configuration page, the system writes `enabled = true` to a `.local` override file and triggers a fail2ban reload. If the jail's configuration is invalid (bad regex, missing log file, broken filter reference, syntax error in an action), fail2ban may **refuse to start entirely** — not just skip the one bad jail but stop the whole daemon. At that point every jail is down, all monitoring stops, and the user is locked out of all fail2ban operations in BanGUI.
|
||||
|
||||
The current `activate_jail()` flow in `config_file_service.py` does a post-reload check (queries fail2ban for the jail's status and returns `active=false` if it didn't start), but this only works when fail2ban is still running. If the entire daemon crashes after the reload, the socket is gone and BanGUI cannot query anything. The user sees generic "offline" errors but has no clear path to fix the problem.
|
||||
|
||||
### Goal
|
||||
|
||||
Build a multi-layered safety net that:
|
||||
1. **Pre-validates** the jail config before activating it (catch obvious errors before the reload).
|
||||
2. **Detects** when fail2ban goes down after a jail activation (detect the crash quickly).
|
||||
3. **Alerts** the user with a clear, actionable message explaining which jail was just activated and that it likely caused the failure.
|
||||
4. **Offers a one-click rollback** that disables the bad jail config and restarts fail2ban.
|
||||
|
||||
### Plan
|
||||
|
||||
#### Layer 1: Pre-Activation Validation
|
||||
|
||||
1. **Extend `activate_jail()` in `config_file_service.py`** (or add a new `validate_jail_config()` method) to perform dry-run checks before writing the `.local` file and reloading:
|
||||
- **Filter existence:** Verify the jail's `filter` setting references a filter file that actually exists in `filter.d/`.
|
||||
- **Action existence:** Verify every action referenced by the jail exists in `action.d/`.
|
||||
- **Regex compilation:** Attempt to compile all `failregex` and `ignoreregex` patterns with Python's `re` module. Report which pattern is broken.
|
||||
- **Log path check:** Verify that the log file paths declared in the jail config actually exist on disk and are readable.
|
||||
- **Syntax check:** Parse the full merged config (base + overrides) and check for obvious syntax issues (malformed interpolation, missing required keys).
|
||||
2. **Return validation errors as a structured response** before proceeding with activation. The response should list every issue found so the user can fix them before trying again.
|
||||
3. **Create a new endpoint `POST /api/config/jails/{name}/validate`** that runs only the validation step without actually activating. The frontend can call this for a "Check Config" button.
|
||||
|
||||
#### Layer 2: Post-Activation Health Check
|
||||
|
||||
4. **After each `activate_jail()` reload**, perform a health-check sequence with retries:
|
||||
- Wait 2 seconds after sending the reload command.
|
||||
- Probe the fail2ban socket with `ping`.
|
||||
- If the probe succeeds, check if the specific jail is active.
|
||||
- If the probe fails (socket gone / connection refused), retry up to 3 times with 2-second intervals.
|
||||
- Return the probe result as part of the activation response.
|
||||
5. **Extend the `JailActivationResponse` model** to include:
|
||||
- `fail2ban_running: bool` — whether the fail2ban daemon is still running after reload.
|
||||
- `validation_warnings: list[str]` — any non-fatal warnings from the pre-validation step.
|
||||
- `error: str | None` — a human-readable error message if something went wrong.
|
||||
|
||||
#### Layer 3: Automatic Crash Detection via Background Task
|
||||
|
||||
6. **Extend `tasks/health_check.py`** (the periodic health probe that runs every 30 seconds):
|
||||
- Track the **last known activation event**: when a jail was activated, store its name and timestamp in an in-memory variable (or a lightweight DB record).
|
||||
- If the health check detects that fail2ban transitioned from `online` to `offline`, and a jail was activated within the last 60 seconds, flag this as a **probable activation failure**.
|
||||
- Store a `PendingRecovery` record: `{ jail_name: str, activated_at: datetime, detected_at: datetime, recovered: bool }`.
|
||||
7. **Create a new endpoint `GET /api/config/pending-recovery`** that returns the current `PendingRecovery` record (or `null` if none).
|
||||
- The frontend polls this endpoint (or it is included in the dashboard status response) to detect when a recovery state is active.
|
||||
|
||||
#### Layer 4: User Alert & One-Click Rollback
|
||||
|
||||
8. **Frontend — Global alert banner.** When the health status transitions to offline and a `PendingRecovery` record exists:
|
||||
- Show a **full-width warning banner** at the top of every page (not just the Config page). The banner is dismissible only after the issue is resolved.
|
||||
- Banner text: "fail2ban stopped after activating jail **{name}**. The jail's configuration may be invalid. Disable this jail and restart fail2ban?"
|
||||
- Two buttons:
|
||||
- **"Disable & Restart"** — calls the rollback endpoint (see below).
|
||||
- **"View Details"** — navigates to the Config page Log tab so the user can inspect the fail2ban log for the exact error message.
|
||||
9. **Create a rollback endpoint `POST /api/config/jails/{name}/rollback`** in the backend:
|
||||
- Writes `enabled = false` to the jail's `.local` override (same as `deactivate_jail()` but works even when fail2ban is down since it only writes a file).
|
||||
- Attempts to start (not reload) the fail2ban daemon via the configured start command (e.g. `systemctl start fail2ban` or `fail2ban-client start`). Make the start command configurable in the app settings.
|
||||
- Waits up to 10 seconds for the socket to come back, probing every 2 seconds.
|
||||
- Returns a response indicating whether fail2ban is back online and how many jails are now active.
|
||||
- Clears the `PendingRecovery` record on success.
|
||||
10. **Frontend — Rollback result.** After the rollback call returns:
|
||||
- If successful: show a success toast "fail2ban restarted with {n} active jails. The jail **{name}** has been disabled." and dismiss the banner.
|
||||
- If fail2ban still doesn't start: show an error dialog explaining that the problem may not be limited to the last activated jail. Suggest the user check the fail2ban log (link to the Log tab) or SSH into the server. Keep the banner visible.
|
||||
|
||||
#### Layer 5: Config Page Enhancements
|
||||
|
||||
11. **On the Config page Jails tab**, when activating a jail:
|
||||
- Before activation, show a confirmation dialog that includes any validation warnings from the pre-check.
|
||||
- During activation, show a spinner with the text "Activating jail and verifying fail2ban…" (acknowledge the post-activation health check takes a few seconds).
|
||||
- After activation, if `fail2ban_running` is false in the response, immediately show the recovery banner and rollback option without waiting for the background health check.
|
||||
12. **Add a "Validate" button** next to the "Activate" button on inactive jails. Clicking it calls `POST /api/config/jails/{name}/validate` and shows the validation results in a panel (green for pass, red for each issue found).
|
||||
|
||||
### Backend File Map
|
||||
|
||||
| File | Changes |
|
||||
|---|---|
|
||||
| `services/config_file_service.py` | Add `validate_jail_config()`, extend `activate_jail()` with pre-validation and post-reload health check. |
|
||||
| `routers/config.py` | Add `POST /api/config/jails/{name}/validate`, `GET /api/config/pending-recovery`, `POST /api/config/jails/{name}/rollback`. |
|
||||
| `models/config.py` | Add `JailValidationResult`, `PendingRecovery`, extend `JailActivationResponse`. |
|
||||
| `tasks/health_check.py` | Track last activation event, detect crash-after-activation, write `PendingRecovery` record. |
|
||||
| `services/health_service.py` | Add helper to attempt daemon start (not just probe). |
|
||||
|
||||
### Frontend File Map
|
||||
|
||||
| File | Changes |
|
||||
|---|---|
|
||||
| `components/config/ActivateJailDialog.tsx` | Add pre-validation call, show warnings, show extended activation feedback. |
|
||||
| `components/config/JailsTab.tsx` | Add "Validate" button next to "Activate" for inactive jails. |
|
||||
| `components/common/RecoveryBanner.tsx` (new) | Global warning banner for activation failures with rollback button. |
|
||||
| `pages/AppLayout.tsx` (or root layout) | Mount the `RecoveryBanner` component so it appears on all pages. |
|
||||
| `api/config.ts` | Add `validateJailConfig()`, `fetchPendingRecovery()`, `rollbackJail()`. |
|
||||
| `types/config.ts` | Add `JailValidationResult`, `PendingRecovery`, extend `JailActivationResponse`. |
|
||||
|
||||
### Tests
|
||||
|
||||
13. **Backend:** Test `validate_jail_config()` — valid config passes, missing filter fails, bad regex fails, missing log path fails.
|
||||
14. **Backend:** Test the rollback endpoint — mock file write, mock daemon start, verify response for success and failure cases.
|
||||
15. **Backend:** Test the health-check crash detection — simulate online→offline transition with a recent activation, verify `PendingRecovery` is set.
|
||||
16. **Frontend:** Test `RecoveryBanner` — renders when `PendingRecovery` is present, disappears after successful rollback, shows error on failed rollback.
|
||||
17. **Frontend:** Test the "Validate" button on the Jails tab — shows green on valid, shows errors on invalid.
|
||||
|
||||
### Acceptance Criteria
|
||||
|
||||
- Obvious config errors (missing filter, bad regex, missing log file) are caught **before** the jail is activated.
|
||||
- If fail2ban crashes after a jail activation, BanGUI detects it within 30 seconds and shows a prominent alert.
|
||||
- The user can disable the problematic jail and restart fail2ban with a single click from the alert banner.
|
||||
- If the automatic rollback succeeds, BanGUI confirms fail2ban is back and shows the number of recovered jails.
|
||||
- If the automatic rollback fails, the user is guided to check the log or intervene manually.
|
||||
- A standalone "Validate" button lets users check a jail's config without activating it.
|
||||
- All new endpoints have tests covering success, failure, and edge cases.
|
||||
|
||||
---
|
||||
1. Fix the global SQLite connection pattern and tests.
|
||||
2. Refactor dependency injection / explicit shared resources.
|
||||
3. Harden fail2ban client concurrency and packaging.
|
||||
4. Convert setup guard to a safer startup-driven model.
|
||||
5. Add deployment-safe configuration and production-ready CORS.
|
||||
6. Add lifecycle and concurrency regression tests.
|
||||
|
||||
@@ -210,7 +210,7 @@ Use Fluent UI React components as the building blocks. The following mapping sho
|
||||
|
||||
| Element | Fluent component | Notes |
|
||||
|---|---|---|
|
||||
| Data tables | `DetailsList` | All ban tables, jail overviews, history tables. Enable column sorting, selection, and shimmer loading. |
|
||||
| Data tables | `DetailsList` | All ban tables, jail overviews, history tables. Enable column sorting, selection, and shimmer loading. Use clear pagination controls (page number + prev/next) and a page-size selector (25/50/100) for large result sets. |
|
||||
| Stat cards | `DocumentCard` or custom `Stack` card | Dashboard status bar — server status, total bans, active jails. Use `Depth 4`. |
|
||||
| Status indicators | `Badge` / `Icon` + colour | Server online/offline, jail running/stopped/idle. |
|
||||
| Country labels | Monospaced text + flag emoji or icon | Geo data next to IP addresses. |
|
||||
|
||||
@@ -1 +1,68 @@
|
||||
"""BanGUI backend application package."""
|
||||
"""BanGUI backend application package.
|
||||
|
||||
This package exposes the application version based on the project metadata.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Final
|
||||
|
||||
import importlib.metadata
|
||||
import tomllib
|
||||
|
||||
PACKAGE_NAME: Final[str] = "bangui-backend"
|
||||
|
||||
|
||||
def _read_pyproject_version() -> str:
|
||||
"""Read the project version from ``pyproject.toml``.
|
||||
|
||||
This is used as a fallback when the package metadata is not available (e.g.
|
||||
when running directly from a source checkout without installing the package).
|
||||
"""
|
||||
|
||||
project_root = Path(__file__).resolve().parents[1]
|
||||
pyproject_path = project_root / "pyproject.toml"
|
||||
if not pyproject_path.exists():
|
||||
raise FileNotFoundError(f"pyproject.toml not found at {pyproject_path}")
|
||||
|
||||
data = tomllib.loads(pyproject_path.read_text(encoding="utf-8"))
|
||||
return str(data["project"]["version"])
|
||||
|
||||
|
||||
def _read_docker_version() -> str:
|
||||
"""Read the project version from ``Docker/VERSION``.
|
||||
|
||||
This file is the single source of truth for release scripts and must not be
|
||||
out of sync with the frontend and backend versions.
|
||||
"""
|
||||
|
||||
repo_root = Path(__file__).resolve().parents[2]
|
||||
version_path = repo_root / "Docker" / "VERSION"
|
||||
if not version_path.exists():
|
||||
raise FileNotFoundError(f"Docker/VERSION not found at {version_path}")
|
||||
|
||||
version = version_path.read_text(encoding="utf-8").strip()
|
||||
return version.lstrip("v")
|
||||
|
||||
|
||||
def _read_version() -> str:
|
||||
"""Return the current package version.
|
||||
|
||||
Prefer the release artifact in ``Docker/VERSION`` when available so the
|
||||
backend version always matches what the release tooling publishes.
|
||||
|
||||
If that file is missing (e.g. in a production wheel or a local checkout),
|
||||
fall back to ``pyproject.toml`` and finally installed package metadata.
|
||||
"""
|
||||
|
||||
try:
|
||||
return _read_docker_version()
|
||||
except FileNotFoundError:
|
||||
try:
|
||||
return _read_pyproject_version()
|
||||
except FileNotFoundError:
|
||||
return importlib.metadata.version(PACKAGE_NAME)
|
||||
|
||||
|
||||
__version__ = _read_version()
|
||||
|
||||
@@ -60,6 +60,15 @@ class Settings(BaseSettings):
|
||||
"Used for listing, viewing, and editing configuration files through the web UI."
|
||||
),
|
||||
)
|
||||
fail2ban_start_command: str = Field(
|
||||
default="fail2ban-client start",
|
||||
description=(
|
||||
"Shell command used to start (not reload) the fail2ban daemon during "
|
||||
"recovery rollback. Split by whitespace to build the argument list — "
|
||||
"no shell interpretation is performed. "
|
||||
"Example: 'systemctl start fail2ban' or 'fail2ban-client start'."
|
||||
),
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="BANGUI_",
|
||||
@@ -76,4 +85,4 @@ def get_settings() -> Settings:
|
||||
A validated :class:`Settings` object. Raises :class:`pydantic.ValidationError`
|
||||
if required keys are absent or values fail validation.
|
||||
"""
|
||||
return Settings()
|
||||
return Settings() # type: ignore[call-arg] # pydantic-settings populates required fields from env vars
|
||||
|
||||
@@ -75,6 +75,20 @@ CREATE TABLE IF NOT EXISTS geo_cache (
|
||||
);
|
||||
"""
|
||||
|
||||
_CREATE_HISTORY_ARCHIVE: str = """
|
||||
CREATE TABLE IF NOT EXISTS history_archive (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
jail TEXT NOT NULL,
|
||||
ip TEXT NOT NULL,
|
||||
timeofban INTEGER NOT NULL,
|
||||
bancount INTEGER NOT NULL,
|
||||
data TEXT NOT NULL,
|
||||
action TEXT NOT NULL CHECK(action IN ('ban', 'unban')),
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')),
|
||||
UNIQUE(ip, jail, action, timeofban)
|
||||
);
|
||||
"""
|
||||
|
||||
# Ordered list of DDL statements to execute on initialisation.
|
||||
_SCHEMA_STATEMENTS: list[str] = [
|
||||
_CREATE_SETTINGS,
|
||||
@@ -83,6 +97,7 @@ _SCHEMA_STATEMENTS: list[str] = [
|
||||
_CREATE_BLOCKLIST_SOURCES,
|
||||
_CREATE_IMPORT_LOG,
|
||||
_CREATE_GEO_CACHE,
|
||||
_CREATE_HISTORY_ARCHIVE,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ directly — to keep coupling explicit and testable.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Protocol, cast
|
||||
|
||||
import aiosqlite
|
||||
import structlog
|
||||
@@ -19,6 +19,13 @@ from app.utils.time_utils import utc_now
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
|
||||
class AppState(Protocol):
|
||||
"""Partial view of the FastAPI application state used by dependencies."""
|
||||
|
||||
settings: Settings
|
||||
|
||||
|
||||
_COOKIE_NAME = "bangui_session"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -85,7 +92,8 @@ async def get_settings(request: Request) -> Settings:
|
||||
Returns:
|
||||
The application settings loaded at startup.
|
||||
"""
|
||||
return request.app.state.settings # type: ignore[no-any-return]
|
||||
state = cast("AppState", request.app.state)
|
||||
return state.settings
|
||||
|
||||
|
||||
async def require_auth(
|
||||
|
||||
53
backend/app/exceptions.py
Normal file
53
backend/app/exceptions.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Shared domain exception classes used across routers and services."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class JailNotFoundError(Exception):
|
||||
"""Raised when a requested jail name does not exist."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
super().__init__(f"Jail not found: {name!r}")
|
||||
|
||||
|
||||
class JailOperationError(Exception):
|
||||
"""Raised when a fail2ban jail operation fails."""
|
||||
|
||||
|
||||
class ConfigValidationError(Exception):
|
||||
"""Raised when config values fail validation before applying."""
|
||||
|
||||
|
||||
class ConfigOperationError(Exception):
|
||||
"""Raised when a config payload update or command fails."""
|
||||
|
||||
|
||||
class ServerOperationError(Exception):
|
||||
"""Raised when a server control command (e.g. refresh) fails."""
|
||||
|
||||
|
||||
class FilterInvalidRegexError(Exception):
|
||||
"""Raised when a regex pattern fails to compile."""
|
||||
|
||||
def __init__(self, pattern: str, error: str) -> None:
|
||||
"""Initialize with the invalid pattern and compile error."""
|
||||
self.pattern = pattern
|
||||
self.error = error
|
||||
super().__init__(f"Invalid regex {pattern!r}: {error}")
|
||||
|
||||
|
||||
class JailNotFoundInConfigError(Exception):
|
||||
"""Raised when the requested jail name is not defined in any config file."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
super().__init__(f"Jail not found in config: {name!r}")
|
||||
|
||||
|
||||
class ConfigWriteError(Exception):
|
||||
"""Raised when writing a configuration file modification fails."""
|
||||
|
||||
def __init__(self, message: str) -> None:
|
||||
self.message = message
|
||||
super().__init__(message)
|
||||
@@ -31,6 +31,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app import __version__
|
||||
from app.config import Settings, get_settings
|
||||
from app.db import init_db
|
||||
from app.routers import (
|
||||
@@ -47,8 +48,9 @@ from app.routers import (
|
||||
server,
|
||||
setup,
|
||||
)
|
||||
from app.tasks import blocklist_import, geo_cache_flush, geo_re_resolve, health_check
|
||||
from app.tasks import blocklist_import, geo_cache_flush, geo_re_resolve, health_check, history_sync
|
||||
from app.utils.fail2ban_client import Fail2BanConnectionError, Fail2BanProtocolError
|
||||
from app.utils.jail_config import ensure_jail_configs
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Ensure the bundled fail2ban package is importable from fail2ban-master/
|
||||
@@ -137,7 +139,13 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
|
||||
log.info("bangui_starting_up", database_path=settings.database_path)
|
||||
|
||||
# --- Ensure required jail config files are present ---
|
||||
ensure_jail_configs(Path(settings.fail2ban_config_dir) / "jail.d")
|
||||
|
||||
# --- Application database ---
|
||||
db_path: Path = Path(settings.database_path)
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
log.debug("database_directory_ensured", directory=str(db_path.parent))
|
||||
db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path)
|
||||
db.row_factory = aiosqlite.Row
|
||||
await init_db(db)
|
||||
@@ -154,11 +162,7 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
await geo_service.load_cache_from_db(db)
|
||||
|
||||
# Log unresolved geo entries so the operator can see the scope of the issue.
|
||||
async with db.execute(
|
||||
"SELECT COUNT(*) FROM geo_cache WHERE country_code IS NULL"
|
||||
) as cur:
|
||||
row = await cur.fetchone()
|
||||
unresolved_count: int = int(row[0]) if row else 0
|
||||
unresolved_count = await geo_service.count_unresolved(db)
|
||||
if unresolved_count > 0:
|
||||
log.warning("geo_cache_unresolved_ips", unresolved=unresolved_count)
|
||||
|
||||
@@ -179,6 +183,9 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
# --- Periodic re-resolve of NULL-country geo entries ---
|
||||
geo_re_resolve.register(app)
|
||||
|
||||
# --- Periodic history sync from fail2ban into BanGUI archive ---
|
||||
history_sync.register(app)
|
||||
|
||||
log.info("bangui_started")
|
||||
|
||||
try:
|
||||
@@ -320,17 +327,15 @@ class SetupRedirectMiddleware(BaseHTTPMiddleware):
|
||||
if path.startswith("/api") and not getattr(
|
||||
request.app.state, "_setup_complete_cached", False
|
||||
):
|
||||
db: aiosqlite.Connection | None = getattr(request.app.state, "db", None)
|
||||
if db is not None:
|
||||
from app.services import setup_service # noqa: PLC0415
|
||||
from app.services import setup_service # noqa: PLC0415
|
||||
|
||||
if await setup_service.is_setup_complete(db):
|
||||
request.app.state._setup_complete_cached = True
|
||||
else:
|
||||
return RedirectResponse(
|
||||
url="/api/setup",
|
||||
status_code=status.HTTP_307_TEMPORARY_REDIRECT,
|
||||
)
|
||||
db: aiosqlite.Connection | None = getattr(request.app.state, "db", None)
|
||||
if db is None or not await setup_service.is_setup_complete(db):
|
||||
return RedirectResponse(
|
||||
url="/api/setup",
|
||||
status_code=status.HTTP_307_TEMPORARY_REDIRECT,
|
||||
)
|
||||
request.app.state._setup_complete_cached = True
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
@@ -360,7 +365,7 @@ def create_app(settings: Settings | None = None) -> FastAPI:
|
||||
app: FastAPI = FastAPI(
|
||||
title="BanGUI",
|
||||
description="Web interface for monitoring, managing, and configuring fail2ban.",
|
||||
version="0.1.0",
|
||||
version=__version__,
|
||||
lifespan=_lifespan,
|
||||
)
|
||||
|
||||
|
||||
@@ -306,3 +306,30 @@ class BansByJailResponse(BaseModel):
|
||||
description="Jails ordered by ban count descending.",
|
||||
)
|
||||
total: int = Field(..., ge=0, description="Total ban count in the selected window.")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Jail-specific paginated bans
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class JailBannedIpsResponse(BaseModel):
|
||||
"""Paginated response for ``GET /api/jails/{name}/banned``.
|
||||
|
||||
Contains only the current page of active ban entries for a single jail,
|
||||
geo-enriched exclusively for the page slice to avoid rate-limit issues.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
items: list[ActiveBan] = Field(
|
||||
default_factory=list,
|
||||
description="Active ban entries for the current page.",
|
||||
)
|
||||
total: int = Field(
|
||||
...,
|
||||
ge=0,
|
||||
description="Total matching entries (after applying the search filter).",
|
||||
)
|
||||
page: int = Field(..., ge=1, description="Current page number (1-based).")
|
||||
page_size: int = Field(..., ge=1, description="Number of items per page.")
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
Request, response, and domain models for the config router and service.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -805,6 +807,14 @@ class InactiveJail(BaseModel):
|
||||
"inactive jails that appear in this list."
|
||||
),
|
||||
)
|
||||
has_local_override: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"``True`` when a ``jail.d/{name}.local`` file exists for this jail. "
|
||||
"Only meaningful for inactive jails; indicates that a cleanup action "
|
||||
"is available."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class InactiveJailListResponse(BaseModel):
|
||||
@@ -860,3 +870,140 @@ class JailActivationResponse(BaseModel):
|
||||
description="New activation state: ``True`` after activate, ``False`` after deactivate.",
|
||||
)
|
||||
message: str = Field(..., description="Human-readable result message.")
|
||||
fail2ban_running: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"Whether the fail2ban daemon is still running after the activation "
|
||||
"and reload. ``False`` signals that the daemon may have crashed."
|
||||
),
|
||||
)
|
||||
validation_warnings: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Non-fatal warnings from the pre-activation validation step.",
|
||||
)
|
||||
recovered: bool | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Set when activation failed after writing the config file. "
|
||||
"``True`` means the system automatically rolled back the change and "
|
||||
"restarted fail2ban. ``False`` means the rollback itself also "
|
||||
"failed and manual intervention is required. ``None`` when "
|
||||
"activation succeeded or failed before the file was written."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Jail validation models (Task 3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class JailValidationIssue(BaseModel):
|
||||
"""A single issue found during pre-activation validation of a jail config."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
field: str = Field(
|
||||
...,
|
||||
description="Config field associated with this issue, e.g. 'filter', 'failregex', 'logpath'.",
|
||||
)
|
||||
message: str = Field(..., description="Human-readable description of the issue.")
|
||||
|
||||
|
||||
class JailValidationResult(BaseModel):
|
||||
"""Result of pre-activation validation of a single jail configuration."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
jail_name: str = Field(..., description="Name of the validated jail.")
|
||||
valid: bool = Field(..., description="True when no issues were found.")
|
||||
issues: list[JailValidationIssue] = Field(
|
||||
default_factory=list,
|
||||
description="Validation issues found. Empty when valid=True.",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rollback response model (Task 3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class RollbackResponse(BaseModel):
|
||||
"""Response for ``POST /api/config/jails/{name}/rollback``."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
jail_name: str = Field(..., description="Name of the jail that was disabled.")
|
||||
disabled: bool = Field(
|
||||
...,
|
||||
description="Whether the jail's .local override was successfully written with enabled=false.",
|
||||
)
|
||||
fail2ban_running: bool = Field(
|
||||
...,
|
||||
description="Whether fail2ban is online after the rollback attempt.",
|
||||
)
|
||||
active_jails: int = Field(
|
||||
default=0,
|
||||
ge=0,
|
||||
description="Number of currently active jails after a successful restart.",
|
||||
)
|
||||
message: str = Field(..., description="Human-readable result message.")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pending recovery model (Task 3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class PendingRecovery(BaseModel):
|
||||
"""Records a probable activation-caused fail2ban crash pending user action."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
jail_name: str = Field(
|
||||
...,
|
||||
description="Name of the jail whose activation likely caused the crash.",
|
||||
)
|
||||
activated_at: datetime.datetime = Field(
|
||||
...,
|
||||
description="ISO-8601 UTC timestamp of when the jail was activated.",
|
||||
)
|
||||
detected_at: datetime.datetime = Field(
|
||||
...,
|
||||
description="ISO-8601 UTC timestamp of when the crash was detected.",
|
||||
)
|
||||
recovered: bool = Field(
|
||||
default=False,
|
||||
description="Whether fail2ban has been successfully restarted.",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# fail2ban log viewer models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Fail2BanLogResponse(BaseModel):
|
||||
"""Response for ``GET /api/config/fail2ban-log``."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
log_path: str = Field(..., description="Resolved absolute path of the log file being read.")
|
||||
lines: list[str] = Field(default_factory=list, description="Log lines returned (tail, optionally filtered).")
|
||||
total_lines: int = Field(..., ge=0, description="Total number of lines in the file before filtering.")
|
||||
log_level: str = Field(..., description="Current fail2ban log level.")
|
||||
log_target: str = Field(..., description="Current fail2ban log target (file path or special value).")
|
||||
|
||||
|
||||
class ServiceStatusResponse(BaseModel):
|
||||
"""Response for ``GET /api/config/service-status``."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
online: bool = Field(..., description="Whether fail2ban is reachable via its socket.")
|
||||
version: str | None = Field(default=None, description="BanGUI application version (or None when offline).")
|
||||
jail_count: int = Field(default=0, ge=0, description="Number of currently active jails.")
|
||||
total_bans: int = Field(default=0, ge=0, description="Aggregated current ban count across all jails.")
|
||||
total_failures: int = Field(default=0, ge=0, description="Aggregated current failure count across all jails.")
|
||||
log_level: str = Field(default="UNKNOWN", description="Current fail2ban log level.")
|
||||
log_target: str = Field(default="UNKNOWN", description="Current fail2ban log target.")
|
||||
|
||||
@@ -3,8 +3,18 @@
|
||||
Response models for the ``GET /api/geo/lookup/{ip}`` endpoint.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
|
||||
|
||||
class GeoDetail(BaseModel):
|
||||
"""Enriched geolocation data for an IP address.
|
||||
@@ -64,3 +74,26 @@ class IpLookupResponse(BaseModel):
|
||||
default=None,
|
||||
description="Enriched geographical and network information.",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# shared service types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeoInfo:
|
||||
"""Geo resolution result used throughout backend services."""
|
||||
|
||||
country_code: str | None
|
||||
country_name: str | None
|
||||
asn: str | None
|
||||
org: str | None
|
||||
|
||||
|
||||
GeoEnricher = Callable[[str], Awaitable[GeoInfo | None]]
|
||||
GeoBatchLookup = Callable[
|
||||
[list[str], "aiohttp.ClientSession", "aiosqlite.Connection | None"],
|
||||
Awaitable[dict[str, GeoInfo]],
|
||||
]
|
||||
GeoCacheLookup = Callable[[list[str]], tuple[dict[str, GeoInfo], list[str]]]
|
||||
|
||||
@@ -56,3 +56,7 @@ class ServerSettingsResponse(BaseModel):
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
settings: ServerSettings
|
||||
warnings: dict[str, bool] = Field(
|
||||
default_factory=dict,
|
||||
description="Warnings highlighting potentially unsafe settings.",
|
||||
)
|
||||
|
||||
365
backend/app/repositories/fail2ban_db_repo.py
Normal file
365
backend/app/repositories/fail2ban_db_repo.py
Normal file
@@ -0,0 +1,365 @@
|
||||
"""Fail2Ban SQLite database repository.
|
||||
|
||||
This module contains helper functions that query the read-only fail2ban
|
||||
SQLite database file. All functions accept a *db_path* and manage their own
|
||||
connection using aiosqlite in read-only mode.
|
||||
|
||||
The functions intentionally return plain Python data structures (dataclasses) so
|
||||
service layers can focus on business logic and formatting.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiosqlite
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
||||
from app.models.ban import BanOrigin
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BanRecord:
|
||||
"""A single row from the fail2ban ``bans`` table."""
|
||||
|
||||
jail: str
|
||||
ip: str
|
||||
timeofban: int
|
||||
bancount: int
|
||||
data: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BanIpCount:
|
||||
"""Aggregated ban count for a single IP."""
|
||||
|
||||
ip: str
|
||||
event_count: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JailBanCount:
|
||||
"""Aggregated ban count for a single jail."""
|
||||
|
||||
jail: str
|
||||
count: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class HistoryRecord:
|
||||
"""A single row from the fail2ban ``bans`` table for history queries."""
|
||||
|
||||
jail: str
|
||||
ip: str
|
||||
timeofban: int
|
||||
bancount: int
|
||||
data: str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_db_uri(db_path: str) -> str:
|
||||
"""Return a read-only sqlite URI for the given file path."""
|
||||
|
||||
return f"file:{db_path}?mode=ro"
|
||||
|
||||
|
||||
def _origin_sql_filter(origin: BanOrigin | None) -> tuple[str, tuple[str, ...]]:
|
||||
"""Return a SQL fragment and parameters for the origin filter."""
|
||||
|
||||
if origin == "blocklist":
|
||||
return " AND jail = ?", ("blocklist-import",)
|
||||
if origin == "selfblock":
|
||||
return " AND jail != ?", ("blocklist-import",)
|
||||
return "", ()
|
||||
|
||||
|
||||
def _rows_to_ban_records(rows: Iterable[aiosqlite.Row]) -> list[BanRecord]:
|
||||
return [
|
||||
BanRecord(
|
||||
jail=str(r["jail"]),
|
||||
ip=str(r["ip"]),
|
||||
timeofban=int(r["timeofban"]),
|
||||
bancount=int(r["bancount"]),
|
||||
data=str(r["data"]),
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
def _rows_to_history_records(rows: Iterable[aiosqlite.Row]) -> list[HistoryRecord]:
|
||||
return [
|
||||
HistoryRecord(
|
||||
jail=str(r["jail"]),
|
||||
ip=str(r["ip"]),
|
||||
timeofban=int(r["timeofban"]),
|
||||
bancount=int(r["bancount"]),
|
||||
data=str(r["data"]),
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def check_db_nonempty(db_path: str) -> bool:
|
||||
"""Return True if the fail2ban database contains at least one ban row."""
|
||||
|
||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db, db.execute(
|
||||
"SELECT 1 FROM bans LIMIT 1"
|
||||
) as cur:
|
||||
row = await cur.fetchone()
|
||||
return row is not None
|
||||
|
||||
|
||||
async def get_currently_banned(
|
||||
db_path: str,
|
||||
since: int,
|
||||
origin: BanOrigin | None = None,
|
||||
*,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> tuple[list[BanRecord], int]:
|
||||
"""Return a page of currently banned IPs and the total matching count.
|
||||
|
||||
Args:
|
||||
db_path: File path to the fail2ban SQLite database.
|
||||
since: Unix timestamp to filter bans newer than or equal to.
|
||||
origin: Optional origin filter.
|
||||
limit: Optional maximum number of rows to return.
|
||||
offset: Optional offset for pagination.
|
||||
|
||||
Returns:
|
||||
A ``(records, total)`` tuple.
|
||||
"""
|
||||
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
|
||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
|
||||
async with db.execute(
|
||||
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
|
||||
(since, *origin_params),
|
||||
) as cur:
|
||||
count_row = await cur.fetchone()
|
||||
total: int = int(count_row[0]) if count_row else 0
|
||||
|
||||
query = (
|
||||
"SELECT jail, ip, timeofban, bancount, data "
|
||||
"FROM bans "
|
||||
"WHERE timeofban >= ?" + origin_clause + " ORDER BY timeofban DESC"
|
||||
)
|
||||
params: list[object] = [since, *origin_params]
|
||||
if limit is not None:
|
||||
query += " LIMIT ?"
|
||||
params.append(limit)
|
||||
if offset is not None:
|
||||
query += " OFFSET ?"
|
||||
params.append(offset)
|
||||
|
||||
async with db.execute(query, params) as cur:
|
||||
rows = await cur.fetchall()
|
||||
|
||||
return _rows_to_ban_records(rows), total
|
||||
|
||||
|
||||
async def get_ban_counts_by_bucket(
|
||||
db_path: str,
|
||||
since: int,
|
||||
bucket_secs: int,
|
||||
num_buckets: int,
|
||||
origin: BanOrigin | None = None,
|
||||
) -> list[int]:
|
||||
"""Return ban counts aggregated into equal-width time buckets."""
|
||||
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
|
||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"SELECT CAST((timeofban - ?) / ? AS INTEGER) AS bucket_idx, "
|
||||
"COUNT(*) AS cnt "
|
||||
"FROM bans "
|
||||
"WHERE timeofban >= ?" + origin_clause + " GROUP BY bucket_idx "
|
||||
"ORDER BY bucket_idx",
|
||||
(since, bucket_secs, since, *origin_params),
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
|
||||
counts: list[int] = [0] * num_buckets
|
||||
for row in rows:
|
||||
idx: int = int(row["bucket_idx"])
|
||||
if 0 <= idx < num_buckets:
|
||||
counts[idx] = int(row["cnt"])
|
||||
|
||||
return counts
|
||||
|
||||
|
||||
async def get_ban_event_counts(
|
||||
db_path: str,
|
||||
since: int,
|
||||
origin: BanOrigin | None = None,
|
||||
) -> list[BanIpCount]:
|
||||
"""Return total ban events per unique IP in the window."""
|
||||
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
|
||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"SELECT ip, COUNT(*) AS event_count "
|
||||
"FROM bans "
|
||||
"WHERE timeofban >= ?" + origin_clause + " GROUP BY ip",
|
||||
(since, *origin_params),
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
|
||||
return [
|
||||
BanIpCount(ip=str(r["ip"]), event_count=int(r["event_count"]))
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
async def get_bans_by_jail(
|
||||
db_path: str,
|
||||
since: int,
|
||||
origin: BanOrigin | None = None,
|
||||
) -> tuple[int, list[JailBanCount]]:
|
||||
"""Return per-jail ban counts and the total ban count."""
|
||||
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
|
||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
|
||||
async with db.execute(
|
||||
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
|
||||
(since, *origin_params),
|
||||
) as cur:
|
||||
count_row = await cur.fetchone()
|
||||
total: int = int(count_row[0]) if count_row else 0
|
||||
|
||||
async with db.execute(
|
||||
"SELECT jail, COUNT(*) AS cnt "
|
||||
"FROM bans "
|
||||
"WHERE timeofban >= ?" + origin_clause + " GROUP BY jail ORDER BY cnt DESC",
|
||||
(since, *origin_params),
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
|
||||
return total, [
|
||||
JailBanCount(jail=str(r["jail"]), count=int(r["cnt"])) for r in rows
|
||||
]
|
||||
|
||||
|
||||
async def get_bans_table_summary(
|
||||
db_path: str,
|
||||
) -> tuple[int, int | None, int | None]:
|
||||
"""Return basic summary stats for the ``bans`` table.
|
||||
|
||||
Returns:
|
||||
A tuple ``(row_count, min_timeofban, max_timeofban)``. If the table is
|
||||
empty the min/max values will be ``None``.
|
||||
"""
|
||||
|
||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"SELECT COUNT(*), MIN(timeofban), MAX(timeofban) FROM bans"
|
||||
) as cur:
|
||||
row = await cur.fetchone()
|
||||
|
||||
if row is None:
|
||||
return 0, None, None
|
||||
|
||||
return (
|
||||
int(row[0]),
|
||||
int(row[1]) if row[1] is not None else None,
|
||||
int(row[2]) if row[2] is not None else None,
|
||||
)
|
||||
|
||||
|
||||
async def get_history_page(
|
||||
db_path: str,
|
||||
since: int | None = None,
|
||||
jail: str | None = None,
|
||||
ip_filter: str | None = None,
|
||||
origin: BanOrigin | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 100,
|
||||
) -> tuple[list[HistoryRecord], int]:
|
||||
"""Return a paginated list of history records with total count."""
|
||||
|
||||
wheres: list[str] = []
|
||||
params: list[object] = []
|
||||
|
||||
if since is not None:
|
||||
wheres.append("timeofban >= ?")
|
||||
params.append(since)
|
||||
|
||||
if jail is not None:
|
||||
wheres.append("jail = ?")
|
||||
params.append(jail)
|
||||
|
||||
if ip_filter is not None:
|
||||
wheres.append("ip LIKE ?")
|
||||
params.append(f"{ip_filter}%")
|
||||
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
if origin_clause:
|
||||
origin_clause_clean = origin_clause.removeprefix(" AND ")
|
||||
wheres.append(origin_clause_clean)
|
||||
params.extend(origin_params)
|
||||
|
||||
where_sql: str = ("WHERE " + " AND ".join(wheres)) if wheres else ""
|
||||
|
||||
effective_page_size: int = page_size
|
||||
offset: int = (page - 1) * effective_page_size
|
||||
|
||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
|
||||
async with db.execute(
|
||||
f"SELECT COUNT(*) FROM bans {where_sql}", # noqa: S608
|
||||
params,
|
||||
) as cur:
|
||||
count_row = await cur.fetchone()
|
||||
total: int = int(count_row[0]) if count_row else 0
|
||||
|
||||
async with db.execute(
|
||||
f"SELECT jail, ip, timeofban, bancount, data "
|
||||
f"FROM bans {where_sql} "
|
||||
"ORDER BY timeofban DESC "
|
||||
"LIMIT ? OFFSET ?",
|
||||
[*params, effective_page_size, offset],
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
|
||||
return _rows_to_history_records(rows), total
|
||||
|
||||
|
||||
async def get_history_for_ip(db_path: str, ip: str) -> list[HistoryRecord]:
|
||||
"""Return the full ban timeline for a specific IP."""
|
||||
|
||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
async with db.execute(
|
||||
"SELECT jail, ip, timeofban, bancount, data "
|
||||
"FROM bans "
|
||||
"WHERE ip = ? "
|
||||
"ORDER BY timeofban DESC",
|
||||
(ip,),
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
|
||||
return _rows_to_history_records(rows)
|
||||
148
backend/app/repositories/geo_cache_repo.py
Normal file
148
backend/app/repositories/geo_cache_repo.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""Repository for the geo cache persistent store.
|
||||
|
||||
This module provides typed, async helpers for querying and mutating the
|
||||
``geo_cache`` table in the BanGUI application database.
|
||||
|
||||
All functions accept an open :class:`aiosqlite.Connection` and do not manage
|
||||
connection lifetimes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
import aiosqlite
|
||||
|
||||
|
||||
class GeoCacheRow(TypedDict):
|
||||
"""A single row from the ``geo_cache`` table."""
|
||||
|
||||
ip: str
|
||||
country_code: str | None
|
||||
country_name: str | None
|
||||
asn: str | None
|
||||
org: str | None
|
||||
|
||||
|
||||
async def load_all(db: aiosqlite.Connection) -> list[GeoCacheRow]:
|
||||
"""Load all geo cache rows from the database.
|
||||
|
||||
Args:
|
||||
db: Open BanGUI application database connection.
|
||||
|
||||
Returns:
|
||||
List of rows from the ``geo_cache`` table.
|
||||
"""
|
||||
rows: list[GeoCacheRow] = []
|
||||
async with db.execute(
|
||||
"SELECT ip, country_code, country_name, asn, org FROM geo_cache"
|
||||
) as cur:
|
||||
async for row in cur:
|
||||
rows.append(
|
||||
GeoCacheRow(
|
||||
ip=str(row[0]),
|
||||
country_code=row[1],
|
||||
country_name=row[2],
|
||||
asn=row[3],
|
||||
org=row[4],
|
||||
)
|
||||
)
|
||||
return rows
|
||||
|
||||
|
||||
async def get_unresolved_ips(db: aiosqlite.Connection) -> list[str]:
|
||||
"""Return all IPs in ``geo_cache`` where ``country_code`` is NULL.
|
||||
|
||||
Args:
|
||||
db: Open BanGUI application database connection.
|
||||
|
||||
Returns:
|
||||
List of IPv4/IPv6 strings that need geo resolution.
|
||||
"""
|
||||
ips: list[str] = []
|
||||
async with db.execute(
|
||||
"SELECT ip FROM geo_cache WHERE country_code IS NULL"
|
||||
) as cur:
|
||||
async for row in cur:
|
||||
ips.append(str(row[0]))
|
||||
return ips
|
||||
|
||||
|
||||
async def count_unresolved(db: aiosqlite.Connection) -> int:
|
||||
"""Return the number of unresolved rows (country_code IS NULL)."""
|
||||
async with db.execute(
|
||||
"SELECT COUNT(*) FROM geo_cache WHERE country_code IS NULL"
|
||||
) as cur:
|
||||
row = await cur.fetchone()
|
||||
return int(row[0]) if row else 0
|
||||
|
||||
|
||||
async def upsert_entry(
|
||||
db: aiosqlite.Connection,
|
||||
ip: str,
|
||||
country_code: str | None,
|
||||
country_name: str | None,
|
||||
asn: str | None,
|
||||
org: str | None,
|
||||
) -> None:
|
||||
"""Insert or update a resolved geo cache entry."""
|
||||
await db.execute(
|
||||
"""
|
||||
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(ip) DO UPDATE SET
|
||||
country_code = excluded.country_code,
|
||||
country_name = excluded.country_name,
|
||||
asn = excluded.asn,
|
||||
org = excluded.org,
|
||||
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
|
||||
""",
|
||||
(ip, country_code, country_name, asn, org),
|
||||
)
|
||||
|
||||
|
||||
async def upsert_neg_entry(db: aiosqlite.Connection, ip: str) -> None:
|
||||
"""Record a failed lookup attempt as a negative entry."""
|
||||
await db.execute(
|
||||
"INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)",
|
||||
(ip,),
|
||||
)
|
||||
|
||||
|
||||
async def bulk_upsert_entries(
|
||||
db: aiosqlite.Connection,
|
||||
rows: Sequence[tuple[str, str | None, str | None, str | None, str | None]],
|
||||
) -> int:
|
||||
"""Bulk insert or update multiple geo cache entries."""
|
||||
if not rows:
|
||||
return 0
|
||||
|
||||
await db.executemany(
|
||||
"""
|
||||
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(ip) DO UPDATE SET
|
||||
country_code = excluded.country_code,
|
||||
country_name = excluded.country_name,
|
||||
asn = excluded.asn,
|
||||
org = excluded.org,
|
||||
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
|
||||
""",
|
||||
rows,
|
||||
)
|
||||
return len(rows)
|
||||
|
||||
|
||||
async def bulk_upsert_neg_entries(db: aiosqlite.Connection, ips: list[str]) -> int:
|
||||
"""Bulk insert negative lookup entries."""
|
||||
if not ips:
|
||||
return 0
|
||||
|
||||
await db.executemany(
|
||||
"INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)",
|
||||
[(ip,) for ip in ips],
|
||||
)
|
||||
return len(ips)
|
||||
148
backend/app/repositories/history_archive_repo.py
Normal file
148
backend/app/repositories/history_archive_repo.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""Ban history archive repository.
|
||||
|
||||
Provides persistence APIs for the BanGUI archival history table in the
|
||||
application database.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.models.ban import BLOCKLIST_JAIL, BanOrigin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiosqlite
|
||||
|
||||
|
||||
async def archive_ban_event(
|
||||
db: aiosqlite.Connection,
|
||||
jail: str,
|
||||
ip: str,
|
||||
timeofban: int,
|
||||
bancount: int,
|
||||
data: str,
|
||||
action: str = "ban",
|
||||
) -> bool:
|
||||
"""Insert a new archived ban/unban event, ignoring duplicates."""
|
||||
async with db.execute(
|
||||
"""INSERT OR IGNORE INTO history_archive
|
||||
(jail, ip, timeofban, bancount, data, action)
|
||||
VALUES (?, ?, ?, ?, ?, ?)""",
|
||||
(jail, ip, timeofban, bancount, data, action),
|
||||
) as cursor:
|
||||
inserted = cursor.rowcount == 1
|
||||
await db.commit()
|
||||
return inserted
|
||||
|
||||
|
||||
async def get_archived_history(
|
||||
db: aiosqlite.Connection,
|
||||
since: int | None = None,
|
||||
jail: str | None = None,
|
||||
ip_filter: str | None = None,
|
||||
origin: BanOrigin | None = None,
|
||||
action: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 100,
|
||||
) -> tuple[list[dict], int]:
|
||||
"""Return a paginated archived history result set."""
|
||||
wheres: list[str] = []
|
||||
params: list[object] = []
|
||||
|
||||
if since is not None:
|
||||
wheres.append("timeofban >= ?")
|
||||
params.append(since)
|
||||
|
||||
if jail is not None:
|
||||
wheres.append("jail = ?")
|
||||
params.append(jail)
|
||||
|
||||
if ip_filter is not None:
|
||||
wheres.append("ip LIKE ?")
|
||||
params.append(f"{ip_filter}%")
|
||||
|
||||
if origin == "blocklist":
|
||||
wheres.append("jail = ?")
|
||||
params.append(BLOCKLIST_JAIL)
|
||||
elif origin == "selfblock":
|
||||
wheres.append("jail != ?")
|
||||
params.append(BLOCKLIST_JAIL)
|
||||
|
||||
if action is not None:
|
||||
wheres.append("action = ?")
|
||||
params.append(action)
|
||||
|
||||
where_sql = "WHERE " + " AND ".join(wheres) if wheres else ""
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
async with db.execute(f"SELECT COUNT(*) FROM history_archive {where_sql}", params) as cur:
|
||||
row = await cur.fetchone()
|
||||
total = int(row[0]) if row is not None and row[0] is not None else 0
|
||||
|
||||
async with db.execute(
|
||||
"SELECT jail, ip, timeofban, bancount, data, action "
|
||||
"FROM history_archive "
|
||||
f"{where_sql} "
|
||||
"ORDER BY timeofban DESC LIMIT ? OFFSET ?",
|
||||
[*params, page_size, offset],
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
|
||||
records = [
|
||||
{
|
||||
"jail": str(r[0]),
|
||||
"ip": str(r[1]),
|
||||
"timeofban": int(r[2]),
|
||||
"bancount": int(r[3]),
|
||||
"data": str(r[4]),
|
||||
"action": str(r[5]),
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
return records, total
|
||||
|
||||
|
||||
async def get_all_archived_history(
|
||||
db: aiosqlite.Connection,
|
||||
since: int | None = None,
|
||||
jail: str | None = None,
|
||||
ip_filter: str | None = None,
|
||||
origin: BanOrigin | None = None,
|
||||
action: str | None = None,
|
||||
) -> list[dict]:
|
||||
"""Return all archived history rows for the given filters."""
|
||||
page: int = 1
|
||||
page_size: int = 500
|
||||
all_rows: list[dict] = []
|
||||
|
||||
while True:
|
||||
rows, total = await get_archived_history(
|
||||
db=db,
|
||||
since=since,
|
||||
jail=jail,
|
||||
ip_filter=ip_filter,
|
||||
origin=origin,
|
||||
action=action,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
all_rows.extend(rows)
|
||||
if len(rows) < page_size:
|
||||
break
|
||||
page += 1
|
||||
|
||||
return all_rows
|
||||
|
||||
|
||||
async def purge_archived_history(db: aiosqlite.Connection, age_seconds: int) -> int:
|
||||
"""Purge archived entries older than *age_seconds*; return rows deleted."""
|
||||
threshold = int(datetime.datetime.now(datetime.UTC).timestamp()) - age_seconds
|
||||
async with db.execute(
|
||||
"DELETE FROM history_archive WHERE timeofban < ?",
|
||||
(threshold,),
|
||||
) as cursor:
|
||||
deleted = cursor.rowcount
|
||||
await db.commit()
|
||||
return deleted
|
||||
@@ -8,12 +8,26 @@ table. All methods are plain async functions that accept a
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, TypedDict, cast
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping
|
||||
|
||||
import aiosqlite
|
||||
|
||||
|
||||
class ImportLogRow(TypedDict):
|
||||
"""Row shape returned by queries on the import_log table."""
|
||||
|
||||
id: int
|
||||
source_id: int | None
|
||||
source_url: str
|
||||
timestamp: str
|
||||
ips_imported: int
|
||||
ips_skipped: int
|
||||
errors: str | None
|
||||
|
||||
|
||||
async def add_log(
|
||||
db: aiosqlite.Connection,
|
||||
*,
|
||||
@@ -54,7 +68,7 @@ async def list_logs(
|
||||
source_id: int | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
) -> tuple[list[ImportLogRow], int]:
|
||||
"""Return a paginated list of import log entries.
|
||||
|
||||
Args:
|
||||
@@ -68,8 +82,8 @@ async def list_logs(
|
||||
*total* is the count of all matching rows (ignoring pagination).
|
||||
"""
|
||||
where = ""
|
||||
params_count: list[Any] = []
|
||||
params_rows: list[Any] = []
|
||||
params_count: list[object] = []
|
||||
params_rows: list[object] = []
|
||||
|
||||
if source_id is not None:
|
||||
where = " WHERE source_id = ?"
|
||||
@@ -102,7 +116,7 @@ async def list_logs(
|
||||
return items, total
|
||||
|
||||
|
||||
async def get_last_log(db: aiosqlite.Connection) -> dict[str, Any] | None:
|
||||
async def get_last_log(db: aiosqlite.Connection) -> ImportLogRow | None:
|
||||
"""Return the most recent import log entry across all sources.
|
||||
|
||||
Args:
|
||||
@@ -143,13 +157,14 @@ def compute_total_pages(total: int, page_size: int) -> int:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _row_to_dict(row: Any) -> dict[str, Any]:
|
||||
def _row_to_dict(row: object) -> ImportLogRow:
|
||||
"""Convert an aiosqlite row to a plain Python dict.
|
||||
|
||||
Args:
|
||||
row: An :class:`aiosqlite.Row` or sequence returned by a cursor.
|
||||
row: An :class:`aiosqlite.Row` or similar mapping returned by a cursor.
|
||||
|
||||
Returns:
|
||||
Dict mapping column names to Python values.
|
||||
"""
|
||||
return dict(row)
|
||||
mapping = cast("Mapping[str, object]", row)
|
||||
return cast("ImportLogRow", dict(mapping))
|
||||
|
||||
@@ -20,8 +20,8 @@ from fastapi import APIRouter, HTTPException, Request, status
|
||||
from app.dependencies import AuthDep
|
||||
from app.models.ban import ActiveBanListResponse, BanRequest, UnbanAllResponse, UnbanRequest
|
||||
from app.models.jail import JailCommandResponse
|
||||
from app.services import jail_service
|
||||
from app.services.jail_service import JailNotFoundError, JailOperationError
|
||||
from app.services import geo_service, jail_service
|
||||
from app.exceptions import JailNotFoundError, JailOperationError
|
||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/api/bans", tags=["Bans"])
|
||||
@@ -73,6 +73,7 @@ async def get_active_bans(
|
||||
try:
|
||||
return await jail_service.get_active_bans(
|
||||
socket_path,
|
||||
geo_batch_lookup=geo_service.lookup_batch,
|
||||
http_session=http_session,
|
||||
app_db=app_db,
|
||||
)
|
||||
|
||||
@@ -42,8 +42,7 @@ from app.models.blocklist import (
|
||||
ScheduleConfig,
|
||||
ScheduleInfo,
|
||||
)
|
||||
from app.repositories import import_log_repo
|
||||
from app.services import blocklist_service
|
||||
from app.services import blocklist_service, geo_service
|
||||
from app.tasks import blocklist_import as blocklist_import_task
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/api/blocklists", tags=["Blocklists"])
|
||||
@@ -132,7 +131,15 @@ async def run_import_now(
|
||||
"""
|
||||
http_session: aiohttp.ClientSession = request.app.state.http_session
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
return await blocklist_service.import_all(db, http_session, socket_path)
|
||||
from app.services import jail_service
|
||||
|
||||
return await blocklist_service.import_all(
|
||||
db,
|
||||
http_session,
|
||||
socket_path,
|
||||
geo_is_cached=geo_service.is_cached,
|
||||
geo_batch_lookup=geo_service.lookup_batch,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -225,19 +232,9 @@ async def get_import_log(
|
||||
Returns:
|
||||
:class:`~app.models.blocklist.ImportLogListResponse`.
|
||||
"""
|
||||
items, total = await import_log_repo.list_logs(
|
||||
return await blocklist_service.list_import_logs(
|
||||
db, source_id=source_id, page=page, page_size=page_size
|
||||
)
|
||||
total_pages = import_log_repo.compute_total_pages(total, page_size)
|
||||
from app.models.blocklist import ImportLogEntry # noqa: PLC0415
|
||||
|
||||
return ImportLogListResponse(
|
||||
items=[ImportLogEntry.model_validate(i) for i in items],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -9,6 +9,9 @@ global settings, test regex patterns, add log paths, and preview log files.
|
||||
* ``GET /api/config/jails/inactive`` — list all inactive jails
|
||||
* ``POST /api/config/jails/{name}/activate`` — activate an inactive jail
|
||||
* ``POST /api/config/jails/{name}/deactivate`` — deactivate an active jail
|
||||
* ``POST /api/config/jails/{name}/validate`` — validate jail config pre-activation (Task 3)
|
||||
* ``POST /api/config/jails/{name}/rollback`` — disable bad jail and restart fail2ban (Task 3)
|
||||
* ``GET /api/config/pending-recovery`` — active crash-recovery record (Task 3)
|
||||
* ``POST /api/config/jails/{name}/filter`` — assign a filter to a jail
|
||||
* ``POST /api/config/jails/{name}/action`` — add an action to a jail
|
||||
* ``DELETE /api/config/jails/{name}/action/{action_name}`` — remove an action from a jail
|
||||
@@ -28,12 +31,16 @@ global settings, test regex patterns, add log paths, and preview log files.
|
||||
* ``PUT /api/config/actions/{name}`` — update an action's .local override
|
||||
* ``POST /api/config/actions`` — create a new user-defined action
|
||||
* ``DELETE /api/config/actions/{name}`` — delete an action's .local file
|
||||
* ``GET /api/config/fail2ban-log`` — read the tail of the fail2ban log file
|
||||
* ``GET /api/config/service-status`` — fail2ban health + log configuration
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from typing import Annotated
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, HTTPException, Path, Query, Request, status
|
||||
|
||||
from app.dependencies import AuthDep
|
||||
@@ -46,6 +53,7 @@ from app.models.config import (
|
||||
AddLogPathRequest,
|
||||
AssignActionRequest,
|
||||
AssignFilterRequest,
|
||||
Fail2BanLogResponse,
|
||||
FilterConfig,
|
||||
FilterCreateRequest,
|
||||
FilterListResponse,
|
||||
@@ -57,37 +65,50 @@ from app.models.config import (
|
||||
JailConfigListResponse,
|
||||
JailConfigResponse,
|
||||
JailConfigUpdate,
|
||||
JailValidationResult,
|
||||
LogPreviewRequest,
|
||||
LogPreviewResponse,
|
||||
MapColorThresholdsResponse,
|
||||
MapColorThresholdsUpdate,
|
||||
PendingRecovery,
|
||||
RegexTestRequest,
|
||||
RegexTestResponse,
|
||||
RollbackResponse,
|
||||
ServiceStatusResponse,
|
||||
)
|
||||
from app.services import config_file_service, config_service, jail_service
|
||||
from app.services.config_file_service import (
|
||||
from app.services import config_service, jail_service, log_service
|
||||
from app.services import (
|
||||
action_config_service,
|
||||
config_file_service,
|
||||
filter_config_service,
|
||||
jail_config_service,
|
||||
)
|
||||
from app.services.action_config_service import (
|
||||
ActionAlreadyExistsError,
|
||||
ActionNameError,
|
||||
ActionNotFoundError,
|
||||
ActionReadonlyError,
|
||||
ConfigWriteError,
|
||||
)
|
||||
from app.services.filter_config_service import (
|
||||
FilterAlreadyExistsError,
|
||||
FilterInvalidRegexError,
|
||||
FilterNameError,
|
||||
FilterNotFoundError,
|
||||
FilterReadonlyError,
|
||||
)
|
||||
from app.services.jail_config_service import (
|
||||
JailAlreadyActiveError,
|
||||
JailAlreadyInactiveError,
|
||||
JailNameError,
|
||||
JailNotFoundInConfigError,
|
||||
)
|
||||
from app.services.config_service import (
|
||||
ConfigOperationError,
|
||||
ConfigValidationError,
|
||||
JailNotFoundError,
|
||||
)
|
||||
from app.exceptions import ConfigOperationError, ConfigValidationError, JailNotFoundError, JailOperationError
|
||||
from app.tasks.health_check import _run_probe
|
||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/api/config", tags=["Config"])
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -182,7 +203,7 @@ async def get_inactive_jails(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
return await config_file_service.list_inactive_jails(config_dir, socket_path)
|
||||
return await jail_config_service.list_inactive_jails(config_dir, socket_path)
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -345,15 +366,86 @@ async def reload_fail2ban(
|
||||
_auth: Validated session.
|
||||
|
||||
Raises:
|
||||
HTTPException: 409 when fail2ban reports the reload failed.
|
||||
HTTPException: 502 when fail2ban is unreachable.
|
||||
"""
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
except JailOperationError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"fail2ban reload failed: {exc}",
|
||||
) from exc
|
||||
except Fail2BanConnectionError as exc:
|
||||
raise _bad_gateway(exc) from exc
|
||||
|
||||
|
||||
# Restart endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post(
|
||||
"/restart",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Restart the fail2ban service",
|
||||
)
|
||||
async def restart_fail2ban(
|
||||
request: Request,
|
||||
_auth: AuthDep,
|
||||
) -> None:
|
||||
"""Trigger a full fail2ban service restart.
|
||||
|
||||
Stops the fail2ban daemon via the Unix domain socket, then starts it
|
||||
again using the configured ``fail2ban_start_command``. After starting,
|
||||
probes the socket for up to 10 seconds to confirm the daemon came back
|
||||
online.
|
||||
|
||||
Args:
|
||||
request: Incoming request.
|
||||
_auth: Validated session.
|
||||
|
||||
Raises:
|
||||
HTTPException: 409 when fail2ban reports the stop command failed.
|
||||
HTTPException: 502 when fail2ban is unreachable for the stop command.
|
||||
HTTPException: 503 when fail2ban does not come back online within
|
||||
10 seconds after being started. Check the fail2ban log for
|
||||
initialisation errors. Use
|
||||
``POST /api/config/jails/{name}/rollback`` if a specific jail
|
||||
is suspect.
|
||||
"""
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
start_cmd: str = request.app.state.settings.fail2ban_start_command
|
||||
start_cmd_parts: list[str] = start_cmd.split()
|
||||
|
||||
# Step 1: stop the daemon via socket.
|
||||
try:
|
||||
await jail_service.restart(socket_path)
|
||||
except JailOperationError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"fail2ban stop command failed: {exc}",
|
||||
) from exc
|
||||
except Fail2BanConnectionError as exc:
|
||||
raise _bad_gateway(exc) from exc
|
||||
|
||||
# Step 2: start the daemon via subprocess.
|
||||
await config_file_service.start_daemon(start_cmd_parts)
|
||||
|
||||
# Step 3: probe the socket until fail2ban is responsive or the budget expires.
|
||||
fail2ban_running: bool = await config_file_service.wait_for_fail2ban(socket_path, max_wait_seconds=10.0)
|
||||
if not fail2ban_running:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=(
|
||||
"fail2ban was stopped but did not come back online within 10 seconds. "
|
||||
"Check the fail2ban log for initialisation errors. "
|
||||
"Use POST /api/config/jails/{name}/rollback if a specific jail is suspect."
|
||||
),
|
||||
)
|
||||
log.info("fail2ban_restarted")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regex tester (stateless)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -380,7 +472,7 @@ async def regex_test(
|
||||
Returns:
|
||||
:class:`~app.models.config.RegexTestResponse` with match result and groups.
|
||||
"""
|
||||
return config_service.test_regex(body)
|
||||
return log_service.test_regex(body)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -486,7 +578,7 @@ async def preview_log(
|
||||
Returns:
|
||||
:class:`~app.models.config.LogPreviewResponse` with per-line results.
|
||||
"""
|
||||
return await config_service.preview_log(body)
|
||||
return await log_service.preview_log(body)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -515,9 +607,7 @@ async def get_map_color_thresholds(
|
||||
"""
|
||||
from app.services import setup_service
|
||||
|
||||
high, medium, low = await setup_service.get_map_color_thresholds(
|
||||
request.app.state.db
|
||||
)
|
||||
high, medium, low = await setup_service.get_map_color_thresholds(request.app.state.db)
|
||||
return MapColorThresholdsResponse(
|
||||
threshold_high=high,
|
||||
threshold_medium=medium,
|
||||
@@ -607,9 +697,7 @@ async def activate_jail(
|
||||
req = body if body is not None else ActivateJailRequest()
|
||||
|
||||
try:
|
||||
return await config_file_service.activate_jail(
|
||||
config_dir, socket_path, name, req
|
||||
)
|
||||
result = await jail_config_service.activate_jail(config_dir, socket_path, name, req)
|
||||
except JailNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except JailNotFoundInConfigError:
|
||||
@@ -627,6 +715,28 @@ async def activate_jail(
|
||||
except Fail2BanConnectionError as exc:
|
||||
raise _bad_gateway(exc) from exc
|
||||
|
||||
# Record this activation so the health-check task can attribute a
|
||||
# subsequent fail2ban crash to it.
|
||||
request.app.state.last_activation = {
|
||||
"jail_name": name,
|
||||
"at": datetime.datetime.now(tz=datetime.UTC),
|
||||
}
|
||||
|
||||
# If fail2ban stopped responding after the reload, create a pending-recovery
|
||||
# record immediately (before the background health task notices).
|
||||
if not result.fail2ban_running:
|
||||
request.app.state.pending_recovery = PendingRecovery(
|
||||
jail_name=name,
|
||||
activated_at=request.app.state.last_activation["at"],
|
||||
detected_at=datetime.datetime.now(tz=datetime.UTC),
|
||||
)
|
||||
|
||||
# Force an immediate health probe so the cached status reflects the current
|
||||
# fail2ban state without waiting for the next scheduled check.
|
||||
await _run_probe(request.app)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post(
|
||||
"/jails/{name}/deactivate",
|
||||
@@ -661,7 +771,7 @@ async def deactivate_jail(
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
|
||||
try:
|
||||
return await config_file_service.deactivate_jail(config_dir, socket_path, name)
|
||||
result = await jail_config_service.deactivate_jail(config_dir, socket_path, name)
|
||||
except JailNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except JailNotFoundInConfigError:
|
||||
@@ -679,6 +789,182 @@ async def deactivate_jail(
|
||||
except Fail2BanConnectionError as exc:
|
||||
raise _bad_gateway(exc) from exc
|
||||
|
||||
# Force an immediate health probe so the cached status reflects the current
|
||||
# fail2ban state (reload changes the active-jail count) without waiting for
|
||||
# the next scheduled background check (up to 30 seconds).
|
||||
await _run_probe(request.app)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/jails/{name}/local",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete the jail.d override file for an inactive jail",
|
||||
)
|
||||
async def delete_jail_local_override(
|
||||
request: Request,
|
||||
_auth: AuthDep,
|
||||
name: _NamePath,
|
||||
) -> None:
|
||||
"""Remove the ``jail.d/{name}.local`` override file for an inactive jail.
|
||||
|
||||
This endpoint is the clean-up action for inactive jails that still carry
|
||||
a ``.local`` override file (e.g. one written with ``enabled = false`` by a
|
||||
previous deactivation). The file is deleted without modifying fail2ban's
|
||||
running state, since the jail is already inactive.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object.
|
||||
_auth: Validated session.
|
||||
name: Name of the jail whose ``.local`` file should be removed.
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if *name* contains invalid characters.
|
||||
HTTPException: 404 if *name* is not found in any config file.
|
||||
HTTPException: 409 if the jail is currently active.
|
||||
HTTPException: 500 if the file cannot be deleted.
|
||||
HTTPException: 502 if fail2ban is unreachable.
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
|
||||
try:
|
||||
await jail_config_service.delete_jail_local_override(config_dir, socket_path, name)
|
||||
except JailNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except JailNotFoundInConfigError:
|
||||
raise _not_found(name) from None
|
||||
except JailAlreadyActiveError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Jail {name!r} is currently active; deactivate it first.",
|
||||
) from None
|
||||
except ConfigWriteError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to delete config override: {exc}",
|
||||
) from exc
|
||||
except Fail2BanConnectionError as exc:
|
||||
raise _bad_gateway(exc) from exc
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Jail validation & rollback endpoints (Task 3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post(
|
||||
"/jails/{name}/validate",
|
||||
response_model=JailValidationResult,
|
||||
summary="Validate jail configuration before activation",
|
||||
)
|
||||
async def validate_jail(
|
||||
request: Request,
|
||||
_auth: AuthDep,
|
||||
name: _NamePath,
|
||||
) -> JailValidationResult:
|
||||
"""Run pre-activation validation checks on a jail configuration.
|
||||
|
||||
Validates filter and action file existence, regex pattern compilation, and
|
||||
log path existence without modifying any files or reloading fail2ban.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object.
|
||||
_auth: Validated session.
|
||||
name: Jail name to validate.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.JailValidationResult` with any issues found.
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if *name* contains invalid characters.
|
||||
HTTPException: 404 if *name* is not found in any config file.
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
return await jail_config_service.validate_jail_config(config_dir, name)
|
||||
except JailNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
|
||||
|
||||
@router.get(
|
||||
"/pending-recovery",
|
||||
response_model=PendingRecovery | None,
|
||||
summary="Return active crash-recovery record if one exists",
|
||||
)
|
||||
async def get_pending_recovery(
|
||||
request: Request,
|
||||
_auth: AuthDep,
|
||||
) -> PendingRecovery | None:
|
||||
"""Return the current :class:`~app.models.config.PendingRecovery` record.
|
||||
|
||||
A non-null response means fail2ban crashed shortly after a jail activation
|
||||
and the user should be offered a rollback option. Returns ``null`` (HTTP
|
||||
200 with ``null`` body) when no recovery is pending.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object.
|
||||
_auth: Validated session.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.PendingRecovery` or ``None``.
|
||||
"""
|
||||
return getattr(request.app.state, "pending_recovery", None)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/jails/{name}/rollback",
|
||||
response_model=RollbackResponse,
|
||||
summary="Disable a bad jail config and restart fail2ban",
|
||||
)
|
||||
async def rollback_jail(
|
||||
request: Request,
|
||||
_auth: AuthDep,
|
||||
name: _NamePath,
|
||||
) -> RollbackResponse:
|
||||
"""Disable the specified jail and attempt to restart fail2ban.
|
||||
|
||||
Writes ``enabled = false`` to ``jail.d/{name}.local`` (works even when
|
||||
fail2ban is down — no socket is needed), then runs the configured start
|
||||
command and waits up to ten seconds for the daemon to come back online.
|
||||
|
||||
On success, clears the :class:`~app.models.config.PendingRecovery` record.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object.
|
||||
_auth: Validated session.
|
||||
name: Jail name to disable and roll back.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.RollbackResponse`.
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if *name* contains invalid characters.
|
||||
HTTPException: 500 if writing the .local override file fails.
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
start_cmd: str = request.app.state.settings.fail2ban_start_command
|
||||
start_cmd_parts: list[str] = start_cmd.split()
|
||||
|
||||
try:
|
||||
result = await jail_config_service.rollback_jail(config_dir, socket_path, name, start_cmd_parts)
|
||||
except JailNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigWriteError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to write config override: {exc}",
|
||||
) from exc
|
||||
|
||||
# Clear pending recovery if fail2ban came back online.
|
||||
if result.fail2ban_running:
|
||||
request.app.state.pending_recovery = None
|
||||
request.app.state.last_activation = None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Filter discovery endpoints (Task 2.1)
|
||||
@@ -715,7 +1001,7 @@ async def list_filters(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
result = await config_file_service.list_filters(config_dir, socket_path)
|
||||
result = await filter_config_service.list_filters(config_dir, socket_path)
|
||||
# Sort: active first (by name), then inactive (by name).
|
||||
result.filters.sort(key=lambda f: (not f.active, f.name.lower()))
|
||||
return result
|
||||
@@ -752,7 +1038,7 @@ async def get_filter(
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
try:
|
||||
return await config_file_service.get_filter(config_dir, socket_path, name)
|
||||
return await filter_config_service.get_filter(config_dir, socket_path, name)
|
||||
except FilterNotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
@@ -816,9 +1102,7 @@ async def update_filter(
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
try:
|
||||
return await config_file_service.update_filter(
|
||||
config_dir, socket_path, name, body, do_reload=reload
|
||||
)
|
||||
return await filter_config_service.update_filter(config_dir, socket_path, name, body, do_reload=reload)
|
||||
except FilterNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except FilterNotFoundError:
|
||||
@@ -868,9 +1152,7 @@ async def create_filter(
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
try:
|
||||
return await config_file_service.create_filter(
|
||||
config_dir, socket_path, body, do_reload=reload
|
||||
)
|
||||
return await filter_config_service.create_filter(config_dir, socket_path, body, do_reload=reload)
|
||||
except FilterNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except FilterAlreadyExistsError as exc:
|
||||
@@ -917,7 +1199,7 @@ async def delete_filter(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
await config_file_service.delete_filter(config_dir, name)
|
||||
await filter_config_service.delete_filter(config_dir, name)
|
||||
except FilterNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except FilterNotFoundError:
|
||||
@@ -966,9 +1248,7 @@ async def assign_filter_to_jail(
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
try:
|
||||
await config_file_service.assign_filter_to_jail(
|
||||
config_dir, socket_path, name, body, do_reload=reload
|
||||
)
|
||||
await filter_config_service.assign_filter_to_jail(config_dir, socket_path, name, body, do_reload=reload)
|
||||
except (JailNameError, FilterNameError) as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except JailNotFoundInConfigError:
|
||||
@@ -1032,7 +1312,7 @@ async def list_actions(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
result = await config_file_service.list_actions(config_dir, socket_path)
|
||||
result = await action_config_service.list_actions(config_dir, socket_path)
|
||||
result.actions.sort(key=lambda a: (not a.active, a.name.lower()))
|
||||
return result
|
||||
|
||||
@@ -1067,7 +1347,7 @@ async def get_action(
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
try:
|
||||
return await config_file_service.get_action(config_dir, socket_path, name)
|
||||
return await action_config_service.get_action(config_dir, socket_path, name)
|
||||
except ActionNotFoundError:
|
||||
raise _action_not_found(name) from None
|
||||
|
||||
@@ -1112,9 +1392,7 @@ async def update_action(
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
try:
|
||||
return await config_file_service.update_action(
|
||||
config_dir, socket_path, name, body, do_reload=reload
|
||||
)
|
||||
return await action_config_service.update_action(config_dir, socket_path, name, body, do_reload=reload)
|
||||
except ActionNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ActionNotFoundError:
|
||||
@@ -1160,9 +1438,7 @@ async def create_action(
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
try:
|
||||
return await config_file_service.create_action(
|
||||
config_dir, socket_path, body, do_reload=reload
|
||||
)
|
||||
return await action_config_service.create_action(config_dir, socket_path, body, do_reload=reload)
|
||||
except ActionNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ActionAlreadyExistsError as exc:
|
||||
@@ -1205,7 +1481,7 @@ async def delete_action(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
await config_file_service.delete_action(config_dir, name)
|
||||
await action_config_service.delete_action(config_dir, name)
|
||||
except ActionNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ActionNotFoundError:
|
||||
@@ -1255,9 +1531,7 @@ async def assign_action_to_jail(
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
try:
|
||||
await config_file_service.assign_action_to_jail(
|
||||
config_dir, socket_path, name, body, do_reload=reload
|
||||
)
|
||||
await action_config_service.assign_action_to_jail(config_dir, socket_path, name, body, do_reload=reload)
|
||||
except (JailNameError, ActionNameError) as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except JailNotFoundInConfigError:
|
||||
@@ -1306,9 +1580,7 @@ async def remove_action_from_jail(
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
try:
|
||||
await config_file_service.remove_action_from_jail(
|
||||
config_dir, socket_path, name, action_name, do_reload=reload
|
||||
)
|
||||
await action_config_service.remove_action_from_jail(config_dir, socket_path, name, action_name, do_reload=reload)
|
||||
except (JailNameError, ActionNameError) as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except JailNotFoundInConfigError:
|
||||
@@ -1319,3 +1591,87 @@ async def remove_action_from_jail(
|
||||
detail=f"Failed to write jail override: {exc}",
|
||||
) from exc
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# fail2ban log viewer endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get(
|
||||
"/fail2ban-log",
|
||||
response_model=Fail2BanLogResponse,
|
||||
summary="Read the tail of the fail2ban daemon log file",
|
||||
)
|
||||
async def get_fail2ban_log(
|
||||
request: Request,
|
||||
_auth: AuthDep,
|
||||
lines: Annotated[int, Query(ge=1, le=2000, description="Number of lines to return from the tail.")] = 200,
|
||||
filter: Annotated[ # noqa: A002
|
||||
str | None,
|
||||
Query(description="Plain-text substring filter; only matching lines are returned."),
|
||||
] = None,
|
||||
) -> Fail2BanLogResponse:
|
||||
"""Return the tail of the fail2ban daemon log file.
|
||||
|
||||
Queries the fail2ban socket for the current log target and log level,
|
||||
reads the last *lines* entries from the file, and optionally filters
|
||||
them by *filter*. Only file-based log targets are supported.
|
||||
|
||||
Args:
|
||||
request: Incoming request.
|
||||
_auth: Validated session — enforces authentication.
|
||||
lines: Number of tail lines to return (1–2000, default 200).
|
||||
filter: Optional plain-text substring — only matching lines returned.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.Fail2BanLogResponse`.
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 when the log target is not a file or path is outside
|
||||
the allowed directory.
|
||||
HTTPException: 502 when fail2ban is unreachable.
|
||||
"""
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
try:
|
||||
return await config_service.read_fail2ban_log(socket_path, lines, filter)
|
||||
except config_service.ConfigOperationError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except Fail2BanConnectionError as exc:
|
||||
raise _bad_gateway(exc) from exc
|
||||
|
||||
|
||||
@router.get(
|
||||
"/service-status",
|
||||
response_model=ServiceStatusResponse,
|
||||
summary="Return fail2ban service health status with log configuration",
|
||||
)
|
||||
async def get_service_status(
|
||||
request: Request,
|
||||
_auth: AuthDep,
|
||||
) -> ServiceStatusResponse:
|
||||
"""Return fail2ban service health and current log configuration.
|
||||
|
||||
Probes the fail2ban daemon to determine online/offline state, then
|
||||
augments the result with the current log level and log target values.
|
||||
|
||||
Args:
|
||||
request: Incoming request.
|
||||
_auth: Validated session — enforces authentication.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.ServiceStatusResponse`.
|
||||
|
||||
Raises:
|
||||
HTTPException: 502 when fail2ban is unreachable (the service itself
|
||||
handles this gracefully and returns ``online=False``).
|
||||
"""
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
from app.services import health_service
|
||||
|
||||
try:
|
||||
return await config_service.get_service_status(
|
||||
socket_path,
|
||||
probe_fn=health_service.probe,
|
||||
)
|
||||
except Fail2BanConnectionError as exc:
|
||||
raise _bad_gateway(exc) from exc
|
||||
|
||||
@@ -12,13 +12,14 @@ Also provides ``GET /api/dashboard/bans`` for the dashboard ban-list table,
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
|
||||
from fastapi import APIRouter, Query, Request
|
||||
|
||||
from app import __version__
|
||||
from app.dependencies import AuthDep
|
||||
from app.models.ban import (
|
||||
BanOrigin,
|
||||
@@ -29,7 +30,7 @@ from app.models.ban import (
|
||||
TimeRange,
|
||||
)
|
||||
from app.models.server import ServerStatus, ServerStatusResponse
|
||||
from app.services import ban_service
|
||||
from app.services import ban_service, geo_service
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/api/dashboard", tags=["Dashboard"])
|
||||
|
||||
@@ -69,6 +70,7 @@ async def get_server_status(
|
||||
"server_status",
|
||||
ServerStatus(online=False),
|
||||
)
|
||||
cached.version = __version__
|
||||
return ServerStatusResponse(status=cached)
|
||||
|
||||
|
||||
@@ -81,6 +83,7 @@ async def get_dashboard_bans(
|
||||
request: Request,
|
||||
_auth: AuthDep,
|
||||
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
|
||||
source: Literal["fail2ban", "archive"] = Query(default="fail2ban", description="Data source: 'fail2ban' or 'archive'."),
|
||||
page: int = Query(default=1, ge=1, description="1-based page number."),
|
||||
page_size: int = Query(default=_DEFAULT_PAGE_SIZE, ge=1, le=500, description="Items per page."),
|
||||
origin: BanOrigin | None = Query(
|
||||
@@ -115,10 +118,12 @@ async def get_dashboard_bans(
|
||||
return await ban_service.list_bans(
|
||||
socket_path,
|
||||
range,
|
||||
source=source,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
http_session=http_session,
|
||||
app_db=None,
|
||||
app_db=request.app.state.db,
|
||||
geo_batch_lookup=geo_service.lookup_batch,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
@@ -132,6 +137,7 @@ async def get_bans_by_country(
|
||||
request: Request,
|
||||
_auth: AuthDep,
|
||||
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
|
||||
source: Literal["fail2ban", "archive"] = Query(default="fail2ban", description="Data source: 'fail2ban' or 'archive'."),
|
||||
origin: BanOrigin | None = Query(
|
||||
default=None,
|
||||
description="Filter by ban origin: 'blocklist' or 'selfblock'. Omit for all.",
|
||||
@@ -161,8 +167,11 @@ async def get_bans_by_country(
|
||||
return await ban_service.bans_by_country(
|
||||
socket_path,
|
||||
range,
|
||||
source=source,
|
||||
http_session=http_session,
|
||||
app_db=None,
|
||||
geo_cache_lookup=geo_service.lookup_cached_only,
|
||||
geo_batch_lookup=geo_service.lookup_batch,
|
||||
app_db=request.app.state.db,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
@@ -176,6 +185,7 @@ async def get_ban_trend(
|
||||
request: Request,
|
||||
_auth: AuthDep,
|
||||
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
|
||||
source: Literal["fail2ban", "archive"] = Query(default="fail2ban", description="Data source: 'fail2ban' or 'archive'."),
|
||||
origin: BanOrigin | None = Query(
|
||||
default=None,
|
||||
description="Filter by ban origin: 'blocklist' or 'selfblock'. Omit for all.",
|
||||
@@ -207,7 +217,13 @@ async def get_ban_trend(
|
||||
"""
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
|
||||
return await ban_service.ban_trend(socket_path, range, origin=origin)
|
||||
return await ban_service.ban_trend(
|
||||
socket_path,
|
||||
range,
|
||||
source=source,
|
||||
app_db=request.app.state.db,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -219,6 +235,7 @@ async def get_bans_by_jail(
|
||||
request: Request,
|
||||
_auth: AuthDep,
|
||||
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
|
||||
source: Literal["fail2ban", "archive"] = Query(default="fail2ban", description="Data source: 'fail2ban' or 'archive'."),
|
||||
origin: BanOrigin | None = Query(
|
||||
default=None,
|
||||
description="Filter by ban origin: 'blocklist' or 'selfblock'. Omit for all.",
|
||||
@@ -243,4 +260,10 @@ async def get_bans_by_jail(
|
||||
"""
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
|
||||
return await ban_service.bans_by_jail(socket_path, range, origin=origin)
|
||||
return await ban_service.bans_by_jail(
|
||||
socket_path,
|
||||
range,
|
||||
source=source,
|
||||
app_db=request.app.state.db,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
@@ -14,8 +14,8 @@ Endpoints:
|
||||
* ``GET /api/config/filters/{name}/parsed`` — parse a filter file into a structured model
|
||||
* ``PUT /api/config/filters/{name}/parsed`` — update a filter file from a structured model
|
||||
* ``GET /api/config/actions`` — list all action files
|
||||
* ``GET /api/config/actions/{name}`` — get one action file (with content)
|
||||
* ``PUT /api/config/actions/{name}`` — update an action file
|
||||
* ``GET /api/config/actions/{name}/raw`` — get one action file (raw content)
|
||||
* ``PUT /api/config/actions/{name}/raw`` — update an action file (raw content)
|
||||
* ``POST /api/config/actions`` — create a new action file
|
||||
* ``GET /api/config/actions/{name}/parsed`` — parse an action file into a structured model
|
||||
* ``PUT /api/config/actions/{name}/parsed`` — update an action file from a structured model
|
||||
@@ -51,8 +51,8 @@ from app.models.file_config import (
|
||||
JailConfigFileEnabledUpdate,
|
||||
JailConfigFilesResponse,
|
||||
)
|
||||
from app.services import file_config_service
|
||||
from app.services.file_config_service import (
|
||||
from app.services import raw_config_io_service
|
||||
from app.services.raw_config_io_service import (
|
||||
ConfigDirError,
|
||||
ConfigFileExistsError,
|
||||
ConfigFileNameError,
|
||||
@@ -134,7 +134,7 @@ async def list_jail_config_files(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
return await file_config_service.list_jail_config_files(config_dir)
|
||||
return await raw_config_io_service.list_jail_config_files(config_dir)
|
||||
except ConfigDirError as exc:
|
||||
raise _service_unavailable(str(exc)) from exc
|
||||
|
||||
@@ -166,7 +166,7 @@ async def get_jail_config_file(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
return await file_config_service.get_jail_config_file(config_dir, filename)
|
||||
return await raw_config_io_service.get_jail_config_file(config_dir, filename)
|
||||
except ConfigFileNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigFileNotFoundError:
|
||||
@@ -204,7 +204,7 @@ async def write_jail_config_file(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
await file_config_service.write_jail_config_file(config_dir, filename, body)
|
||||
await raw_config_io_service.write_jail_config_file(config_dir, filename, body)
|
||||
except ConfigFileNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigFileNotFoundError:
|
||||
@@ -244,7 +244,7 @@ async def set_jail_config_file_enabled(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
await file_config_service.set_jail_config_enabled(
|
||||
await raw_config_io_service.set_jail_config_enabled(
|
||||
config_dir, filename, body.enabled
|
||||
)
|
||||
except ConfigFileNameError as exc:
|
||||
@@ -285,7 +285,7 @@ async def create_jail_config_file(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
filename = await file_config_service.create_jail_config_file(config_dir, body)
|
||||
filename = await raw_config_io_service.create_jail_config_file(config_dir, body)
|
||||
except ConfigFileNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigFileExistsError:
|
||||
@@ -338,7 +338,7 @@ async def get_filter_file_raw(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
return await file_config_service.get_filter_file(config_dir, name)
|
||||
return await raw_config_io_service.get_filter_file(config_dir, name)
|
||||
except ConfigFileNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigFileNotFoundError:
|
||||
@@ -373,7 +373,7 @@ async def write_filter_file(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
await file_config_service.write_filter_file(config_dir, name, body)
|
||||
await raw_config_io_service.write_filter_file(config_dir, name, body)
|
||||
except ConfigFileNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigFileNotFoundError:
|
||||
@@ -412,7 +412,7 @@ async def create_filter_file(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
filename = await file_config_service.create_filter_file(config_dir, body)
|
||||
filename = await raw_config_io_service.create_filter_file(config_dir, body)
|
||||
except ConfigFileNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigFileExistsError:
|
||||
@@ -454,13 +454,13 @@ async def list_action_files(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
return await file_config_service.list_action_files(config_dir)
|
||||
return await raw_config_io_service.list_action_files(config_dir)
|
||||
except ConfigDirError as exc:
|
||||
raise _service_unavailable(str(exc)) from exc
|
||||
|
||||
|
||||
@router.get(
|
||||
"/actions/{name}",
|
||||
"/actions/{name}/raw",
|
||||
response_model=ConfFileContent,
|
||||
summary="Return an action definition file with its content",
|
||||
)
|
||||
@@ -486,7 +486,7 @@ async def get_action_file(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
return await file_config_service.get_action_file(config_dir, name)
|
||||
return await raw_config_io_service.get_action_file(config_dir, name)
|
||||
except ConfigFileNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigFileNotFoundError:
|
||||
@@ -496,7 +496,7 @@ async def get_action_file(
|
||||
|
||||
|
||||
@router.put(
|
||||
"/actions/{name}",
|
||||
"/actions/{name}/raw",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Update an action definition file",
|
||||
)
|
||||
@@ -521,7 +521,7 @@ async def write_action_file(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
await file_config_service.write_action_file(config_dir, name, body)
|
||||
await raw_config_io_service.write_action_file(config_dir, name, body)
|
||||
except ConfigFileNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigFileNotFoundError:
|
||||
@@ -560,7 +560,7 @@ async def create_action_file(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
filename = await file_config_service.create_action_file(config_dir, body)
|
||||
filename = await raw_config_io_service.create_action_file(config_dir, body)
|
||||
except ConfigFileNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigFileExistsError:
|
||||
@@ -613,7 +613,7 @@ async def get_parsed_filter(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
return await file_config_service.get_parsed_filter_file(config_dir, name)
|
||||
return await raw_config_io_service.get_parsed_filter_file(config_dir, name)
|
||||
except ConfigFileNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigFileNotFoundError:
|
||||
@@ -651,7 +651,7 @@ async def update_parsed_filter(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
await file_config_service.update_parsed_filter_file(config_dir, name, body)
|
||||
await raw_config_io_service.update_parsed_filter_file(config_dir, name, body)
|
||||
except ConfigFileNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigFileNotFoundError:
|
||||
@@ -698,7 +698,7 @@ async def get_parsed_action(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
return await file_config_service.get_parsed_action_file(config_dir, name)
|
||||
return await raw_config_io_service.get_parsed_action_file(config_dir, name)
|
||||
except ConfigFileNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigFileNotFoundError:
|
||||
@@ -736,7 +736,7 @@ async def update_parsed_action(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
await file_config_service.update_parsed_action_file(config_dir, name, body)
|
||||
await raw_config_io_service.update_parsed_action_file(config_dir, name, body)
|
||||
except ConfigFileNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigFileNotFoundError:
|
||||
@@ -783,7 +783,7 @@ async def get_parsed_jail_file(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
return await file_config_service.get_parsed_jail_file(config_dir, filename)
|
||||
return await raw_config_io_service.get_parsed_jail_file(config_dir, filename)
|
||||
except ConfigFileNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigFileNotFoundError:
|
||||
@@ -821,7 +821,7 @@ async def update_parsed_jail_file(
|
||||
"""
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
try:
|
||||
await file_config_service.update_parsed_jail_file(config_dir, filename, body)
|
||||
await raw_config_io_service.update_parsed_jail_file(config_dir, filename, body)
|
||||
except ConfigFileNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigFileNotFoundError:
|
||||
|
||||
@@ -13,11 +13,13 @@ from typing import TYPE_CHECKING, Annotated
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
|
||||
from app.services.jail_service import IpLookupResult
|
||||
|
||||
import aiosqlite
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Request, status
|
||||
|
||||
from app.dependencies import AuthDep, get_db
|
||||
from app.models.geo import GeoCacheStatsResponse, GeoDetail, IpLookupResponse
|
||||
from app.models.geo import GeoCacheStatsResponse, GeoDetail, GeoInfo, IpLookupResponse
|
||||
from app.services import geo_service, jail_service
|
||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||
|
||||
@@ -61,7 +63,7 @@ async def lookup_ip(
|
||||
return await geo_service.lookup(addr, http_session)
|
||||
|
||||
try:
|
||||
result = await jail_service.lookup_ip(
|
||||
result: IpLookupResult = await jail_service.lookup_ip(
|
||||
socket_path,
|
||||
ip,
|
||||
geo_enricher=_enricher,
|
||||
@@ -77,9 +79,9 @@ async def lookup_ip(
|
||||
detail=f"Cannot reach fail2ban: {exc}",
|
||||
) from exc
|
||||
|
||||
raw_geo = result.get("geo")
|
||||
raw_geo = result["geo"]
|
||||
geo_detail: GeoDetail | None = None
|
||||
if raw_geo is not None:
|
||||
if isinstance(raw_geo, GeoInfo):
|
||||
geo_detail = GeoDetail(
|
||||
country_code=raw_geo.country_code,
|
||||
country_name=raw_geo.country_name,
|
||||
@@ -153,12 +155,7 @@ async def re_resolve_geo(
|
||||
that were retried.
|
||||
"""
|
||||
# Collect all IPs in geo_cache that still lack a country code.
|
||||
unresolved: list[str] = []
|
||||
async with db.execute(
|
||||
"SELECT ip FROM geo_cache WHERE country_code IS NULL"
|
||||
) as cur:
|
||||
async for row in cur:
|
||||
unresolved.append(str(row[0]))
|
||||
unresolved = await geo_service.get_unresolved_ips(db)
|
||||
|
||||
if not unresolved:
|
||||
return {"resolved": 0, "total": 0}
|
||||
|
||||
@@ -1,21 +1,37 @@
|
||||
"""Health check router.
|
||||
|
||||
A lightweight ``GET /api/health`` endpoint that verifies the application
|
||||
is running and can serve requests. It does not probe fail2ban — that
|
||||
responsibility belongs to the health service (Stage 4).
|
||||
is running and can serve requests. Also reports the cached fail2ban liveness
|
||||
state so monitoring tools and Docker health checks can observe daemon status
|
||||
without probing the socket directly.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.models.server import ServerStatus
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/api", tags=["Health"])
|
||||
|
||||
|
||||
@router.get("/health", summary="Application health check")
|
||||
async def health_check() -> JSONResponse:
|
||||
"""Return a 200 response confirming the API is operational.
|
||||
async def health_check(request: Request) -> JSONResponse:
|
||||
"""Return 200 with application and fail2ban status.
|
||||
|
||||
HTTP 200 is always returned so Docker health checks do not restart the
|
||||
backend container when fail2ban is temporarily offline. The
|
||||
``fail2ban`` field in the body indicates the daemon's current state.
|
||||
|
||||
Args:
|
||||
request: FastAPI request (used to read cached server status).
|
||||
|
||||
Returns:
|
||||
A JSON object with ``{"status": "ok"}``.
|
||||
A JSON object with ``{"status": "ok", "fail2ban": "online"|"offline"}``.
|
||||
"""
|
||||
return JSONResponse(content={"status": "ok"})
|
||||
cached: ServerStatus = getattr(
|
||||
request.app.state, "server_status", ServerStatus(online=False)
|
||||
)
|
||||
return JSONResponse(content={
|
||||
"status": "ok",
|
||||
"fail2ban": "online" if cached.online else "offline",
|
||||
})
|
||||
|
||||
@@ -15,7 +15,7 @@ Routes
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
||||
from fastapi import APIRouter, HTTPException, Query, Request
|
||||
|
||||
from app.dependencies import AuthDep
|
||||
from app.models.ban import TimeRange
|
||||
from app.models.ban import BanOrigin, TimeRange
|
||||
from app.models.history import HistoryListResponse, IpDetailResponse
|
||||
from app.services import geo_service, history_service
|
||||
|
||||
@@ -52,6 +52,14 @@ async def get_history(
|
||||
default=None,
|
||||
description="Restrict results to IPs matching this prefix.",
|
||||
),
|
||||
origin: BanOrigin | None = Query(
|
||||
default=None,
|
||||
description="Filter by ban origin: 'blocklist' or 'selfblock'. Omit for all.",
|
||||
),
|
||||
source: Literal["fail2ban", "archive"] = Query(
|
||||
default="fail2ban",
|
||||
description="Data source: 'fail2ban' or 'archive'.",
|
||||
),
|
||||
page: int = Query(default=1, ge=1, description="1-based page number."),
|
||||
page_size: int = Query(
|
||||
default=_DEFAULT_PAGE_SIZE,
|
||||
@@ -89,9 +97,48 @@ async def get_history(
|
||||
range_=range,
|
||||
jail=jail,
|
||||
ip_filter=ip,
|
||||
origin=origin,
|
||||
source=source,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
geo_enricher=_enricher,
|
||||
db=request.app.state.db,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/archive",
|
||||
response_model=HistoryListResponse,
|
||||
summary="Return a paginated list of archived historical bans",
|
||||
)
|
||||
async def get_history_archive(
|
||||
request: Request,
|
||||
_auth: AuthDep,
|
||||
range: TimeRange | None = Query(
|
||||
default=None,
|
||||
description="Optional time-range filter. Omit for all-time.",
|
||||
),
|
||||
jail: str | None = Query(default=None, description="Restrict results to this jail name."),
|
||||
ip: str | None = Query(default=None, description="Restrict results to IPs matching this prefix."),
|
||||
page: int = Query(default=1, ge=1, description="1-based page number."),
|
||||
page_size: int = Query(default=_DEFAULT_PAGE_SIZE, ge=1, le=500, description="Items per page (max 500)."),
|
||||
) -> HistoryListResponse:
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
http_session: aiohttp.ClientSession = request.app.state.http_session
|
||||
|
||||
async def _enricher(addr: str) -> geo_service.GeoInfo | None:
|
||||
return await geo_service.lookup(addr, http_session)
|
||||
|
||||
return await history_service.list_history(
|
||||
socket_path,
|
||||
range_=range,
|
||||
jail=jail,
|
||||
ip_filter=ip,
|
||||
source="archive",
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
geo_enricher=_enricher,
|
||||
db=request.app.state.db,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ Provides CRUD and control operations for fail2ban jails:
|
||||
|
||||
* ``GET /api/jails`` — list all jails
|
||||
* ``GET /api/jails/{name}`` — full detail for one jail
|
||||
* ``GET /api/jails/{name}/banned`` — paginated currently-banned IPs for one jail
|
||||
* ``POST /api/jails/{name}/start`` — start a jail
|
||||
* ``POST /api/jails/{name}/stop`` — stop a jail
|
||||
* ``POST /api/jails/{name}/idle`` — toggle idle mode
|
||||
@@ -23,14 +24,15 @@ from typing import Annotated
|
||||
from fastapi import APIRouter, Body, HTTPException, Path, Request, status
|
||||
|
||||
from app.dependencies import AuthDep
|
||||
from app.models.ban import JailBannedIpsResponse
|
||||
from app.models.jail import (
|
||||
IgnoreIpRequest,
|
||||
JailCommandResponse,
|
||||
JailDetailResponse,
|
||||
JailListResponse,
|
||||
)
|
||||
from app.services import jail_service
|
||||
from app.services.jail_service import JailNotFoundError, JailOperationError
|
||||
from app.services import geo_service, jail_service
|
||||
from app.exceptions import JailNotFoundError, JailOperationError
|
||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/api/jails", tags=["Jails"])
|
||||
@@ -540,3 +542,75 @@ async def toggle_ignore_self(
|
||||
raise _conflict(str(exc)) from exc
|
||||
except Fail2BanConnectionError as exc:
|
||||
raise _bad_gateway(exc) from exc
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Currently banned IPs (paginated)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{name}/banned",
|
||||
response_model=JailBannedIpsResponse,
|
||||
summary="Return paginated currently-banned IPs for a single jail",
|
||||
)
|
||||
async def get_jail_banned_ips(
|
||||
request: Request,
|
||||
_auth: AuthDep,
|
||||
name: _NamePath,
|
||||
page: int = 1,
|
||||
page_size: int = 25,
|
||||
search: str | None = None,
|
||||
) -> JailBannedIpsResponse:
|
||||
"""Return a paginated list of IPs currently banned by a specific jail.
|
||||
|
||||
The full ban list is fetched from the fail2ban socket, filtered by the
|
||||
optional *search* substring, sliced to the requested page, and then
|
||||
geo-enriched exclusively for that page slice.
|
||||
|
||||
Args:
|
||||
request: Incoming request (used to access ``app.state``).
|
||||
_auth: Validated session — enforces authentication.
|
||||
name: Jail name.
|
||||
page: 1-based page number (default 1, min 1).
|
||||
page_size: Items per page (default 25, max 100).
|
||||
search: Optional case-insensitive substring filter on the IP address.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.ban.JailBannedIpsResponse` with the paginated bans.
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 when *page* or *page_size* are out of range.
|
||||
HTTPException: 404 when the jail does not exist.
|
||||
HTTPException: 502 when fail2ban is unreachable.
|
||||
"""
|
||||
if page < 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="page must be >= 1.",
|
||||
)
|
||||
if not (1 <= page_size <= 100):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="page_size must be between 1 and 100.",
|
||||
)
|
||||
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
http_session = getattr(request.app.state, "http_session", None)
|
||||
app_db = getattr(request.app.state, "db", None)
|
||||
|
||||
try:
|
||||
return await jail_service.get_jail_banned_ips(
|
||||
socket_path=socket_path,
|
||||
jail_name=name,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
search=search,
|
||||
geo_batch_lookup=geo_service.lookup_batch,
|
||||
http_session=http_session,
|
||||
app_db=app_db,
|
||||
)
|
||||
except JailNotFoundError:
|
||||
raise _not_found(name) from None
|
||||
except Fail2BanConnectionError as exc:
|
||||
raise _bad_gateway(exc) from exc
|
||||
|
||||
@@ -15,7 +15,7 @@ from fastapi import APIRouter, HTTPException, Request, status
|
||||
from app.dependencies import AuthDep
|
||||
from app.models.server import ServerSettingsResponse, ServerSettingsUpdate
|
||||
from app.services import server_service
|
||||
from app.services.server_service import ServerOperationError
|
||||
from app.exceptions import ServerOperationError
|
||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/api/server", tags=["Server"])
|
||||
|
||||
1070
backend/app/services/action_config_service.py
Normal file
1070
backend/app/services/action_config_service.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -20,7 +20,7 @@ if TYPE_CHECKING:
|
||||
from app.models.auth import Session
|
||||
|
||||
from app.repositories import session_repo
|
||||
from app.services import setup_service
|
||||
from app.utils.setup_utils import get_password_hash
|
||||
from app.utils.time_utils import add_minutes, utc_now
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
@@ -65,7 +65,7 @@ async def login(
|
||||
Raises:
|
||||
ValueError: If the password is incorrect or no password hash is stored.
|
||||
"""
|
||||
stored_hash = await setup_service.get_password_hash(db)
|
||||
stored_hash = await get_password_hash(db)
|
||||
if stored_hash is None:
|
||||
log.warning("bangui_login_no_hash")
|
||||
raise ValueError("No password is configured — run setup first.")
|
||||
|
||||
@@ -11,12 +11,9 @@ so BanGUI never modifies or locks the fail2ban database.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiosqlite
|
||||
import structlog
|
||||
|
||||
from app.models.ban import (
|
||||
@@ -31,15 +28,21 @@ from app.models.ban import (
|
||||
BanTrendResponse,
|
||||
DashboardBanItem,
|
||||
DashboardBanListResponse,
|
||||
JailBanCount,
|
||||
TimeRange,
|
||||
_derive_origin,
|
||||
bucket_count,
|
||||
)
|
||||
from app.utils.fail2ban_client import Fail2BanClient
|
||||
from app.models.ban import (
|
||||
JailBanCount as JailBanCountModel,
|
||||
)
|
||||
from app.repositories import fail2ban_db_repo
|
||||
from app.utils.fail2ban_db_utils import get_fail2ban_db_path, parse_data_json, ts_to_iso
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
|
||||
from app.models.geo import GeoBatchLookup, GeoCacheLookup, GeoEnricher, GeoInfo
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
@@ -74,6 +77,9 @@ def _origin_sql_filter(origin: BanOrigin | None) -> tuple[str, tuple[str, ...]]:
|
||||
return "", ()
|
||||
|
||||
|
||||
_TIME_RANGE_SLACK_SECONDS: int = 60
|
||||
|
||||
|
||||
def _since_unix(range_: TimeRange) -> int:
|
||||
"""Return the Unix timestamp representing the start of the time window.
|
||||
|
||||
@@ -88,92 +94,13 @@ def _since_unix(range_: TimeRange) -> int:
|
||||
range_: One of the supported time-range presets.
|
||||
|
||||
Returns:
|
||||
Unix timestamp (seconds since epoch) equal to *now − range_*.
|
||||
Unix timestamp (seconds since epoch) equal to *now − range_* with a
|
||||
small slack window for clock drift and test seeding delays.
|
||||
"""
|
||||
seconds: int = TIME_RANGE_SECONDS[range_]
|
||||
return int(time.time()) - seconds
|
||||
return int(time.time()) - seconds - _TIME_RANGE_SLACK_SECONDS
|
||||
|
||||
|
||||
def _ts_to_iso(unix_ts: int) -> str:
|
||||
"""Convert a Unix timestamp to an ISO 8601 UTC string.
|
||||
|
||||
Args:
|
||||
unix_ts: Seconds since the Unix epoch.
|
||||
|
||||
Returns:
|
||||
ISO 8601 UTC timestamp, e.g. ``"2026-03-01T12:00:00+00:00"``.
|
||||
"""
|
||||
return datetime.fromtimestamp(unix_ts, tz=UTC).isoformat()
|
||||
|
||||
|
||||
async def _get_fail2ban_db_path(socket_path: str) -> str:
|
||||
"""Query fail2ban for the path to its SQLite database.
|
||||
|
||||
Sends the ``get dbfile`` command via the fail2ban socket and returns
|
||||
the value of the ``dbfile`` setting.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
|
||||
Returns:
|
||||
Absolute path to the fail2ban SQLite database file.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If fail2ban reports that no database is configured
|
||||
or if the socket response is unexpected.
|
||||
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
|
||||
cannot be reached.
|
||||
"""
|
||||
async with Fail2BanClient(socket_path, timeout=_SOCKET_TIMEOUT) as client:
|
||||
response = await client.send(["get", "dbfile"])
|
||||
|
||||
try:
|
||||
code, data = response
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise RuntimeError(f"Unexpected response from fail2ban: {response!r}") from exc
|
||||
|
||||
if code != 0:
|
||||
raise RuntimeError(f"fail2ban error code {code}: {data!r}")
|
||||
|
||||
if data is None:
|
||||
raise RuntimeError("fail2ban has no database configured (dbfile is None)")
|
||||
|
||||
return str(data)
|
||||
|
||||
|
||||
def _parse_data_json(raw: Any) -> tuple[list[str], int]:
|
||||
"""Extract matches and failure count from the ``bans.data`` column.
|
||||
|
||||
The ``data`` column stores a JSON blob with optional keys:
|
||||
|
||||
* ``matches`` — list of raw matched log lines.
|
||||
* ``failures`` — total failure count that triggered the ban.
|
||||
|
||||
Args:
|
||||
raw: The raw ``data`` column value (string, dict, or ``None``).
|
||||
|
||||
Returns:
|
||||
A ``(matches, failures)`` tuple. Both default to empty/zero when
|
||||
parsing fails or the column is absent.
|
||||
"""
|
||||
if raw is None:
|
||||
return [], 0
|
||||
|
||||
obj: dict[str, Any] = {}
|
||||
if isinstance(raw, str):
|
||||
try:
|
||||
parsed: Any = json.loads(raw)
|
||||
if isinstance(parsed, dict):
|
||||
obj = parsed
|
||||
# json.loads("null") → None, or other non-dict — treat as empty
|
||||
except json.JSONDecodeError:
|
||||
return [], 0
|
||||
elif isinstance(raw, dict):
|
||||
obj = raw
|
||||
|
||||
matches: list[str] = [str(m) for m in (obj.get("matches") or [])]
|
||||
failures: int = int(obj.get("failures", 0))
|
||||
return matches, failures
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -185,11 +112,13 @@ async def list_bans(
|
||||
socket_path: str,
|
||||
range_: TimeRange,
|
||||
*,
|
||||
source: str = "fail2ban",
|
||||
page: int = 1,
|
||||
page_size: int = _DEFAULT_PAGE_SIZE,
|
||||
http_session: aiohttp.ClientSession | None = None,
|
||||
app_db: aiosqlite.Connection | None = None,
|
||||
geo_enricher: Any | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
origin: BanOrigin | None = None,
|
||||
) -> DashboardBanListResponse:
|
||||
"""Return a paginated list of bans within the selected time window.
|
||||
@@ -228,61 +157,72 @@ async def list_bans(
|
||||
:class:`~app.models.ban.DashboardBanListResponse` containing the
|
||||
paginated items and total count.
|
||||
"""
|
||||
from app.services import geo_service # noqa: PLC0415
|
||||
|
||||
since: int = _since_unix(range_)
|
||||
effective_page_size: int = min(page_size, _MAX_PAGE_SIZE)
|
||||
offset: int = (page - 1) * effective_page_size
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
|
||||
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||
log.info(
|
||||
"ban_service_list_bans",
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
range=range_,
|
||||
origin=origin,
|
||||
)
|
||||
if source not in ("fail2ban", "archive"):
|
||||
raise ValueError(f"Unsupported source: {source!r}")
|
||||
|
||||
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
|
||||
f2b_db.row_factory = aiosqlite.Row
|
||||
if source == "archive":
|
||||
if app_db is None:
|
||||
raise ValueError("app_db must be provided when source is 'archive'")
|
||||
|
||||
async with f2b_db.execute(
|
||||
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
|
||||
(since, *origin_params),
|
||||
) as cur:
|
||||
count_row = await cur.fetchone()
|
||||
total: int = int(count_row[0]) if count_row else 0
|
||||
from app.repositories.history_archive_repo import get_archived_history
|
||||
|
||||
async with f2b_db.execute(
|
||||
"SELECT jail, ip, timeofban, bancount, data "
|
||||
"FROM bans "
|
||||
"WHERE timeofban >= ?"
|
||||
+ origin_clause
|
||||
+ " ORDER BY timeofban DESC "
|
||||
"LIMIT ? OFFSET ?",
|
||||
(since, *origin_params, effective_page_size, offset),
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
rows, total = await get_archived_history(
|
||||
db=app_db,
|
||||
since=since,
|
||||
origin=origin,
|
||||
action="ban",
|
||||
page=page,
|
||||
page_size=effective_page_size,
|
||||
)
|
||||
else:
|
||||
db_path: str = await get_fail2ban_db_path(socket_path)
|
||||
log.info(
|
||||
"ban_service_list_bans",
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
range=range_,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
rows, total = await fail2ban_db_repo.get_currently_banned(
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
origin=origin,
|
||||
limit=effective_page_size,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
# Batch-resolve geo data for all IPs on this page in a single API call.
|
||||
# This avoids hitting the 45 req/min single-IP rate limit when the
|
||||
# page contains many bans (e.g. after a large blocklist import).
|
||||
geo_map: dict[str, Any] = {}
|
||||
if http_session is not None and rows:
|
||||
page_ips: list[str] = [str(r["ip"]) for r in rows]
|
||||
geo_map: dict[str, GeoInfo] = {}
|
||||
if http_session is not None and rows and geo_batch_lookup is not None:
|
||||
page_ips: list[str] = [r.ip for r in rows]
|
||||
try:
|
||||
geo_map = await geo_service.lookup_batch(page_ips, http_session, db=app_db)
|
||||
geo_map = await geo_batch_lookup(page_ips, http_session, db=app_db)
|
||||
except Exception: # noqa: BLE001
|
||||
log.warning("ban_service_batch_geo_failed_list_bans")
|
||||
|
||||
items: list[DashboardBanItem] = []
|
||||
for row in rows:
|
||||
jail: str = str(row["jail"])
|
||||
ip: str = str(row["ip"])
|
||||
banned_at: str = _ts_to_iso(int(row["timeofban"]))
|
||||
ban_count: int = int(row["bancount"])
|
||||
matches, _ = _parse_data_json(row["data"])
|
||||
if source == "archive":
|
||||
jail = str(row["jail"])
|
||||
ip = str(row["ip"])
|
||||
banned_at = ts_to_iso(int(row["timeofban"]))
|
||||
ban_count = int(row["bancount"])
|
||||
matches, _ = parse_data_json(row["data"])
|
||||
else:
|
||||
jail = row.jail
|
||||
ip = row.ip
|
||||
banned_at = ts_to_iso(row.timeofban)
|
||||
ban_count = row.bancount
|
||||
matches, _ = parse_data_json(row.data)
|
||||
|
||||
service: str | None = matches[0] if matches else None
|
||||
|
||||
country_code: str | None = None
|
||||
@@ -342,8 +282,12 @@ _MAX_COMPANION_BANS: int = 200
|
||||
async def bans_by_country(
|
||||
socket_path: str,
|
||||
range_: TimeRange,
|
||||
*,
|
||||
source: str = "fail2ban",
|
||||
http_session: aiohttp.ClientSession | None = None,
|
||||
geo_enricher: Any | None = None,
|
||||
geo_cache_lookup: GeoCacheLookup | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
app_db: aiosqlite.Connection | None = None,
|
||||
origin: BanOrigin | None = None,
|
||||
) -> BansByCountryResponse:
|
||||
@@ -382,77 +326,105 @@ async def bans_by_country(
|
||||
:class:`~app.models.ban.BansByCountryResponse` with per-country
|
||||
aggregation and the companion ban list.
|
||||
"""
|
||||
from app.services import geo_service # noqa: PLC0415
|
||||
|
||||
since: int = _since_unix(range_)
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||
log.info(
|
||||
"ban_service_bans_by_country",
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
range=range_,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
|
||||
f2b_db.row_factory = aiosqlite.Row
|
||||
if source not in ("fail2ban", "archive"):
|
||||
raise ValueError(f"Unsupported source: {source!r}")
|
||||
|
||||
# Total count for the window.
|
||||
async with f2b_db.execute(
|
||||
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
|
||||
(since, *origin_params),
|
||||
) as cur:
|
||||
count_row = await cur.fetchone()
|
||||
total: int = int(count_row[0]) if count_row else 0
|
||||
if source == "archive":
|
||||
if app_db is None:
|
||||
raise ValueError("app_db must be provided when source is 'archive'")
|
||||
|
||||
# Aggregation: unique IPs + their total event count.
|
||||
# No LIMIT here — we need all unique source IPs for accurate country counts.
|
||||
async with f2b_db.execute(
|
||||
"SELECT ip, COUNT(*) AS event_count "
|
||||
"FROM bans "
|
||||
"WHERE timeofban >= ?"
|
||||
+ origin_clause
|
||||
+ " GROUP BY ip",
|
||||
(since, *origin_params),
|
||||
) as cur:
|
||||
agg_rows = await cur.fetchall()
|
||||
from app.repositories.history_archive_repo import (
|
||||
get_all_archived_history,
|
||||
get_archived_history,
|
||||
)
|
||||
|
||||
# Companion table: most recent raw rows for display alongside the map.
|
||||
async with f2b_db.execute(
|
||||
"SELECT jail, ip, timeofban, bancount, data "
|
||||
"FROM bans "
|
||||
"WHERE timeofban >= ?"
|
||||
+ origin_clause
|
||||
+ " ORDER BY timeofban DESC "
|
||||
"LIMIT ?",
|
||||
(since, *origin_params, _MAX_COMPANION_BANS),
|
||||
) as cur:
|
||||
companion_rows = await cur.fetchall()
|
||||
all_rows = await get_all_archived_history(
|
||||
db=app_db,
|
||||
since=since,
|
||||
origin=origin,
|
||||
action="ban",
|
||||
)
|
||||
|
||||
unique_ips: list[str] = [str(r["ip"]) for r in agg_rows]
|
||||
geo_map: dict[str, Any] = {}
|
||||
total = len(all_rows)
|
||||
|
||||
if http_session is not None and unique_ips:
|
||||
# companion rows for the table should be most recent
|
||||
companion_rows, _ = await get_archived_history(
|
||||
db=app_db,
|
||||
since=since,
|
||||
origin=origin,
|
||||
action="ban",
|
||||
page=1,
|
||||
page_size=_MAX_COMPANION_BANS,
|
||||
)
|
||||
|
||||
agg_rows = {}
|
||||
for row in all_rows:
|
||||
ip = str(row["ip"])
|
||||
agg_rows[ip] = agg_rows.get(ip, 0) + 1
|
||||
|
||||
unique_ips = list(agg_rows.keys())
|
||||
else:
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
db_path: str = await get_fail2ban_db_path(socket_path)
|
||||
log.info(
|
||||
"ban_service_bans_by_country",
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
range=range_,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
# Total count and companion rows reuse the same SQL query logic.
|
||||
# Passing limit=0 returns only the total from the count query.
|
||||
_, total = await fail2ban_db_repo.get_currently_banned(
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
origin=origin,
|
||||
limit=0,
|
||||
offset=0,
|
||||
)
|
||||
|
||||
agg_rows = await fail2ban_db_repo.get_ban_event_counts(
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
companion_rows, _ = await fail2ban_db_repo.get_currently_banned(
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
origin=origin,
|
||||
limit=_MAX_COMPANION_BANS,
|
||||
offset=0,
|
||||
)
|
||||
|
||||
unique_ips = [r.ip for r in agg_rows]
|
||||
geo_map: dict[str, GeoInfo] = {}
|
||||
|
||||
if http_session is not None and unique_ips and geo_cache_lookup is not None:
|
||||
# Serve only what is already in the in-memory cache — no API calls on
|
||||
# the hot path. Uncached IPs are resolved asynchronously in the
|
||||
# background so subsequent requests benefit from a warmer cache.
|
||||
geo_map, uncached = geo_service.lookup_cached_only(unique_ips)
|
||||
geo_map, uncached = geo_cache_lookup(unique_ips)
|
||||
if uncached:
|
||||
log.info(
|
||||
"ban_service_geo_background_scheduled",
|
||||
uncached=len(uncached),
|
||||
cached=len(geo_map),
|
||||
)
|
||||
# Fire-and-forget: lookup_batch handles rate-limiting / retries.
|
||||
# The dirty-set flush task persists results to the DB.
|
||||
asyncio.create_task( # noqa: RUF006
|
||||
geo_service.lookup_batch(uncached, http_session, db=app_db),
|
||||
name="geo_bans_by_country",
|
||||
)
|
||||
if geo_batch_lookup is not None:
|
||||
# Fire-and-forget: lookup_batch handles rate-limiting / retries.
|
||||
# The dirty-set flush task persists results to the DB.
|
||||
asyncio.create_task( # noqa: RUF006
|
||||
geo_batch_lookup(uncached, http_session, db=app_db),
|
||||
name="geo_bans_by_country",
|
||||
)
|
||||
elif geo_enricher is not None and unique_ips:
|
||||
# Fallback: legacy per-IP enricher (used in tests / older callers).
|
||||
async def _safe_lookup(ip: str) -> tuple[str, Any]:
|
||||
async def _safe_lookup(ip: str) -> tuple[str, GeoInfo | None]:
|
||||
try:
|
||||
return ip, await geo_enricher(ip)
|
||||
except Exception: # noqa: BLE001
|
||||
@@ -460,18 +432,34 @@ async def bans_by_country(
|
||||
return ip, None
|
||||
|
||||
results = await asyncio.gather(*(_safe_lookup(ip) for ip in unique_ips))
|
||||
geo_map = dict(results)
|
||||
geo_map = {ip: geo for ip, geo in results if geo is not None}
|
||||
|
||||
# Build country aggregation from the SQL-grouped rows.
|
||||
countries: dict[str, int] = {}
|
||||
country_names: dict[str, str] = {}
|
||||
|
||||
for row in agg_rows:
|
||||
ip: str = str(row["ip"])
|
||||
if source == "archive":
|
||||
agg_items = [
|
||||
{
|
||||
"ip": ip,
|
||||
"event_count": count,
|
||||
}
|
||||
for ip, count in agg_rows.items()
|
||||
]
|
||||
else:
|
||||
agg_items = agg_rows
|
||||
|
||||
for agg_row in agg_items:
|
||||
if source == "archive":
|
||||
ip = agg_row["ip"]
|
||||
event_count = agg_row["event_count"]
|
||||
else:
|
||||
ip = agg_row.ip
|
||||
event_count = agg_row.event_count
|
||||
|
||||
geo = geo_map.get(ip)
|
||||
cc: str | None = geo.country_code if geo else None
|
||||
cn: str | None = geo.country_name if geo else None
|
||||
event_count: int = int(row["event_count"])
|
||||
|
||||
if cc:
|
||||
countries[cc] = countries.get(cc, 0) + event_count
|
||||
@@ -480,27 +468,39 @@ async def bans_by_country(
|
||||
|
||||
# Build companion table from recent rows (geo already cached from batch step).
|
||||
bans: list[DashboardBanItem] = []
|
||||
for row in companion_rows:
|
||||
ip = str(row["ip"])
|
||||
for companion_row in companion_rows:
|
||||
if source == "archive":
|
||||
ip = companion_row["ip"]
|
||||
jail = companion_row["jail"]
|
||||
banned_at = ts_to_iso(int(companion_row["timeofban"]))
|
||||
ban_count = int(companion_row["bancount"])
|
||||
service = None
|
||||
else:
|
||||
ip = companion_row.ip
|
||||
jail = companion_row.jail
|
||||
banned_at = ts_to_iso(companion_row.timeofban)
|
||||
ban_count = companion_row.bancount
|
||||
matches, _ = parse_data_json(companion_row.data)
|
||||
service = matches[0] if matches else None
|
||||
|
||||
geo = geo_map.get(ip)
|
||||
cc = geo.country_code if geo else None
|
||||
cn = geo.country_name if geo else None
|
||||
asn: str | None = geo.asn if geo else None
|
||||
org: str | None = geo.org if geo else None
|
||||
matches, _ = _parse_data_json(row["data"])
|
||||
|
||||
bans.append(
|
||||
DashboardBanItem(
|
||||
ip=ip,
|
||||
jail=str(row["jail"]),
|
||||
banned_at=_ts_to_iso(int(row["timeofban"])),
|
||||
service=matches[0] if matches else None,
|
||||
jail=jail,
|
||||
banned_at=banned_at,
|
||||
service=service,
|
||||
country_code=cc,
|
||||
country_name=cn,
|
||||
asn=asn,
|
||||
org=org,
|
||||
ban_count=int(row["bancount"]),
|
||||
origin=_derive_origin(str(row["jail"])),
|
||||
ban_count=ban_count,
|
||||
origin=_derive_origin(jail),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -521,6 +521,8 @@ async def ban_trend(
|
||||
socket_path: str,
|
||||
range_: TimeRange,
|
||||
*,
|
||||
source: str = "fail2ban",
|
||||
app_db: aiosqlite.Connection | None = None,
|
||||
origin: BanOrigin | None = None,
|
||||
) -> BanTrendResponse:
|
||||
"""Return ban counts aggregated into equal-width time buckets.
|
||||
@@ -552,45 +554,63 @@ async def ban_trend(
|
||||
since: int = _since_unix(range_)
|
||||
bucket_secs: int = BUCKET_SECONDS[range_]
|
||||
num_buckets: int = bucket_count(range_)
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
|
||||
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||
log.info(
|
||||
"ban_service_ban_trend",
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
range=range_,
|
||||
origin=origin,
|
||||
bucket_secs=bucket_secs,
|
||||
num_buckets=num_buckets,
|
||||
)
|
||||
if source not in ("fail2ban", "archive"):
|
||||
raise ValueError(f"Unsupported source: {source!r}")
|
||||
|
||||
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
|
||||
f2b_db.row_factory = aiosqlite.Row
|
||||
if source == "archive":
|
||||
if app_db is None:
|
||||
raise ValueError("app_db must be provided when source is 'archive'")
|
||||
|
||||
async with f2b_db.execute(
|
||||
"SELECT CAST((timeofban - ?) / ? AS INTEGER) AS bucket_idx, "
|
||||
"COUNT(*) AS cnt "
|
||||
"FROM bans "
|
||||
"WHERE timeofban >= ?"
|
||||
+ origin_clause
|
||||
+ " GROUP BY bucket_idx "
|
||||
"ORDER BY bucket_idx",
|
||||
(since, bucket_secs, since, *origin_params),
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
from app.repositories.history_archive_repo import get_all_archived_history
|
||||
|
||||
# Map bucket_idx → count; ignore any out-of-range indices.
|
||||
counts: dict[int, int] = {}
|
||||
for row in rows:
|
||||
idx: int = int(row["bucket_idx"])
|
||||
if 0 <= idx < num_buckets:
|
||||
counts[idx] = int(row["cnt"])
|
||||
all_rows = await get_all_archived_history(
|
||||
db=app_db,
|
||||
since=since,
|
||||
origin=origin,
|
||||
action="ban",
|
||||
)
|
||||
|
||||
counts: list[int] = [0] * num_buckets
|
||||
for row in all_rows:
|
||||
timeofban = int(row["timeofban"])
|
||||
bucket_index = int((timeofban - since) / bucket_secs)
|
||||
if 0 <= bucket_index < num_buckets:
|
||||
counts[bucket_index] += 1
|
||||
|
||||
log.info(
|
||||
"ban_service_ban_trend",
|
||||
source=source,
|
||||
since=since,
|
||||
range=range_,
|
||||
origin=origin,
|
||||
bucket_secs=bucket_secs,
|
||||
num_buckets=num_buckets,
|
||||
)
|
||||
else:
|
||||
db_path: str = await get_fail2ban_db_path(socket_path)
|
||||
log.info(
|
||||
"ban_service_ban_trend",
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
range=range_,
|
||||
origin=origin,
|
||||
bucket_secs=bucket_secs,
|
||||
num_buckets=num_buckets,
|
||||
)
|
||||
|
||||
counts = await fail2ban_db_repo.get_ban_counts_by_bucket(
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
bucket_secs=bucket_secs,
|
||||
num_buckets=num_buckets,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
buckets: list[BanTrendBucket] = [
|
||||
BanTrendBucket(
|
||||
timestamp=_ts_to_iso(since + i * bucket_secs),
|
||||
count=counts.get(i, 0),
|
||||
timestamp=ts_to_iso(since + i * bucket_secs),
|
||||
count=counts[i],
|
||||
)
|
||||
for i in range(num_buckets)
|
||||
]
|
||||
@@ -610,6 +630,8 @@ async def bans_by_jail(
|
||||
socket_path: str,
|
||||
range_: TimeRange,
|
||||
*,
|
||||
source: str = "fail2ban",
|
||||
app_db: aiosqlite.Connection | None = None,
|
||||
origin: BanOrigin | None = None,
|
||||
) -> BansByJailResponse:
|
||||
"""Return ban counts aggregated per jail for the selected time window.
|
||||
@@ -631,62 +653,83 @@ async def bans_by_jail(
|
||||
sorted descending and the total ban count.
|
||||
"""
|
||||
since: int = _since_unix(range_)
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
|
||||
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||
log.debug(
|
||||
"ban_service_bans_by_jail",
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
since_iso=_ts_to_iso(since),
|
||||
range=range_,
|
||||
origin=origin,
|
||||
)
|
||||
if source not in ("fail2ban", "archive"):
|
||||
raise ValueError(f"Unsupported source: {source!r}")
|
||||
|
||||
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
|
||||
f2b_db.row_factory = aiosqlite.Row
|
||||
if source == "archive":
|
||||
if app_db is None:
|
||||
raise ValueError("app_db must be provided when source is 'archive'")
|
||||
|
||||
async with f2b_db.execute(
|
||||
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
|
||||
(since, *origin_params),
|
||||
) as cur:
|
||||
count_row = await cur.fetchone()
|
||||
total: int = int(count_row[0]) if count_row else 0
|
||||
from app.repositories.history_archive_repo import get_all_archived_history
|
||||
|
||||
# Diagnostic guard: if zero results were returned, check whether the
|
||||
# table has *any* rows and log a warning with min/max timeofban so
|
||||
# operators can diagnose timezone or filter mismatches from logs.
|
||||
all_rows = await get_all_archived_history(
|
||||
db=app_db,
|
||||
since=since,
|
||||
origin=origin,
|
||||
action="ban",
|
||||
)
|
||||
|
||||
jail_counter: dict[str, int] = {}
|
||||
for row in all_rows:
|
||||
jail_name = str(row["jail"])
|
||||
jail_counter[jail_name] = jail_counter.get(jail_name, 0) + 1
|
||||
|
||||
total = sum(jail_counter.values())
|
||||
jail_counts = [
|
||||
JailBanCountModel(jail=jail_name, count=count)
|
||||
for jail_name, count in sorted(jail_counter.items(), key=lambda x: x[1], reverse=True)
|
||||
]
|
||||
|
||||
log.debug(
|
||||
"ban_service_bans_by_jail",
|
||||
source=source,
|
||||
since=since,
|
||||
since_iso=ts_to_iso(since),
|
||||
range=range_,
|
||||
origin=origin,
|
||||
)
|
||||
else:
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
|
||||
db_path: str = await get_fail2ban_db_path(socket_path)
|
||||
log.debug(
|
||||
"ban_service_bans_by_jail",
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
since_iso=ts_to_iso(since),
|
||||
range=range_,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
total, jail_counts = await fail2ban_db_repo.get_bans_by_jail(
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
# Diagnostic guard: if zero results were returned, check whether the table
|
||||
# has *any* rows and log a warning with min/max timeofban so operators can
|
||||
# diagnose timezone or filter mismatches from logs.
|
||||
if total == 0:
|
||||
async with f2b_db.execute(
|
||||
"SELECT COUNT(*), MIN(timeofban), MAX(timeofban) FROM bans"
|
||||
) as cur:
|
||||
diag_row = await cur.fetchone()
|
||||
if diag_row and diag_row[0] > 0:
|
||||
table_row_count, min_timeofban, max_timeofban = await fail2ban_db_repo.get_bans_table_summary(db_path)
|
||||
if table_row_count > 0:
|
||||
log.warning(
|
||||
"ban_service_bans_by_jail_empty_despite_data",
|
||||
table_row_count=diag_row[0],
|
||||
min_timeofban=diag_row[1],
|
||||
max_timeofban=diag_row[2],
|
||||
table_row_count=table_row_count,
|
||||
min_timeofban=min_timeofban,
|
||||
max_timeofban=max_timeofban,
|
||||
since=since,
|
||||
range=range_,
|
||||
)
|
||||
|
||||
async with f2b_db.execute(
|
||||
"SELECT jail, COUNT(*) AS cnt "
|
||||
"FROM bans "
|
||||
"WHERE timeofban >= ?"
|
||||
+ origin_clause
|
||||
+ " GROUP BY jail ORDER BY cnt DESC",
|
||||
(since, *origin_params),
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
|
||||
jails: list[JailBanCount] = [
|
||||
JailBanCount(jail=str(row["jail"]), count=int(row["cnt"])) for row in rows
|
||||
]
|
||||
log.debug(
|
||||
"ban_service_bans_by_jail_result",
|
||||
total=total,
|
||||
jail_count=len(jails),
|
||||
jail_count=len(jail_counts),
|
||||
)
|
||||
|
||||
return BansByJailResponse(
|
||||
jails=[JailBanCountModel(jail=j.jail, count=j.count) for j in jail_counts],
|
||||
total=total,
|
||||
)
|
||||
return BansByJailResponse(jails=jails, total=total)
|
||||
|
||||
@@ -14,26 +14,35 @@ under the key ``"blocklist_schedule"``.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from collections.abc import Awaitable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
|
||||
from app.models.blocklist import (
|
||||
BlocklistSource,
|
||||
ImportLogEntry,
|
||||
ImportLogListResponse,
|
||||
ImportRunResult,
|
||||
ImportSourceResult,
|
||||
PreviewResponse,
|
||||
ScheduleConfig,
|
||||
ScheduleInfo,
|
||||
)
|
||||
from app.exceptions import JailNotFoundError
|
||||
from app.repositories import blocklist_repo, import_log_repo, settings_repo
|
||||
from app.utils.ip_utils import is_valid_ip, is_valid_network
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
|
||||
from app.models.geo import GeoBatchLookup
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
#: Settings key used to persist the schedule config.
|
||||
@@ -54,7 +63,7 @@ _PREVIEW_MAX_BYTES: int = 65536
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _row_to_source(row: dict[str, Any]) -> BlocklistSource:
|
||||
def _row_to_source(row: dict[str, object]) -> BlocklistSource:
|
||||
"""Convert a repository row dict to a :class:`BlocklistSource`.
|
||||
|
||||
Args:
|
||||
@@ -236,6 +245,9 @@ async def import_source(
|
||||
http_session: aiohttp.ClientSession,
|
||||
socket_path: str,
|
||||
db: aiosqlite.Connection,
|
||||
geo_is_cached: Callable[[str], bool] | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
ban_ip: Callable[[str, str, str], Awaitable[None]] | None = None,
|
||||
) -> ImportSourceResult:
|
||||
"""Download and apply bans from a single blocklist source.
|
||||
|
||||
@@ -293,8 +305,14 @@ async def import_source(
|
||||
ban_error: str | None = None
|
||||
imported_ips: list[str] = []
|
||||
|
||||
# Import jail_service here to avoid circular import at module level.
|
||||
from app.services import jail_service # noqa: PLC0415
|
||||
if ban_ip is None:
|
||||
try:
|
||||
jail_svc = importlib.import_module("app.services.jail_service")
|
||||
ban_ip_fn = jail_svc.ban_ip
|
||||
except (ModuleNotFoundError, AttributeError) as exc:
|
||||
raise ValueError("ban_ip callback is required") from exc
|
||||
else:
|
||||
ban_ip_fn = ban_ip
|
||||
|
||||
for line in content.splitlines():
|
||||
stripped = line.strip()
|
||||
@@ -307,10 +325,10 @@ async def import_source(
|
||||
continue
|
||||
|
||||
try:
|
||||
await jail_service.ban_ip(socket_path, BLOCKLIST_JAIL, stripped)
|
||||
await ban_ip_fn(socket_path, BLOCKLIST_JAIL, stripped)
|
||||
imported += 1
|
||||
imported_ips.append(stripped)
|
||||
except jail_service.JailNotFoundError as exc:
|
||||
except JailNotFoundError as exc:
|
||||
# The target jail does not exist in fail2ban — there is no point
|
||||
# continuing because every subsequent ban would also fail.
|
||||
ban_error = str(exc)
|
||||
@@ -337,12 +355,8 @@ async def import_source(
|
||||
)
|
||||
|
||||
# --- Pre-warm geo cache for newly imported IPs ---
|
||||
if imported_ips:
|
||||
from app.services import geo_service # noqa: PLC0415
|
||||
|
||||
uncached_ips: list[str] = [
|
||||
ip for ip in imported_ips if not geo_service.is_cached(ip)
|
||||
]
|
||||
if imported_ips and geo_is_cached is not None:
|
||||
uncached_ips: list[str] = [ip for ip in imported_ips if not geo_is_cached(ip)]
|
||||
skipped_geo: int = len(imported_ips) - len(uncached_ips)
|
||||
|
||||
if skipped_geo > 0:
|
||||
@@ -353,9 +367,9 @@ async def import_source(
|
||||
to_lookup=len(uncached_ips),
|
||||
)
|
||||
|
||||
if uncached_ips:
|
||||
if uncached_ips and geo_batch_lookup is not None:
|
||||
try:
|
||||
await geo_service.lookup_batch(uncached_ips, http_session, db=db)
|
||||
await geo_batch_lookup(uncached_ips, http_session, db=db)
|
||||
log.info(
|
||||
"blocklist_geo_prewarm_complete",
|
||||
source_id=source.id,
|
||||
@@ -381,6 +395,9 @@ async def import_all(
|
||||
db: aiosqlite.Connection,
|
||||
http_session: aiohttp.ClientSession,
|
||||
socket_path: str,
|
||||
geo_is_cached: Callable[[str], bool] | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
ban_ip: Callable[[str, str, str], Awaitable[None]] | None = None,
|
||||
) -> ImportRunResult:
|
||||
"""Import all enabled blocklist sources.
|
||||
|
||||
@@ -404,7 +421,15 @@ async def import_all(
|
||||
|
||||
for row in sources:
|
||||
source = _row_to_source(row)
|
||||
result = await import_source(source, http_session, socket_path, db)
|
||||
result = await import_source(
|
||||
source,
|
||||
http_session,
|
||||
socket_path,
|
||||
db,
|
||||
geo_is_cached=geo_is_cached,
|
||||
geo_batch_lookup=geo_batch_lookup,
|
||||
ban_ip=ban_ip,
|
||||
)
|
||||
results.append(result)
|
||||
total_imported += result.ips_imported
|
||||
total_skipped += result.ips_skipped
|
||||
@@ -503,12 +528,44 @@ async def get_schedule_info(
|
||||
)
|
||||
|
||||
|
||||
async def list_import_logs(
|
||||
db: aiosqlite.Connection,
|
||||
*,
|
||||
source_id: int | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> ImportLogListResponse:
|
||||
"""Return a paginated list of import log entries.
|
||||
|
||||
Args:
|
||||
db: Active application database connection.
|
||||
source_id: Optional filter to only return logs for a specific source.
|
||||
page: 1-based page number.
|
||||
page_size: Items per page.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.blocklist.ImportLogListResponse`.
|
||||
"""
|
||||
items, total = await import_log_repo.list_logs(
|
||||
db, source_id=source_id, page=page, page_size=page_size
|
||||
)
|
||||
total_pages = import_log_repo.compute_total_pages(total, page_size)
|
||||
|
||||
return ImportLogListResponse(
|
||||
items=[ImportLogEntry.model_validate(i) for i in items],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _aiohttp_timeout(seconds: float) -> Any:
|
||||
def _aiohttp_timeout(seconds: float) -> aiohttp.ClientTimeout:
|
||||
"""Return an :class:`aiohttp.ClientTimeout` with the given total timeout.
|
||||
|
||||
Args:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -16,32 +16,46 @@ import asyncio
|
||||
import contextlib
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, TypeVar, cast
|
||||
|
||||
import structlog
|
||||
|
||||
from app.utils.fail2ban_client import Fail2BanCommand, Fail2BanResponse, Fail2BanToken
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from app import __version__
|
||||
from app.exceptions import ConfigOperationError, ConfigValidationError, JailNotFoundError
|
||||
from app.models.config import (
|
||||
AddLogPathRequest,
|
||||
BantimeEscalation,
|
||||
Fail2BanLogResponse,
|
||||
GlobalConfigResponse,
|
||||
GlobalConfigUpdate,
|
||||
JailConfig,
|
||||
JailConfigListResponse,
|
||||
JailConfigResponse,
|
||||
JailConfigUpdate,
|
||||
LogPreviewLine,
|
||||
LogPreviewRequest,
|
||||
LogPreviewResponse,
|
||||
MapColorThresholdsResponse,
|
||||
MapColorThresholdsUpdate,
|
||||
RegexTestRequest,
|
||||
RegexTestResponse,
|
||||
ServiceStatusResponse,
|
||||
)
|
||||
from app.services import setup_service
|
||||
from app.utils.fail2ban_client import Fail2BanClient
|
||||
from app.utils.log_utils import preview_log as util_preview_log
|
||||
from app.utils.log_utils import test_regex as util_test_regex
|
||||
from app.utils.setup_utils import (
|
||||
get_map_color_thresholds as util_get_map_color_thresholds,
|
||||
)
|
||||
from app.utils.setup_utils import (
|
||||
set_map_color_thresholds as util_set_map_color_thresholds,
|
||||
)
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
@@ -51,26 +65,7 @@ _SOCKET_TIMEOUT: float = 10.0
|
||||
# Custom exceptions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class JailNotFoundError(Exception):
|
||||
"""Raised when a requested jail name does not exist in fail2ban."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""Initialise with the jail name that was not found.
|
||||
|
||||
Args:
|
||||
name: The jail name that could not be located.
|
||||
"""
|
||||
self.name: str = name
|
||||
super().__init__(f"Jail not found: {name!r}")
|
||||
|
||||
|
||||
class ConfigValidationError(Exception):
|
||||
"""Raised when a configuration value fails validation before writing."""
|
||||
|
||||
|
||||
class ConfigOperationError(Exception):
|
||||
"""Raised when a configuration write command fails."""
|
||||
# (exceptions are now defined in app.exceptions and imported above)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -78,7 +73,7 @@ class ConfigOperationError(Exception):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _ok(response: Any) -> Any:
|
||||
def _ok(response: object) -> object:
|
||||
"""Extract payload from a fail2ban ``(return_code, data)`` response.
|
||||
|
||||
Args:
|
||||
@@ -91,7 +86,7 @@ def _ok(response: Any) -> Any:
|
||||
ValueError: If the return code indicates an error.
|
||||
"""
|
||||
try:
|
||||
code, data = response
|
||||
code, data = cast("Fail2BanResponse", response)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
|
||||
if code != 0:
|
||||
@@ -99,11 +94,11 @@ def _ok(response: Any) -> Any:
|
||||
return data
|
||||
|
||||
|
||||
def _to_dict(pairs: Any) -> dict[str, Any]:
|
||||
def _to_dict(pairs: object) -> dict[str, object]:
|
||||
"""Convert a list of ``(key, value)`` pairs to a plain dict."""
|
||||
if not isinstance(pairs, (list, tuple)):
|
||||
return {}
|
||||
result: dict[str, Any] = {}
|
||||
result: dict[str, object] = {}
|
||||
for item in pairs:
|
||||
try:
|
||||
k, v = item
|
||||
@@ -113,7 +108,7 @@ def _to_dict(pairs: Any) -> dict[str, Any]:
|
||||
return result
|
||||
|
||||
|
||||
def _ensure_list(value: Any) -> list[str]:
|
||||
def _ensure_list(value: object | None) -> list[str]:
|
||||
"""Coerce a fail2ban ``get`` result to a list of strings."""
|
||||
if value is None:
|
||||
return []
|
||||
@@ -124,11 +119,14 @@ def _ensure_list(value: Any) -> list[str]:
|
||||
return [str(value)]
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
async def _safe_get(
|
||||
client: Fail2BanClient,
|
||||
command: list[Any],
|
||||
default: Any = None,
|
||||
) -> Any:
|
||||
command: Fail2BanCommand,
|
||||
default: object | None = None,
|
||||
) -> object | None:
|
||||
"""Send a command and return *default* if it fails."""
|
||||
try:
|
||||
return _ok(await client.send(command))
|
||||
@@ -136,6 +134,15 @@ async def _safe_get(
|
||||
return default
|
||||
|
||||
|
||||
async def _safe_get_typed[T](
|
||||
client: Fail2BanClient,
|
||||
command: Fail2BanCommand,
|
||||
default: T,
|
||||
) -> T:
|
||||
"""Send a command and return the result typed as ``default``'s type."""
|
||||
return cast("T", await _safe_get(client, command, default))
|
||||
|
||||
|
||||
def _is_not_found_error(exc: Exception) -> bool:
|
||||
"""Return ``True`` if *exc* signals an unknown jail."""
|
||||
msg = str(exc).lower()
|
||||
@@ -190,47 +197,25 @@ async def get_jail_config(socket_path: str, name: str) -> JailConfigResponse:
|
||||
raise JailNotFoundError(name) from exc
|
||||
raise
|
||||
|
||||
(
|
||||
bantime_raw,
|
||||
findtime_raw,
|
||||
maxretry_raw,
|
||||
failregex_raw,
|
||||
ignoreregex_raw,
|
||||
logpath_raw,
|
||||
datepattern_raw,
|
||||
logencoding_raw,
|
||||
backend_raw,
|
||||
usedns_raw,
|
||||
prefregex_raw,
|
||||
actions_raw,
|
||||
bt_increment_raw,
|
||||
bt_factor_raw,
|
||||
bt_formula_raw,
|
||||
bt_multipliers_raw,
|
||||
bt_maxtime_raw,
|
||||
bt_rndtime_raw,
|
||||
bt_overalljails_raw,
|
||||
) = await asyncio.gather(
|
||||
_safe_get(client, ["get", name, "bantime"], 600),
|
||||
_safe_get(client, ["get", name, "findtime"], 600),
|
||||
_safe_get(client, ["get", name, "maxretry"], 5),
|
||||
_safe_get(client, ["get", name, "failregex"], []),
|
||||
_safe_get(client, ["get", name, "ignoreregex"], []),
|
||||
_safe_get(client, ["get", name, "logpath"], []),
|
||||
_safe_get(client, ["get", name, "datepattern"], None),
|
||||
_safe_get(client, ["get", name, "logencoding"], "UTF-8"),
|
||||
_safe_get(client, ["get", name, "backend"], "polling"),
|
||||
_safe_get(client, ["get", name, "usedns"], "warn"),
|
||||
_safe_get(client, ["get", name, "prefregex"], ""),
|
||||
_safe_get(client, ["get", name, "actions"], []),
|
||||
_safe_get(client, ["get", name, "bantime.increment"], False),
|
||||
_safe_get(client, ["get", name, "bantime.factor"], None),
|
||||
_safe_get(client, ["get", name, "bantime.formula"], None),
|
||||
_safe_get(client, ["get", name, "bantime.multipliers"], None),
|
||||
_safe_get(client, ["get", name, "bantime.maxtime"], None),
|
||||
_safe_get(client, ["get", name, "bantime.rndtime"], None),
|
||||
_safe_get(client, ["get", name, "bantime.overalljails"], False),
|
||||
)
|
||||
bantime_raw: int = await _safe_get_typed(client, ["get", name, "bantime"], 600)
|
||||
findtime_raw: int = await _safe_get_typed(client, ["get", name, "findtime"], 600)
|
||||
maxretry_raw: int = await _safe_get_typed(client, ["get", name, "maxretry"], 5)
|
||||
failregex_raw: list[object] = await _safe_get_typed(client, ["get", name, "failregex"], [])
|
||||
ignoreregex_raw: list[object] = await _safe_get_typed(client, ["get", name, "ignoreregex"], [])
|
||||
logpath_raw: list[object] = await _safe_get_typed(client, ["get", name, "logpath"], [])
|
||||
datepattern_raw: str | None = await _safe_get_typed(client, ["get", name, "datepattern"], None)
|
||||
logencoding_raw: str = await _safe_get_typed(client, ["get", name, "logencoding"], "UTF-8")
|
||||
backend_raw: str = await _safe_get_typed(client, ["get", name, "backend"], "polling")
|
||||
usedns_raw: str = await _safe_get_typed(client, ["get", name, "usedns"], "warn")
|
||||
prefregex_raw: str = await _safe_get_typed(client, ["get", name, "prefregex"], "")
|
||||
actions_raw: list[object] = await _safe_get_typed(client, ["get", name, "actions"], [])
|
||||
bt_increment_raw: bool = await _safe_get_typed(client, ["get", name, "bantime.increment"], False)
|
||||
bt_factor_raw: str | float | None = await _safe_get_typed(client, ["get", name, "bantime.factor"], None)
|
||||
bt_formula_raw: str | None = await _safe_get_typed(client, ["get", name, "bantime.formula"], None)
|
||||
bt_multipliers_raw: str | None = await _safe_get_typed(client, ["get", name, "bantime.multipliers"], None)
|
||||
bt_maxtime_raw: str | int | None = await _safe_get_typed(client, ["get", name, "bantime.maxtime"], None)
|
||||
bt_rndtime_raw: str | int | None = await _safe_get_typed(client, ["get", name, "bantime.rndtime"], None)
|
||||
bt_overalljails_raw: bool = await _safe_get_typed(client, ["get", name, "bantime.overalljails"], False)
|
||||
|
||||
bantime_escalation = BantimeEscalation(
|
||||
increment=bool(bt_increment_raw),
|
||||
@@ -350,7 +335,7 @@ async def update_jail_config(
|
||||
raise JailNotFoundError(name) from exc
|
||||
raise
|
||||
|
||||
async def _set(key: str, value: Any) -> None:
|
||||
async def _set(key: str, value: Fail2BanToken) -> None:
|
||||
try:
|
||||
_ok(await client.send(["set", name, key, value]))
|
||||
except ValueError as exc:
|
||||
@@ -366,8 +351,8 @@ async def update_jail_config(
|
||||
await _set("datepattern", update.date_pattern)
|
||||
if update.dns_mode is not None:
|
||||
await _set("usedns", update.dns_mode)
|
||||
if update.backend is not None:
|
||||
await _set("backend", update.backend)
|
||||
# backend is managed by fail2ban and cannot be changed at runtime by API.
|
||||
# This field is therefore ignored during updates.
|
||||
if update.log_encoding is not None:
|
||||
await _set("logencoding", update.log_encoding)
|
||||
if update.prefregex is not None:
|
||||
@@ -420,7 +405,7 @@ async def _replace_regex_list(
|
||||
new_patterns: Replacement list (may be empty to clear).
|
||||
"""
|
||||
# Determine current count.
|
||||
current_raw = await _safe_get(client, ["get", jail, field], [])
|
||||
current_raw: list[object] = await _safe_get_typed(client, ["get", jail, field], [])
|
||||
current: list[str] = _ensure_list(current_raw)
|
||||
|
||||
del_cmd = f"del{field}"
|
||||
@@ -467,10 +452,10 @@ async def get_global_config(socket_path: str) -> GlobalConfigResponse:
|
||||
db_purge_age_raw,
|
||||
db_max_matches_raw,
|
||||
) = await asyncio.gather(
|
||||
_safe_get(client, ["get", "loglevel"], "INFO"),
|
||||
_safe_get(client, ["get", "logtarget"], "STDOUT"),
|
||||
_safe_get(client, ["get", "dbpurgeage"], 86400),
|
||||
_safe_get(client, ["get", "dbmaxmatches"], 10),
|
||||
_safe_get_typed(client, ["get", "loglevel"], "INFO"),
|
||||
_safe_get_typed(client, ["get", "logtarget"], "STDOUT"),
|
||||
_safe_get_typed(client, ["get", "dbpurgeage"], 86400),
|
||||
_safe_get_typed(client, ["get", "dbmaxmatches"], 10),
|
||||
)
|
||||
|
||||
return GlobalConfigResponse(
|
||||
@@ -494,7 +479,7 @@ async def update_global_config(socket_path: str, update: GlobalConfigUpdate) ->
|
||||
"""
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
|
||||
async def _set_global(key: str, value: Any) -> None:
|
||||
async def _set_global(key: str, value: Fail2BanToken) -> None:
|
||||
try:
|
||||
_ok(await client.send(["set", key, value]))
|
||||
except ValueError as exc:
|
||||
@@ -518,27 +503,8 @@ async def update_global_config(socket_path: str, update: GlobalConfigUpdate) ->
|
||||
|
||||
|
||||
def test_regex(request: RegexTestRequest) -> RegexTestResponse:
|
||||
"""Test a regex pattern against a sample log line.
|
||||
|
||||
This is a pure in-process operation — no socket communication occurs.
|
||||
|
||||
Args:
|
||||
request: The :class:`~app.models.config.RegexTestRequest` payload.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.RegexTestResponse` with match result.
|
||||
"""
|
||||
try:
|
||||
compiled = re.compile(request.fail_regex)
|
||||
except re.error as exc:
|
||||
return RegexTestResponse(matched=False, groups=[], error=str(exc))
|
||||
|
||||
match = compiled.search(request.log_line)
|
||||
if match is None:
|
||||
return RegexTestResponse(matched=False)
|
||||
|
||||
groups: list[str] = list(match.groups() or [])
|
||||
return RegexTestResponse(matched=True, groups=[str(g) for g in groups if g is not None])
|
||||
"""Proxy to log utilities for regex test without service imports."""
|
||||
return util_test_regex(request)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -616,101 +582,14 @@ async def delete_log_path(
|
||||
raise ConfigOperationError(f"Failed to delete log path {log_path!r}: {exc}") from exc
|
||||
|
||||
|
||||
async def preview_log(req: LogPreviewRequest) -> LogPreviewResponse:
|
||||
"""Read the last *num_lines* of a log file and test *fail_regex* against each.
|
||||
|
||||
This operation reads from the local filesystem — no socket is used.
|
||||
|
||||
Args:
|
||||
req: :class:`~app.models.config.LogPreviewRequest`.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.LogPreviewResponse` with line-by-line results.
|
||||
"""
|
||||
# Validate the regex first.
|
||||
try:
|
||||
compiled = re.compile(req.fail_regex)
|
||||
except re.error as exc:
|
||||
return LogPreviewResponse(
|
||||
lines=[],
|
||||
total_lines=0,
|
||||
matched_count=0,
|
||||
regex_error=str(exc),
|
||||
)
|
||||
|
||||
path = Path(req.log_path)
|
||||
if not path.is_file():
|
||||
return LogPreviewResponse(
|
||||
lines=[],
|
||||
total_lines=0,
|
||||
matched_count=0,
|
||||
regex_error=f"File not found: {req.log_path!r}",
|
||||
)
|
||||
|
||||
# Read the last num_lines lines efficiently.
|
||||
try:
|
||||
raw_lines = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
_read_tail_lines,
|
||||
str(path),
|
||||
req.num_lines,
|
||||
)
|
||||
except OSError as exc:
|
||||
return LogPreviewResponse(
|
||||
lines=[],
|
||||
total_lines=0,
|
||||
matched_count=0,
|
||||
regex_error=f"Cannot read file: {exc}",
|
||||
)
|
||||
|
||||
result_lines: list[LogPreviewLine] = []
|
||||
matched_count = 0
|
||||
for line in raw_lines:
|
||||
m = compiled.search(line)
|
||||
groups = [str(g) for g in (m.groups() or []) if g is not None] if m else []
|
||||
result_lines.append(LogPreviewLine(line=line, matched=(m is not None), groups=groups))
|
||||
if m:
|
||||
matched_count += 1
|
||||
|
||||
return LogPreviewResponse(
|
||||
lines=result_lines,
|
||||
total_lines=len(result_lines),
|
||||
matched_count=matched_count,
|
||||
)
|
||||
|
||||
|
||||
def _read_tail_lines(file_path: str, num_lines: int) -> list[str]:
|
||||
"""Read the last *num_lines* from *file_path* synchronously.
|
||||
|
||||
Uses a memory-efficient approach that seeks from the end of the file.
|
||||
|
||||
Args:
|
||||
file_path: Absolute path to the log file.
|
||||
num_lines: Number of lines to return.
|
||||
|
||||
Returns:
|
||||
A list of stripped line strings.
|
||||
"""
|
||||
chunk_size = 8192
|
||||
raw_lines: list[bytes] = []
|
||||
with open(file_path, "rb") as fh:
|
||||
fh.seek(0, 2) # seek to end
|
||||
end_pos = fh.tell()
|
||||
if end_pos == 0:
|
||||
return []
|
||||
buf = b""
|
||||
pos = end_pos
|
||||
while len(raw_lines) <= num_lines and pos > 0:
|
||||
read_size = min(chunk_size, pos)
|
||||
pos -= read_size
|
||||
fh.seek(pos)
|
||||
chunk = fh.read(read_size)
|
||||
buf = chunk + buf
|
||||
raw_lines = buf.split(b"\n")
|
||||
# Strip incomplete leading line unless we've read the whole file.
|
||||
if pos > 0 and len(raw_lines) > 1:
|
||||
raw_lines = raw_lines[1:]
|
||||
return [ln.decode("utf-8", errors="replace").rstrip() for ln in raw_lines[-num_lines:] if ln.strip()]
|
||||
async def preview_log(
|
||||
req: LogPreviewRequest,
|
||||
preview_fn: Callable[[LogPreviewRequest], Awaitable[LogPreviewResponse]] | None = None,
|
||||
) -> LogPreviewResponse:
|
||||
"""Proxy to an injectable log preview function."""
|
||||
if preview_fn is None:
|
||||
preview_fn = util_preview_log
|
||||
return await preview_fn(req)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -727,7 +606,7 @@ async def get_map_color_thresholds(db: aiosqlite.Connection) -> MapColorThreshol
|
||||
Returns:
|
||||
A :class:`MapColorThresholdsResponse` containing the three threshold values.
|
||||
"""
|
||||
high, medium, low = await setup_service.get_map_color_thresholds(db)
|
||||
high, medium, low = await util_get_map_color_thresholds(db)
|
||||
return MapColorThresholdsResponse(
|
||||
threshold_high=high,
|
||||
threshold_medium=medium,
|
||||
@@ -748,9 +627,202 @@ async def update_map_color_thresholds(
|
||||
Raises:
|
||||
ValueError: If validation fails (thresholds must satisfy high > medium > low).
|
||||
"""
|
||||
await setup_service.set_map_color_thresholds(
|
||||
await util_set_map_color_thresholds(
|
||||
db,
|
||||
threshold_high=update.threshold_high,
|
||||
threshold_medium=update.threshold_medium,
|
||||
threshold_low=update.threshold_low,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# fail2ban log file reader
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Log targets that are not file paths — log viewing is unavailable for these.
|
||||
_NON_FILE_LOG_TARGETS: frozenset[str] = frozenset(
|
||||
{"STDOUT", "STDERR", "SYSLOG", "SYSTEMD-JOURNAL"}
|
||||
)
|
||||
|
||||
# Only allow reading log files under these base directories (security).
|
||||
_SAFE_LOG_PREFIXES: tuple[str, ...] = ("/var/log", "/config/log")
|
||||
|
||||
|
||||
def _count_file_lines(file_path: str) -> int:
|
||||
"""Count the total number of lines in *file_path* synchronously."""
|
||||
count = 0
|
||||
with open(file_path, "rb") as fh:
|
||||
for chunk in iter(lambda: fh.read(65536), b""):
|
||||
count += chunk.count(b"\n")
|
||||
return count
|
||||
|
||||
|
||||
def _read_tail_lines(file_path: str, num_lines: int) -> list[str]:
|
||||
"""Read the last *num_lines* from *file_path* in a memory-efficient way."""
|
||||
chunk_size = 8192
|
||||
raw_lines: list[bytes] = []
|
||||
with open(file_path, "rb") as fh:
|
||||
fh.seek(0, 2)
|
||||
end_pos = fh.tell()
|
||||
if end_pos == 0:
|
||||
return []
|
||||
|
||||
buf = b""
|
||||
pos = end_pos
|
||||
while len(raw_lines) <= num_lines and pos > 0:
|
||||
read_size = min(chunk_size, pos)
|
||||
pos -= read_size
|
||||
fh.seek(pos)
|
||||
chunk = fh.read(read_size)
|
||||
buf = chunk + buf
|
||||
raw_lines = buf.split(b"\n")
|
||||
|
||||
if pos > 0 and len(raw_lines) > 1:
|
||||
raw_lines = raw_lines[1:]
|
||||
|
||||
return [ln.decode("utf-8", errors="replace").rstrip() for ln in raw_lines[-num_lines:] if ln.strip()]
|
||||
|
||||
|
||||
async def read_fail2ban_log(
|
||||
socket_path: str,
|
||||
lines: int,
|
||||
filter_text: str | None = None,
|
||||
) -> Fail2BanLogResponse:
|
||||
"""Read the tail of the fail2ban daemon log file.
|
||||
|
||||
Queries the fail2ban socket for the current log target and log level,
|
||||
validates that the target is a readable file, then returns the last
|
||||
*lines* entries optionally filtered by *filter_text*.
|
||||
|
||||
Security: the resolved log path is rejected unless it starts with one of
|
||||
the paths in :data:`_SAFE_LOG_PREFIXES`, preventing path traversal.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
lines: Number of lines to return from the tail of the file (1–2000).
|
||||
filter_text: Optional plain-text substring — only matching lines are
|
||||
returned. Applied server-side; does not affect *total_lines*.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.Fail2BanLogResponse`.
|
||||
|
||||
Raises:
|
||||
ConfigOperationError: When the log target is not a file, when the
|
||||
resolved path is outside the allowed directories, or when the
|
||||
file cannot be read.
|
||||
~app.utils.fail2ban_client.Fail2BanConnectionError: Socket unreachable.
|
||||
"""
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
|
||||
log_level_raw, log_target_raw = await asyncio.gather(
|
||||
_safe_get_typed(client, ["get", "loglevel"], "INFO"),
|
||||
_safe_get_typed(client, ["get", "logtarget"], "STDOUT"),
|
||||
)
|
||||
|
||||
log_level = str(log_level_raw or "INFO").upper()
|
||||
log_target = str(log_target_raw or "STDOUT")
|
||||
|
||||
# Reject non-file targets up front.
|
||||
if log_target.upper() in _NON_FILE_LOG_TARGETS:
|
||||
raise ConfigOperationError(
|
||||
f"fail2ban is logging to {log_target!r}. "
|
||||
"File-based log viewing is only available when fail2ban logs to a file path."
|
||||
)
|
||||
|
||||
# Resolve and validate (security: no path traversal outside safe dirs).
|
||||
try:
|
||||
resolved = Path(log_target).resolve()
|
||||
except (ValueError, OSError) as exc:
|
||||
raise ConfigOperationError(
|
||||
f"Cannot resolve log target path {log_target!r}: {exc}"
|
||||
) from exc
|
||||
|
||||
resolved_str = str(resolved)
|
||||
if not any(resolved_str.startswith(safe) for safe in _SAFE_LOG_PREFIXES):
|
||||
raise ConfigOperationError(
|
||||
f"Log path {resolved_str!r} is outside the allowed directory. "
|
||||
"Only paths under /var/log or /config/log are permitted."
|
||||
)
|
||||
|
||||
if not resolved.is_file():
|
||||
raise ConfigOperationError(f"Log file not found: {resolved_str!r}")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
total_lines, raw_lines = await asyncio.gather(
|
||||
loop.run_in_executor(None, _count_file_lines, resolved_str),
|
||||
loop.run_in_executor(None, _read_tail_lines, resolved_str, lines),
|
||||
)
|
||||
|
||||
filtered = (
|
||||
[ln for ln in raw_lines if filter_text in ln]
|
||||
if filter_text
|
||||
else raw_lines
|
||||
)
|
||||
|
||||
log.info(
|
||||
"fail2ban_log_read",
|
||||
log_path=resolved_str,
|
||||
lines_requested=lines,
|
||||
lines_returned=len(filtered),
|
||||
filter_active=filter_text is not None,
|
||||
)
|
||||
|
||||
return Fail2BanLogResponse(
|
||||
log_path=resolved_str,
|
||||
lines=filtered,
|
||||
total_lines=total_lines,
|
||||
log_level=log_level,
|
||||
log_target=log_target,
|
||||
)
|
||||
|
||||
|
||||
async def get_service_status(
|
||||
socket_path: str,
|
||||
probe_fn: Callable[[str], Awaitable[ServiceStatusResponse]] | None = None,
|
||||
) -> ServiceStatusResponse:
|
||||
"""Return fail2ban service health status with log configuration.
|
||||
|
||||
Delegates to an injectable *probe_fn* (defaults to
|
||||
:func:`~app.services.health_service.probe`). This avoids direct service-to-
|
||||
service imports inside this module.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
probe_fn: Optional probe function.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.ServiceStatusResponse`.
|
||||
"""
|
||||
if probe_fn is None:
|
||||
raise ValueError("probe_fn is required to avoid service-to-service coupling")
|
||||
|
||||
server_status = await probe_fn(socket_path)
|
||||
|
||||
if server_status.online:
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
log_level_raw, log_target_raw = await asyncio.gather(
|
||||
_safe_get_typed(client, ["get", "loglevel"], "INFO"),
|
||||
_safe_get_typed(client, ["get", "logtarget"], "STDOUT"),
|
||||
)
|
||||
log_level = str(log_level_raw or "INFO").upper()
|
||||
log_target = str(log_target_raw or "STDOUT")
|
||||
else:
|
||||
log_level = "UNKNOWN"
|
||||
log_target = "UNKNOWN"
|
||||
|
||||
log.info(
|
||||
"service_status_fetched",
|
||||
online=server_status.online,
|
||||
jail_count=server_status.active_jails,
|
||||
)
|
||||
|
||||
return ServiceStatusResponse(
|
||||
online=server_status.online,
|
||||
version=__version__,
|
||||
jail_count=server_status.active_jails,
|
||||
total_bans=server_status.total_bans,
|
||||
total_failures=server_status.total_failures,
|
||||
log_level=log_level,
|
||||
log_target=log_target,
|
||||
)
|
||||
|
||||
926
backend/app/services/filter_config_service.py
Normal file
926
backend/app/services/filter_config_service.py
Normal file
@@ -0,0 +1,926 @@
|
||||
"""Filter configuration management for BanGUI.
|
||||
|
||||
Handles parsing, validation, and lifecycle operations (create/update/delete)
|
||||
for fail2ban filter configurations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import configparser
|
||||
import contextlib
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import structlog
|
||||
|
||||
from app.exceptions import FilterInvalidRegexError
|
||||
from app.models.config import (
|
||||
AssignFilterRequest,
|
||||
FilterConfig,
|
||||
FilterConfigUpdate,
|
||||
FilterCreateRequest,
|
||||
FilterListResponse,
|
||||
FilterUpdateRequest,
|
||||
)
|
||||
from app.services.config_file_service import _TRUE_VALUES, ConfigWriteError, JailNotFoundInConfigError
|
||||
from app.utils import conffile_parser
|
||||
from app.utils.config_file_utils import (
|
||||
_get_active_jail_names,
|
||||
_parse_jails_sync,
|
||||
)
|
||||
from app.utils.jail_utils import reload_jails
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Custom exceptions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class FilterNotFoundError(Exception):
|
||||
"""Raised when the requested filter name is not found in ``filter.d/``."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""Initialise with the filter name that was not found.
|
||||
|
||||
Args:
|
||||
name: The filter name that could not be located.
|
||||
"""
|
||||
self.name: str = name
|
||||
super().__init__(f"Filter not found: {name!r}")
|
||||
|
||||
|
||||
class FilterAlreadyExistsError(Exception):
|
||||
"""Raised when trying to create a filter whose ``.conf`` or ``.local`` already exists."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""Initialise with the filter name that already exists.
|
||||
|
||||
Args:
|
||||
name: The filter name that already exists.
|
||||
"""
|
||||
self.name: str = name
|
||||
super().__init__(f"Filter already exists: {name!r}")
|
||||
|
||||
|
||||
class FilterReadonlyError(Exception):
|
||||
"""Raised when trying to delete a shipped ``.conf`` filter with no ``.local`` override."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""Initialise with the filter name that cannot be deleted.
|
||||
|
||||
Args:
|
||||
name: The filter name that is read-only (shipped ``.conf`` only).
|
||||
"""
|
||||
self.name: str = name
|
||||
super().__init__(
|
||||
f"Filter {name!r} is a shipped default (.conf only); only user-created .local files can be deleted."
|
||||
)
|
||||
|
||||
|
||||
class FilterNameError(Exception):
|
||||
"""Raised when a filter name contains invalid characters."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Additional helper functions for this service
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class JailNameError(Exception):
|
||||
"""Raised when a jail name contains invalid characters."""
|
||||
|
||||
|
||||
_SAFE_FILTER_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
|
||||
_SAFE_JAIL_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
|
||||
|
||||
|
||||
def _safe_filter_name(name: str) -> str:
|
||||
"""Validate *name* and return it unchanged or raise :class:`FilterNameError`.
|
||||
|
||||
Args:
|
||||
name: Proposed filter name (without extension).
|
||||
|
||||
Returns:
|
||||
The name unchanged if valid.
|
||||
|
||||
Raises:
|
||||
FilterNameError: If *name* contains unsafe characters.
|
||||
"""
|
||||
if not _SAFE_FILTER_NAME_RE.match(name):
|
||||
raise FilterNameError(
|
||||
f"Filter name {name!r} contains invalid characters. "
|
||||
"Only alphanumeric characters, hyphens, underscores, and dots are "
|
||||
"allowed; must start with an alphanumeric character."
|
||||
)
|
||||
return name
|
||||
|
||||
|
||||
def _safe_jail_name(name: str) -> str:
|
||||
"""Validate *name* and return it unchanged or raise :class:`JailNameError`.
|
||||
|
||||
Args:
|
||||
name: Proposed jail name.
|
||||
|
||||
Returns:
|
||||
The name unchanged if valid.
|
||||
|
||||
Raises:
|
||||
JailNameError: If *name* contains unsafe characters.
|
||||
"""
|
||||
if not _SAFE_JAIL_NAME_RE.match(name):
|
||||
raise JailNameError(
|
||||
f"Jail name {name!r} contains invalid characters. "
|
||||
"Only alphanumeric characters, hyphens, underscores, and dots are "
|
||||
"allowed; must start with an alphanumeric character."
|
||||
)
|
||||
return name
|
||||
|
||||
|
||||
def _build_parser() -> configparser.RawConfigParser:
|
||||
"""Create a :class:`configparser.RawConfigParser` for fail2ban configs.
|
||||
|
||||
Returns:
|
||||
Parser with interpolation disabled and case-sensitive option names.
|
||||
"""
|
||||
parser = configparser.RawConfigParser(interpolation=None, strict=False)
|
||||
# fail2ban keys are lowercase but preserve case to be safe.
|
||||
parser.optionxform = str # type: ignore[assignment]
|
||||
return parser
|
||||
|
||||
|
||||
def _is_truthy(value: str) -> bool:
|
||||
"""Return ``True`` if *value* is a fail2ban boolean true string.
|
||||
|
||||
Args:
|
||||
value: Raw string from config (e.g. ``"true"``, ``"yes"``, ``"1"``).
|
||||
|
||||
Returns:
|
||||
``True`` when the value represents enabled.
|
||||
"""
|
||||
return value.strip().lower() in _TRUE_VALUES
|
||||
|
||||
|
||||
def _parse_multiline(raw: str) -> list[str]:
|
||||
"""Split a multi-line INI value into individual non-blank lines.
|
||||
|
||||
Args:
|
||||
raw: Raw multi-line string from configparser.
|
||||
|
||||
Returns:
|
||||
List of stripped, non-empty, non-comment strings.
|
||||
"""
|
||||
result: list[str] = []
|
||||
for line in raw.splitlines():
|
||||
stripped = line.strip()
|
||||
if stripped and not stripped.startswith("#"):
|
||||
result.append(stripped)
|
||||
return result
|
||||
|
||||
|
||||
def _resolve_filter(raw_filter: str, jail_name: str, mode: str) -> str:
|
||||
"""Resolve fail2ban variable placeholders in a filter string.
|
||||
|
||||
Handles the common default ``%(__name__)s[mode=%(mode)s]`` pattern that
|
||||
fail2ban uses so the filter name displayed to the user is readable.
|
||||
|
||||
Args:
|
||||
raw_filter: Raw ``filter`` value from config (may contain ``%()s``).
|
||||
jail_name: The jail's section name, used to substitute ``%(__name__)s``.
|
||||
mode: The jail's ``mode`` value, used to substitute ``%(mode)s``.
|
||||
|
||||
Returns:
|
||||
Human-readable filter string.
|
||||
"""
|
||||
result = raw_filter.replace("%(__name__)s", jail_name)
|
||||
result = result.replace("%(mode)s", mode)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers - from config_file_service for local use
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _set_jail_local_key_sync(
|
||||
config_dir: Path,
|
||||
jail_name: str,
|
||||
key: str,
|
||||
value: str,
|
||||
) -> None:
|
||||
"""Update ``jail.d/{jail_name}.local`` to set a single key in the jail section.
|
||||
|
||||
If the ``.local`` file already exists it is read, the key is updated (or
|
||||
added), and the file is written back atomically without disturbing other
|
||||
settings. If the file does not exist a new one is created containing
|
||||
only the BanGUI header comment, the jail section, and the requested key.
|
||||
|
||||
Args:
|
||||
config_dir: The fail2ban configuration root directory.
|
||||
jail_name: Validated jail name (used as section name and filename stem).
|
||||
key: Config key to set inside the jail section.
|
||||
value: Config value to assign.
|
||||
|
||||
Raises:
|
||||
ConfigWriteError: If writing fails.
|
||||
"""
|
||||
jail_d = config_dir / "jail.d"
|
||||
try:
|
||||
jail_d.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc
|
||||
|
||||
local_path = jail_d / f"{jail_name}.local"
|
||||
|
||||
parser = _build_parser()
|
||||
if local_path.is_file():
|
||||
try:
|
||||
parser.read(str(local_path), encoding="utf-8")
|
||||
except (configparser.Error, OSError) as exc:
|
||||
log.warning(
|
||||
"jail_local_read_for_update_error",
|
||||
jail=jail_name,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
if not parser.has_section(jail_name):
|
||||
parser.add_section(jail_name)
|
||||
parser.set(jail_name, key, value)
|
||||
|
||||
# Serialize: write a BanGUI header then the parser output.
|
||||
buf = io.StringIO()
|
||||
buf.write("# Managed by BanGUI — do not edit manually\n\n")
|
||||
parser.write(buf)
|
||||
content = buf.getvalue()
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
encoding="utf-8",
|
||||
dir=jail_d,
|
||||
delete=False,
|
||||
suffix=".tmp",
|
||||
) as tmp:
|
||||
tmp.write(content)
|
||||
tmp_name = tmp.name
|
||||
os.replace(tmp_name, local_path)
|
||||
except OSError as exc:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(tmp_name) # noqa: F821
|
||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
||||
|
||||
log.info(
|
||||
"jail_local_key_set",
|
||||
jail=jail_name,
|
||||
key=key,
|
||||
path=str(local_path),
|
||||
)
|
||||
|
||||
|
||||
def _extract_filter_base_name(filter_raw: str) -> str:
|
||||
"""Extract the base filter name from a raw fail2ban filter string.
|
||||
|
||||
fail2ban jail configs may specify a filter with an optional mode suffix,
|
||||
e.g. ``sshd``, ``sshd[mode=aggressive]``, or
|
||||
``%(__name__)s[mode=%(mode)s]``. This function strips the ``[…]`` mode
|
||||
block and any leading/trailing whitespace to return just the file-system
|
||||
base name used to look up ``filter.d/{name}.conf``.
|
||||
|
||||
Args:
|
||||
filter_raw: Raw ``filter`` value from a jail config (already
|
||||
with ``%(__name__)s`` substituted by the caller).
|
||||
|
||||
Returns:
|
||||
Base filter name, e.g. ``"sshd"``.
|
||||
"""
|
||||
bracket = filter_raw.find("[")
|
||||
if bracket != -1:
|
||||
return filter_raw[:bracket].strip()
|
||||
return filter_raw.strip()
|
||||
|
||||
|
||||
def _build_filter_to_jails_map(
|
||||
all_jails: dict[str, dict[str, str]],
|
||||
active_names: set[str],
|
||||
) -> dict[str, list[str]]:
|
||||
"""Return a mapping of filter base name → list of active jail names.
|
||||
|
||||
Iterates over every jail whose name is in *active_names*, resolves its
|
||||
``filter`` config key, and records the jail against the base filter name.
|
||||
|
||||
Args:
|
||||
all_jails: Merged jail config dict — ``{jail_name: {key: value}}``.
|
||||
active_names: Set of jail names currently running in fail2ban.
|
||||
|
||||
Returns:
|
||||
``{filter_base_name: [jail_name, …]}``.
|
||||
"""
|
||||
mapping: dict[str, list[str]] = {}
|
||||
for jail_name, settings in all_jails.items():
|
||||
if jail_name not in active_names:
|
||||
continue
|
||||
raw_filter = settings.get("filter", "")
|
||||
mode = settings.get("mode", "normal")
|
||||
resolved = _resolve_filter(raw_filter, jail_name, mode) if raw_filter else jail_name
|
||||
base = _extract_filter_base_name(resolved)
|
||||
if base:
|
||||
mapping.setdefault(base, []).append(jail_name)
|
||||
return mapping
|
||||
|
||||
|
||||
def _parse_filters_sync(
|
||||
filter_d: Path,
|
||||
) -> list[tuple[str, str, str, bool, str]]:
|
||||
"""Synchronously scan ``filter.d/`` and return per-filter tuples.
|
||||
|
||||
Each tuple contains:
|
||||
|
||||
- ``name`` — filter base name (``"sshd"``).
|
||||
- ``filename`` — actual filename (``"sshd.conf"`` or ``"sshd.local"``).
|
||||
- ``content`` — merged file content (``conf`` overridden by ``local``).
|
||||
- ``has_local`` — whether a ``.local`` override exists alongside a ``.conf``.
|
||||
- ``source_path`` — absolute path to the primary (``conf``) source file, or
|
||||
to the ``.local`` file for user-created (local-only) filters.
|
||||
|
||||
Also discovers ``.local``-only files (user-created filters with no
|
||||
corresponding ``.conf``). These are returned with ``has_local = False``
|
||||
and ``source_path`` pointing to the ``.local`` file itself.
|
||||
|
||||
Args:
|
||||
filter_d: Path to the ``filter.d`` directory.
|
||||
|
||||
Returns:
|
||||
List of ``(name, filename, content, has_local, source_path)`` tuples,
|
||||
sorted by name.
|
||||
"""
|
||||
if not filter_d.is_dir():
|
||||
log.warning("filter_d_not_found", path=str(filter_d))
|
||||
return []
|
||||
|
||||
conf_names: set[str] = set()
|
||||
results: list[tuple[str, str, str, bool, str]] = []
|
||||
|
||||
# ---- .conf-based filters (with optional .local override) ----------------
|
||||
for conf_path in sorted(filter_d.glob("*.conf")):
|
||||
if not conf_path.is_file():
|
||||
continue
|
||||
name = conf_path.stem
|
||||
filename = conf_path.name
|
||||
conf_names.add(name)
|
||||
local_path = conf_path.with_suffix(".local")
|
||||
has_local = local_path.is_file()
|
||||
|
||||
try:
|
||||
content = conf_path.read_text(encoding="utf-8")
|
||||
except OSError as exc:
|
||||
log.warning("filter_read_error", name=name, path=str(conf_path), error=str(exc))
|
||||
continue
|
||||
|
||||
if has_local:
|
||||
try:
|
||||
local_content = local_path.read_text(encoding="utf-8")
|
||||
# Append local content after conf so configparser reads local
|
||||
# values last (higher priority).
|
||||
content = content + "\n" + local_content
|
||||
except OSError as exc:
|
||||
log.warning(
|
||||
"filter_local_read_error",
|
||||
name=name,
|
||||
path=str(local_path),
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
results.append((name, filename, content, has_local, str(conf_path)))
|
||||
|
||||
# ---- .local-only filters (user-created, no corresponding .conf) ----------
|
||||
for local_path in sorted(filter_d.glob("*.local")):
|
||||
if not local_path.is_file():
|
||||
continue
|
||||
name = local_path.stem
|
||||
if name in conf_names:
|
||||
# Already covered above as a .conf filter with a .local override.
|
||||
continue
|
||||
try:
|
||||
content = local_path.read_text(encoding="utf-8")
|
||||
except OSError as exc:
|
||||
log.warning(
|
||||
"filter_local_read_error",
|
||||
name=name,
|
||||
path=str(local_path),
|
||||
error=str(exc),
|
||||
)
|
||||
continue
|
||||
results.append((name, local_path.name, content, False, str(local_path)))
|
||||
|
||||
results.sort(key=lambda t: t[0])
|
||||
log.debug("filters_scanned", count=len(results), filter_d=str(filter_d))
|
||||
return results
|
||||
|
||||
|
||||
def _validate_regex_patterns(patterns: list[str]) -> None:
|
||||
"""Validate each pattern in *patterns* using Python's ``re`` module.
|
||||
|
||||
Args:
|
||||
patterns: List of regex strings to validate.
|
||||
|
||||
Raises:
|
||||
FilterInvalidRegexError: If any pattern fails to compile.
|
||||
"""
|
||||
for pattern in patterns:
|
||||
try:
|
||||
re.compile(pattern)
|
||||
except re.error as exc:
|
||||
raise FilterInvalidRegexError(pattern, str(exc)) from exc
|
||||
|
||||
|
||||
def _write_filter_local_sync(filter_d: Path, name: str, content: str) -> None:
|
||||
"""Write *content* to ``filter.d/{name}.local`` atomically.
|
||||
|
||||
The write is atomic: content is written to a temp file first, then
|
||||
renamed into place. The ``filter.d/`` directory is created if absent.
|
||||
|
||||
Args:
|
||||
filter_d: Path to the ``filter.d`` directory.
|
||||
name: Validated filter base name (used as filename stem).
|
||||
content: Full serialized filter content to write.
|
||||
|
||||
Raises:
|
||||
ConfigWriteError: If writing fails.
|
||||
"""
|
||||
try:
|
||||
filter_d.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(f"Cannot create filter.d directory: {exc}") from exc
|
||||
|
||||
local_path = filter_d / f"{name}.local"
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
encoding="utf-8",
|
||||
dir=filter_d,
|
||||
delete=False,
|
||||
suffix=".tmp",
|
||||
) as tmp:
|
||||
tmp.write(content)
|
||||
tmp_name = tmp.name
|
||||
os.replace(tmp_name, local_path)
|
||||
except OSError as exc:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(tmp_name) # noqa: F821
|
||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
||||
|
||||
log.info("filter_local_written", filter=name, path=str(local_path))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API — filter discovery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def list_filters(
|
||||
config_dir: str,
|
||||
socket_path: str,
|
||||
) -> FilterListResponse:
|
||||
"""Return all available filters from ``filter.d/`` with active/inactive status.
|
||||
|
||||
Scans ``{config_dir}/filter.d/`` for ``.conf`` files, merges any
|
||||
corresponding ``.local`` overrides, parses each file into a
|
||||
:class:`~app.models.config.FilterConfig`, and cross-references with the
|
||||
currently running jails to determine which filters are active.
|
||||
|
||||
A filter is considered *active* when its base name matches the ``filter``
|
||||
field of at least one currently running jail.
|
||||
|
||||
Args:
|
||||
config_dir: Absolute path to the fail2ban configuration directory.
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.FilterListResponse` with all filters
|
||||
sorted alphabetically, active ones carrying non-empty
|
||||
``used_by_jails`` lists.
|
||||
"""
|
||||
filter_d = Path(config_dir) / "filter.d"
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Run the synchronous scan in a thread-pool executor.
|
||||
raw_filters: list[tuple[str, str, str, bool, str]] = await loop.run_in_executor(None, _parse_filters_sync, filter_d)
|
||||
|
||||
# Fetch active jail names and their configs concurrently.
|
||||
all_jails_result, active_names = await asyncio.gather(
|
||||
loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)),
|
||||
_get_active_jail_names(socket_path),
|
||||
)
|
||||
all_jails, _source_files = all_jails_result
|
||||
|
||||
filter_to_jails = _build_filter_to_jails_map(all_jails, active_names)
|
||||
|
||||
filters: list[FilterConfig] = []
|
||||
for name, filename, content, has_local, source_path in raw_filters:
|
||||
cfg = conffile_parser.parse_filter_file(content, name=name, filename=filename)
|
||||
used_by = sorted(filter_to_jails.get(name, []))
|
||||
filters.append(
|
||||
FilterConfig(
|
||||
name=cfg.name,
|
||||
filename=cfg.filename,
|
||||
before=cfg.before,
|
||||
after=cfg.after,
|
||||
variables=cfg.variables,
|
||||
prefregex=cfg.prefregex,
|
||||
failregex=cfg.failregex,
|
||||
ignoreregex=cfg.ignoreregex,
|
||||
maxlines=cfg.maxlines,
|
||||
datepattern=cfg.datepattern,
|
||||
journalmatch=cfg.journalmatch,
|
||||
active=len(used_by) > 0,
|
||||
used_by_jails=used_by,
|
||||
source_file=source_path,
|
||||
has_local_override=has_local,
|
||||
)
|
||||
)
|
||||
|
||||
log.info("filters_listed", total=len(filters), active=sum(1 for f in filters if f.active))
|
||||
return FilterListResponse(filters=filters, total=len(filters))
|
||||
|
||||
|
||||
async def get_filter(
|
||||
config_dir: str,
|
||||
socket_path: str,
|
||||
name: str,
|
||||
) -> FilterConfig:
|
||||
"""Return a single filter from ``filter.d/`` with active/inactive status.
|
||||
|
||||
Reads ``{config_dir}/filter.d/{name}.conf``, merges any ``.local``
|
||||
override, and enriches the parsed :class:`~app.models.config.FilterConfig`
|
||||
with ``active``, ``used_by_jails``, ``source_file``, and
|
||||
``has_local_override``.
|
||||
|
||||
Args:
|
||||
config_dir: Absolute path to the fail2ban configuration directory.
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
name: Filter base name (e.g. ``"sshd"`` or ``"sshd.conf"``).
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.FilterConfig` with status fields populated.
|
||||
|
||||
Raises:
|
||||
FilterNotFoundError: If no ``{name}.conf`` or ``{name}.local`` file
|
||||
exists in ``filter.d/``.
|
||||
"""
|
||||
# Normalise — strip extension if provided (.conf=5 chars, .local=6 chars).
|
||||
if name.endswith(".conf"):
|
||||
base_name = name[:-5]
|
||||
elif name.endswith(".local"):
|
||||
base_name = name[:-6]
|
||||
else:
|
||||
base_name = name
|
||||
|
||||
filter_d = Path(config_dir) / "filter.d"
|
||||
conf_path = filter_d / f"{base_name}.conf"
|
||||
local_path = filter_d / f"{base_name}.local"
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def _read() -> tuple[str, bool, str]:
|
||||
"""Read filter content and return (content, has_local_override, source_path)."""
|
||||
has_local = local_path.is_file()
|
||||
if conf_path.is_file():
|
||||
content = conf_path.read_text(encoding="utf-8")
|
||||
if has_local:
|
||||
try:
|
||||
content += "\n" + local_path.read_text(encoding="utf-8")
|
||||
except OSError as exc:
|
||||
log.warning(
|
||||
"filter_local_read_error",
|
||||
name=base_name,
|
||||
path=str(local_path),
|
||||
error=str(exc),
|
||||
)
|
||||
return content, has_local, str(conf_path)
|
||||
elif has_local:
|
||||
# Local-only filter: created by the user, no shipped .conf base.
|
||||
content = local_path.read_text(encoding="utf-8")
|
||||
return content, False, str(local_path)
|
||||
else:
|
||||
raise FilterNotFoundError(base_name)
|
||||
|
||||
content, has_local, source_path = await loop.run_in_executor(None, _read)
|
||||
|
||||
cfg = conffile_parser.parse_filter_file(content, name=base_name, filename=f"{base_name}.conf")
|
||||
|
||||
all_jails_result, active_names = await asyncio.gather(
|
||||
loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)),
|
||||
_get_active_jail_names(socket_path),
|
||||
)
|
||||
all_jails, _source_files = all_jails_result
|
||||
filter_to_jails = _build_filter_to_jails_map(all_jails, active_names)
|
||||
|
||||
used_by = sorted(filter_to_jails.get(base_name, []))
|
||||
log.info("filter_fetched", name=base_name, active=len(used_by) > 0)
|
||||
return FilterConfig(
|
||||
name=cfg.name,
|
||||
filename=cfg.filename,
|
||||
before=cfg.before,
|
||||
after=cfg.after,
|
||||
variables=cfg.variables,
|
||||
prefregex=cfg.prefregex,
|
||||
failregex=cfg.failregex,
|
||||
ignoreregex=cfg.ignoreregex,
|
||||
maxlines=cfg.maxlines,
|
||||
datepattern=cfg.datepattern,
|
||||
journalmatch=cfg.journalmatch,
|
||||
active=len(used_by) > 0,
|
||||
used_by_jails=used_by,
|
||||
source_file=source_path,
|
||||
has_local_override=has_local,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API — filter write operations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def update_filter(
|
||||
config_dir: str,
|
||||
socket_path: str,
|
||||
name: str,
|
||||
req: FilterUpdateRequest,
|
||||
do_reload: bool = False,
|
||||
) -> FilterConfig:
|
||||
"""Update a filter's ``.local`` override with new regex/pattern values.
|
||||
|
||||
Reads the current merged configuration for *name* (``conf`` + any existing
|
||||
``local``), applies the non-``None`` fields in *req* on top of it, and
|
||||
writes the resulting definition to ``filter.d/{name}.local``. The
|
||||
original ``.conf`` file is never modified.
|
||||
|
||||
All regex patterns in *req* are validated with Python's ``re`` module
|
||||
before any write occurs.
|
||||
|
||||
Args:
|
||||
config_dir: Absolute path to the fail2ban configuration directory.
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
name: Filter base name (e.g. ``"sshd"`` or ``"sshd.conf"``).
|
||||
req: Partial update — only non-``None`` fields are applied.
|
||||
do_reload: When ``True``, trigger a full fail2ban reload after writing.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.FilterConfig` reflecting the updated state.
|
||||
|
||||
Raises:
|
||||
FilterNameError: If *name* contains invalid characters.
|
||||
FilterNotFoundError: If no ``{name}.conf`` or ``{name}.local`` exists.
|
||||
FilterInvalidRegexError: If any supplied regex pattern is invalid.
|
||||
ConfigWriteError: If writing the ``.local`` file fails.
|
||||
"""
|
||||
base_name = name[:-5] if name.endswith(".conf") or name.endswith(".local") else name
|
||||
_safe_filter_name(base_name)
|
||||
|
||||
# Validate regex patterns before touching the filesystem.
|
||||
patterns: list[str] = []
|
||||
if req.failregex is not None:
|
||||
patterns.extend(req.failregex)
|
||||
if req.ignoreregex is not None:
|
||||
patterns.extend(req.ignoreregex)
|
||||
_validate_regex_patterns(patterns)
|
||||
|
||||
# Fetch the current merged config (raises FilterNotFoundError if absent).
|
||||
current = await get_filter(config_dir, socket_path, base_name)
|
||||
|
||||
# Build a FilterConfigUpdate from the request fields.
|
||||
update = FilterConfigUpdate(
|
||||
failregex=req.failregex,
|
||||
ignoreregex=req.ignoreregex,
|
||||
datepattern=req.datepattern,
|
||||
journalmatch=req.journalmatch,
|
||||
)
|
||||
|
||||
merged = conffile_parser.merge_filter_update(current, update)
|
||||
content = conffile_parser.serialize_filter_config(merged)
|
||||
|
||||
filter_d = Path(config_dir) / "filter.d"
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, _write_filter_local_sync, filter_d, base_name, content)
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await reload_jails(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_filter_update_failed",
|
||||
filter=base_name,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
log.info("filter_updated", filter=base_name, reload=do_reload)
|
||||
return await get_filter(config_dir, socket_path, base_name)
|
||||
|
||||
|
||||
async def create_filter(
|
||||
config_dir: str,
|
||||
socket_path: str,
|
||||
req: FilterCreateRequest,
|
||||
do_reload: bool = False,
|
||||
) -> FilterConfig:
|
||||
"""Create a brand-new user-defined filter in ``filter.d/{name}.local``.
|
||||
|
||||
No ``.conf`` is written; fail2ban loads ``.local`` files directly. If a
|
||||
``.conf`` or ``.local`` file already exists for the requested name, a
|
||||
:class:`FilterAlreadyExistsError` is raised.
|
||||
|
||||
All regex patterns are validated with Python's ``re`` module before
|
||||
writing.
|
||||
|
||||
Args:
|
||||
config_dir: Absolute path to the fail2ban configuration directory.
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
req: Filter name and definition fields.
|
||||
do_reload: When ``True``, trigger a full fail2ban reload after writing.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.FilterConfig` for the newly created filter.
|
||||
|
||||
Raises:
|
||||
FilterNameError: If ``req.name`` contains invalid characters.
|
||||
FilterAlreadyExistsError: If a ``.conf`` or ``.local`` already exists.
|
||||
FilterInvalidRegexError: If any regex pattern is invalid.
|
||||
ConfigWriteError: If writing fails.
|
||||
"""
|
||||
_safe_filter_name(req.name)
|
||||
|
||||
filter_d = Path(config_dir) / "filter.d"
|
||||
conf_path = filter_d / f"{req.name}.conf"
|
||||
local_path = filter_d / f"{req.name}.local"
|
||||
|
||||
def _check_not_exists() -> None:
|
||||
if conf_path.is_file() or local_path.is_file():
|
||||
raise FilterAlreadyExistsError(req.name)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, _check_not_exists)
|
||||
|
||||
# Validate regex patterns.
|
||||
patterns: list[str] = list(req.failregex) + list(req.ignoreregex)
|
||||
_validate_regex_patterns(patterns)
|
||||
|
||||
# Build a FilterConfig and serialise it.
|
||||
cfg = FilterConfig(
|
||||
name=req.name,
|
||||
filename=f"{req.name}.local",
|
||||
failregex=req.failregex,
|
||||
ignoreregex=req.ignoreregex,
|
||||
prefregex=req.prefregex,
|
||||
datepattern=req.datepattern,
|
||||
journalmatch=req.journalmatch,
|
||||
)
|
||||
content = conffile_parser.serialize_filter_config(cfg)
|
||||
|
||||
await loop.run_in_executor(None, _write_filter_local_sync, filter_d, req.name, content)
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await reload_jails(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_filter_create_failed",
|
||||
filter=req.name,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
log.info("filter_created", filter=req.name, reload=do_reload)
|
||||
# Re-fetch to get the canonical FilterConfig (source_file, active, etc.).
|
||||
return await get_filter(config_dir, socket_path, req.name)
|
||||
|
||||
|
||||
async def delete_filter(
|
||||
config_dir: str,
|
||||
name: str,
|
||||
) -> None:
|
||||
"""Delete a user-created filter's ``.local`` file.
|
||||
|
||||
Deletion rules:
|
||||
- If only a ``.conf`` file exists (shipped default, no user override) →
|
||||
:class:`FilterReadonlyError`.
|
||||
- If a ``.local`` file exists (whether or not a ``.conf`` also exists) →
|
||||
the ``.local`` file is deleted. The shipped ``.conf`` is never touched.
|
||||
- If neither file exists → :class:`FilterNotFoundError`.
|
||||
|
||||
Args:
|
||||
config_dir: Absolute path to the fail2ban configuration directory.
|
||||
name: Filter base name (e.g. ``"sshd"``).
|
||||
|
||||
Raises:
|
||||
FilterNameError: If *name* contains invalid characters.
|
||||
FilterNotFoundError: If no filter file is found for *name*.
|
||||
FilterReadonlyError: If only a shipped ``.conf`` exists (no ``.local``).
|
||||
ConfigWriteError: If deletion of the ``.local`` file fails.
|
||||
"""
|
||||
base_name = name[:-5] if name.endswith(".conf") or name.endswith(".local") else name
|
||||
_safe_filter_name(base_name)
|
||||
|
||||
filter_d = Path(config_dir) / "filter.d"
|
||||
conf_path = filter_d / f"{base_name}.conf"
|
||||
local_path = filter_d / f"{base_name}.local"
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def _delete() -> None:
|
||||
has_conf = conf_path.is_file()
|
||||
has_local = local_path.is_file()
|
||||
|
||||
if not has_conf and not has_local:
|
||||
raise FilterNotFoundError(base_name)
|
||||
|
||||
if has_conf and not has_local:
|
||||
# Shipped default — nothing user-writable to remove.
|
||||
raise FilterReadonlyError(base_name)
|
||||
|
||||
try:
|
||||
local_path.unlink()
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(f"Failed to delete {local_path}: {exc}") from exc
|
||||
|
||||
log.info("filter_local_deleted", filter=base_name, path=str(local_path))
|
||||
|
||||
await loop.run_in_executor(None, _delete)
|
||||
|
||||
|
||||
async def assign_filter_to_jail(
|
||||
config_dir: str,
|
||||
socket_path: str,
|
||||
jail_name: str,
|
||||
req: AssignFilterRequest,
|
||||
do_reload: bool = False,
|
||||
) -> None:
|
||||
"""Assign a filter to a jail by updating the jail's ``.local`` file.
|
||||
|
||||
Writes ``filter = {req.filter_name}`` into the ``[{jail_name}]`` section
|
||||
of ``jail.d/{jail_name}.local``. If the ``.local`` file already contains
|
||||
other settings for this jail they are preserved.
|
||||
|
||||
Args:
|
||||
config_dir: Absolute path to the fail2ban configuration directory.
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
jail_name: Name of the jail to update.
|
||||
req: Request containing the filter name to assign.
|
||||
do_reload: When ``True``, trigger a full fail2ban reload after writing.
|
||||
|
||||
Raises:
|
||||
JailNameError: If *jail_name* contains invalid characters.
|
||||
FilterNameError: If ``req.filter_name`` contains invalid characters.
|
||||
JailNotFoundError: If *jail_name* is not defined in any config file.
|
||||
FilterNotFoundError: If ``req.filter_name`` does not exist in
|
||||
``filter.d/``.
|
||||
ConfigWriteError: If writing fails.
|
||||
"""
|
||||
_safe_jail_name(jail_name)
|
||||
_safe_filter_name(req.filter_name)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Verify the jail exists in config.
|
||||
all_jails, _src = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
||||
if jail_name not in all_jails:
|
||||
raise JailNotFoundInConfigError(jail_name)
|
||||
|
||||
# Verify the filter exists (conf or local).
|
||||
filter_d = Path(config_dir) / "filter.d"
|
||||
|
||||
def _check_filter() -> None:
|
||||
conf_exists = (filter_d / f"{req.filter_name}.conf").is_file()
|
||||
local_exists = (filter_d / f"{req.filter_name}.local").is_file()
|
||||
if not conf_exists and not local_exists:
|
||||
raise FilterNotFoundError(req.filter_name)
|
||||
|
||||
await loop.run_in_executor(None, _check_filter)
|
||||
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
_set_jail_local_key_sync,
|
||||
Path(config_dir),
|
||||
jail_name,
|
||||
"filter",
|
||||
req.filter_name,
|
||||
)
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await reload_jails(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_assign_filter_failed",
|
||||
jail=jail_name,
|
||||
filter=req.filter_name,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
log.info(
|
||||
"filter_assigned_to_jail",
|
||||
jail=jail_name,
|
||||
filter=req.filter_name,
|
||||
reload=do_reload,
|
||||
)
|
||||
@@ -20,9 +20,7 @@ Usage::
|
||||
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
from app.services import geo_service
|
||||
|
||||
# warm the cache from the persistent store at startup
|
||||
# Use the geo_service directly in application startup
|
||||
async with aiosqlite.connect("bangui.db") as db:
|
||||
await geo_service.load_cache_from_db(db)
|
||||
|
||||
@@ -30,7 +28,8 @@ Usage::
|
||||
# single lookup
|
||||
info = await geo_service.lookup("1.2.3.4", session)
|
||||
if info:
|
||||
print(info.country_code) # "DE"
|
||||
# info.country_code == "DE"
|
||||
... # use the GeoInfo object in your application
|
||||
|
||||
# bulk lookup (more efficient for large sets)
|
||||
geo_map = await geo_service.lookup_batch(["1.2.3.4", "5.6.7.8"], session)
|
||||
@@ -40,12 +39,14 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
import structlog
|
||||
|
||||
from app.models.geo import GeoInfo
|
||||
from app.repositories import geo_cache_repo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiosqlite
|
||||
import geoip2.database
|
||||
@@ -90,32 +91,6 @@ _BATCH_DELAY: float = 1.5
|
||||
#: transient error (e.g. connection reset due to rate limiting).
|
||||
_BATCH_MAX_RETRIES: int = 2
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Domain model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeoInfo:
|
||||
"""Geographical and network metadata for a single IP address.
|
||||
|
||||
All fields default to ``None`` when the information is unavailable or
|
||||
the lookup fails gracefully.
|
||||
"""
|
||||
|
||||
country_code: str | None
|
||||
"""ISO 3166-1 alpha-2 country code, e.g. ``"DE"``."""
|
||||
|
||||
country_name: str | None
|
||||
"""Human-readable country name, e.g. ``"Germany"``."""
|
||||
|
||||
asn: str | None
|
||||
"""Autonomous System Number string, e.g. ``"AS3320"``."""
|
||||
|
||||
org: str | None
|
||||
"""Organisation name associated with the IP, e.g. ``"Deutsche Telekom"``."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal cache
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -184,11 +159,7 @@ async def cache_stats(db: aiosqlite.Connection) -> dict[str, int]:
|
||||
Dict with keys ``cache_size``, ``unresolved``, ``neg_cache_size``,
|
||||
and ``dirty_size``.
|
||||
"""
|
||||
async with db.execute(
|
||||
"SELECT COUNT(*) FROM geo_cache WHERE country_code IS NULL"
|
||||
) as cur:
|
||||
row = await cur.fetchone()
|
||||
unresolved: int = int(row[0]) if row else 0
|
||||
unresolved = await geo_cache_repo.count_unresolved(db)
|
||||
|
||||
return {
|
||||
"cache_size": len(_cache),
|
||||
@@ -198,6 +169,24 @@ async def cache_stats(db: aiosqlite.Connection) -> dict[str, int]:
|
||||
}
|
||||
|
||||
|
||||
async def count_unresolved(db: aiosqlite.Connection) -> int:
|
||||
"""Return the number of unresolved entries in the persistent geo cache."""
|
||||
|
||||
return await geo_cache_repo.count_unresolved(db)
|
||||
|
||||
|
||||
async def get_unresolved_ips(db: aiosqlite.Connection) -> list[str]:
|
||||
"""Return geo cache IPs where the country code has not yet been resolved.
|
||||
|
||||
Args:
|
||||
db: Open BanGUI application database connection.
|
||||
|
||||
Returns:
|
||||
List of IP addresses that are candidates for re-resolution.
|
||||
"""
|
||||
return await geo_cache_repo.get_unresolved_ips(db)
|
||||
|
||||
|
||||
def init_geoip(mmdb_path: str | None) -> None:
|
||||
"""Initialise the MaxMind GeoLite2-Country database reader.
|
||||
|
||||
@@ -268,21 +257,18 @@ async def load_cache_from_db(db: aiosqlite.Connection) -> None:
|
||||
database (not the fail2ban database).
|
||||
"""
|
||||
count = 0
|
||||
async with db.execute(
|
||||
"SELECT ip, country_code, country_name, asn, org FROM geo_cache"
|
||||
) as cur:
|
||||
async for row in cur:
|
||||
ip: str = str(row[0])
|
||||
country_code: str | None = row[1]
|
||||
if country_code is None:
|
||||
continue
|
||||
_cache[ip] = GeoInfo(
|
||||
country_code=country_code,
|
||||
country_name=row[2],
|
||||
asn=row[3],
|
||||
org=row[4],
|
||||
)
|
||||
count += 1
|
||||
for row in await geo_cache_repo.load_all(db):
|
||||
country_code: str | None = row["country_code"]
|
||||
if country_code is None:
|
||||
continue
|
||||
ip: str = row["ip"]
|
||||
_cache[ip] = GeoInfo(
|
||||
country_code=country_code,
|
||||
country_name=row["country_name"],
|
||||
asn=row["asn"],
|
||||
org=row["org"],
|
||||
)
|
||||
count += 1
|
||||
log.info("geo_cache_loaded_from_db", entries=count)
|
||||
|
||||
|
||||
@@ -301,18 +287,13 @@ async def _persist_entry(
|
||||
ip: IP address string.
|
||||
info: Resolved geo data to persist.
|
||||
"""
|
||||
await db.execute(
|
||||
"""
|
||||
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(ip) DO UPDATE SET
|
||||
country_code = excluded.country_code,
|
||||
country_name = excluded.country_name,
|
||||
asn = excluded.asn,
|
||||
org = excluded.org,
|
||||
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
|
||||
""",
|
||||
(ip, info.country_code, info.country_name, info.asn, info.org),
|
||||
await geo_cache_repo.upsert_entry(
|
||||
db=db,
|
||||
ip=ip,
|
||||
country_code=info.country_code,
|
||||
country_name=info.country_name,
|
||||
asn=info.asn,
|
||||
org=info.org,
|
||||
)
|
||||
|
||||
|
||||
@@ -326,10 +307,7 @@ async def _persist_neg_entry(db: aiosqlite.Connection, ip: str) -> None:
|
||||
db: BanGUI application database connection.
|
||||
ip: IP address string whose resolution failed.
|
||||
"""
|
||||
await db.execute(
|
||||
"INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)",
|
||||
(ip,),
|
||||
)
|
||||
await geo_cache_repo.upsert_neg_entry(db=db, ip=ip)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -585,19 +563,7 @@ async def lookup_batch(
|
||||
if db is not None:
|
||||
if pos_rows:
|
||||
try:
|
||||
await db.executemany(
|
||||
"""
|
||||
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(ip) DO UPDATE SET
|
||||
country_code = excluded.country_code,
|
||||
country_name = excluded.country_name,
|
||||
asn = excluded.asn,
|
||||
org = excluded.org,
|
||||
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
|
||||
""",
|
||||
pos_rows,
|
||||
)
|
||||
await geo_cache_repo.bulk_upsert_entries(db, pos_rows)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"geo_batch_persist_failed",
|
||||
@@ -606,10 +572,7 @@ async def lookup_batch(
|
||||
)
|
||||
if neg_ips:
|
||||
try:
|
||||
await db.executemany(
|
||||
"INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)",
|
||||
[(ip,) for ip in neg_ips],
|
||||
)
|
||||
await geo_cache_repo.bulk_upsert_neg_entries(db, neg_ips)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"geo_batch_persist_neg_failed",
|
||||
@@ -792,19 +755,7 @@ async def flush_dirty(db: aiosqlite.Connection) -> int:
|
||||
return 0
|
||||
|
||||
try:
|
||||
await db.executemany(
|
||||
"""
|
||||
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(ip) DO UPDATE SET
|
||||
country_code = excluded.country_code,
|
||||
country_name = excluded.country_name,
|
||||
asn = excluded.asn,
|
||||
org = excluded.org,
|
||||
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
|
||||
""",
|
||||
rows,
|
||||
)
|
||||
await geo_cache_repo.bulk_upsert_entries(db, rows)
|
||||
await db.commit()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("geo_flush_dirty_failed", error=str(exc))
|
||||
|
||||
@@ -9,12 +9,17 @@ seconds by the background health-check task, not on every HTTP request.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import structlog
|
||||
|
||||
from app.models.server import ServerStatus
|
||||
from app.utils.fail2ban_client import Fail2BanClient, Fail2BanConnectionError, Fail2BanProtocolError
|
||||
from app.utils.fail2ban_client import (
|
||||
Fail2BanClient,
|
||||
Fail2BanConnectionError,
|
||||
Fail2BanProtocolError,
|
||||
Fail2BanResponse,
|
||||
)
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
@@ -25,7 +30,7 @@ log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
_SOCKET_TIMEOUT: float = 5.0
|
||||
|
||||
|
||||
def _ok(response: Any) -> Any:
|
||||
def _ok(response: object) -> object:
|
||||
"""Extract the payload from a fail2ban ``(return_code, data)`` response.
|
||||
|
||||
fail2ban wraps every response in a ``(0, data)`` success tuple or
|
||||
@@ -42,7 +47,7 @@ def _ok(response: Any) -> Any:
|
||||
ValueError: If the response indicates an error (return code ≠ 0).
|
||||
"""
|
||||
try:
|
||||
code, data = response
|
||||
code, data = cast("Fail2BanResponse", response)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
|
||||
|
||||
@@ -52,7 +57,7 @@ def _ok(response: Any) -> Any:
|
||||
return data
|
||||
|
||||
|
||||
def _to_dict(pairs: Any) -> dict[str, Any]:
|
||||
def _to_dict(pairs: object) -> dict[str, object]:
|
||||
"""Convert a list of ``(key, value)`` pairs to a plain dict.
|
||||
|
||||
fail2ban returns structured data as lists of 2-tuples rather than dicts.
|
||||
@@ -66,7 +71,7 @@ def _to_dict(pairs: Any) -> dict[str, Any]:
|
||||
"""
|
||||
if not isinstance(pairs, (list, tuple)):
|
||||
return {}
|
||||
result: dict[str, Any] = {}
|
||||
result: dict[str, object] = {}
|
||||
for item in pairs:
|
||||
try:
|
||||
k, v = item
|
||||
@@ -119,7 +124,7 @@ async def probe(socket_path: str, timeout: float = _SOCKET_TIMEOUT) -> ServerSta
|
||||
# 3. Global status — jail count and names #
|
||||
# ------------------------------------------------------------------ #
|
||||
status_data = _to_dict(_ok(await client.send(["status"])))
|
||||
active_jails: int = int(status_data.get("Number of jail", 0) or 0)
|
||||
active_jails: int = int(str(status_data.get("Number of jail", 0) or 0))
|
||||
jail_list_raw: str = str(status_data.get("Jail list", "") or "").strip()
|
||||
jail_names: list[str] = (
|
||||
[j.strip() for j in jail_list_raw.split(",") if j.strip()]
|
||||
@@ -138,8 +143,8 @@ async def probe(socket_path: str, timeout: float = _SOCKET_TIMEOUT) -> ServerSta
|
||||
jail_resp = _to_dict(_ok(await client.send(["status", jail_name])))
|
||||
filter_stats = _to_dict(jail_resp.get("Filter") or [])
|
||||
action_stats = _to_dict(jail_resp.get("Actions") or [])
|
||||
total_failures += int(filter_stats.get("Currently failed", 0) or 0)
|
||||
total_bans += int(action_stats.get("Currently banned", 0) or 0)
|
||||
total_failures += int(str(filter_stats.get("Currently failed", 0) or 0))
|
||||
total_bans += int(str(action_stats.get("Currently banned", 0) or 0))
|
||||
except (ValueError, TypeError, KeyError) as exc:
|
||||
log.warning(
|
||||
"fail2ban_jail_status_parse_error",
|
||||
|
||||
@@ -11,19 +11,24 @@ modifies or locks the fail2ban database.
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiosqlite
|
||||
import structlog
|
||||
|
||||
from app.models.ban import TIME_RANGE_SECONDS, TimeRange
|
||||
if TYPE_CHECKING:
|
||||
import aiosqlite
|
||||
|
||||
from app.models.geo import GeoEnricher
|
||||
|
||||
from app.models.ban import TIME_RANGE_SECONDS, BanOrigin, TimeRange
|
||||
from app.models.history import (
|
||||
HistoryBanItem,
|
||||
HistoryListResponse,
|
||||
IpDetailResponse,
|
||||
IpTimelineEvent,
|
||||
)
|
||||
from app.services.ban_service import _get_fail2ban_db_path, _parse_data_json, _ts_to_iso
|
||||
from app.repositories import fail2ban_db_repo
|
||||
from app.utils.fail2ban_db_utils import get_fail2ban_db_path, parse_data_json, ts_to_iso
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
@@ -59,9 +64,12 @@ async def list_history(
|
||||
range_: TimeRange | None = None,
|
||||
jail: str | None = None,
|
||||
ip_filter: str | None = None,
|
||||
origin: BanOrigin | None = None,
|
||||
source: str = "fail2ban",
|
||||
page: int = 1,
|
||||
page_size: int = _DEFAULT_PAGE_SIZE,
|
||||
geo_enricher: Any | None = None,
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
db: aiosqlite.Connection | None = None,
|
||||
) -> HistoryListResponse:
|
||||
"""Return a paginated list of historical ban records with optional filters.
|
||||
|
||||
@@ -84,28 +92,13 @@ async def list_history(
|
||||
and the total matching count.
|
||||
"""
|
||||
effective_page_size: int = min(page_size, _MAX_PAGE_SIZE)
|
||||
offset: int = (page - 1) * effective_page_size
|
||||
|
||||
# Build WHERE clauses dynamically.
|
||||
wheres: list[str] = []
|
||||
params: list[Any] = []
|
||||
|
||||
since: int | None = None
|
||||
if range_ is not None:
|
||||
since: int = _since_unix(range_)
|
||||
wheres.append("timeofban >= ?")
|
||||
params.append(since)
|
||||
since = _since_unix(range_)
|
||||
|
||||
if jail is not None:
|
||||
wheres.append("jail = ?")
|
||||
params.append(jail)
|
||||
|
||||
if ip_filter is not None:
|
||||
wheres.append("ip LIKE ?")
|
||||
params.append(f"{ip_filter}%")
|
||||
|
||||
where_sql: str = ("WHERE " + " AND ".join(wheres)) if wheres else ""
|
||||
|
||||
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||
db_path: str = await get_fail2ban_db_path(socket_path)
|
||||
log.info(
|
||||
"history_service_list",
|
||||
db_path=db_path,
|
||||
@@ -115,64 +108,111 @@ async def list_history(
|
||||
page=page,
|
||||
)
|
||||
|
||||
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
|
||||
f2b_db.row_factory = aiosqlite.Row
|
||||
|
||||
async with f2b_db.execute(
|
||||
f"SELECT COUNT(*) FROM bans {where_sql}", # noqa: S608
|
||||
params,
|
||||
) as cur:
|
||||
count_row = await cur.fetchone()
|
||||
total: int = int(count_row[0]) if count_row else 0
|
||||
|
||||
async with f2b_db.execute(
|
||||
f"SELECT jail, ip, timeofban, bancount, data " # noqa: S608
|
||||
f"FROM bans {where_sql} "
|
||||
"ORDER BY timeofban DESC "
|
||||
"LIMIT ? OFFSET ?",
|
||||
[*params, effective_page_size, offset],
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
|
||||
items: list[HistoryBanItem] = []
|
||||
for row in rows:
|
||||
jail_name: str = str(row["jail"])
|
||||
ip: str = str(row["ip"])
|
||||
banned_at: str = _ts_to_iso(int(row["timeofban"]))
|
||||
ban_count: int = int(row["bancount"])
|
||||
matches, failures = _parse_data_json(row["data"])
|
||||
total: int
|
||||
|
||||
country_code: str | None = None
|
||||
country_name: str | None = None
|
||||
asn: str | None = None
|
||||
org: str | None = None
|
||||
if source == "archive":
|
||||
if db is None:
|
||||
raise ValueError("db must be provided when source is 'archive'")
|
||||
|
||||
if geo_enricher is not None:
|
||||
try:
|
||||
geo = await geo_enricher(ip)
|
||||
if geo is not None:
|
||||
country_code = geo.country_code
|
||||
country_name = geo.country_name
|
||||
asn = geo.asn
|
||||
org = geo.org
|
||||
except Exception: # noqa: BLE001
|
||||
log.warning("history_service_geo_lookup_failed", ip=ip)
|
||||
from app.repositories.history_archive_repo import get_archived_history
|
||||
|
||||
items.append(
|
||||
HistoryBanItem(
|
||||
ip=ip,
|
||||
jail=jail_name,
|
||||
banned_at=banned_at,
|
||||
ban_count=ban_count,
|
||||
failures=failures,
|
||||
matches=matches,
|
||||
country_code=country_code,
|
||||
country_name=country_name,
|
||||
asn=asn,
|
||||
org=org,
|
||||
)
|
||||
archived_rows, total = await get_archived_history(
|
||||
db=db,
|
||||
since=since,
|
||||
jail=jail,
|
||||
ip_filter=ip_filter,
|
||||
page=page,
|
||||
page_size=effective_page_size,
|
||||
)
|
||||
|
||||
for row in archived_rows:
|
||||
jail_name = row["jail"]
|
||||
ip = row["ip"]
|
||||
banned_at = ts_to_iso(int(row["timeofban"]))
|
||||
ban_count = int(row["bancount"])
|
||||
matches, failures = parse_data_json(row["data"])
|
||||
# archive records may include actions; we treat all as history
|
||||
|
||||
country_code = None
|
||||
country_name = None
|
||||
asn = None
|
||||
org = None
|
||||
|
||||
if geo_enricher is not None:
|
||||
try:
|
||||
geo = await geo_enricher(ip)
|
||||
if geo is not None:
|
||||
country_code = geo.country_code
|
||||
country_name = geo.country_name
|
||||
asn = geo.asn
|
||||
org = geo.org
|
||||
except Exception: # noqa: BLE001
|
||||
log.warning("history_service_geo_lookup_failed", ip=ip)
|
||||
|
||||
items.append(
|
||||
HistoryBanItem(
|
||||
ip=ip,
|
||||
jail=jail_name,
|
||||
banned_at=banned_at,
|
||||
ban_count=ban_count,
|
||||
failures=failures,
|
||||
matches=matches,
|
||||
country_code=country_code,
|
||||
country_name=country_name,
|
||||
asn=asn,
|
||||
org=org,
|
||||
)
|
||||
)
|
||||
else:
|
||||
rows, total = await fail2ban_db_repo.get_history_page(
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
jail=jail,
|
||||
ip_filter=ip_filter,
|
||||
origin=origin,
|
||||
page=page,
|
||||
page_size=effective_page_size,
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
jail_name: str = row.jail
|
||||
ip: str = row.ip
|
||||
banned_at: str = ts_to_iso(row.timeofban)
|
||||
ban_count: int = row.bancount
|
||||
matches, failures = parse_data_json(row.data)
|
||||
|
||||
country_code: str | None = None
|
||||
country_name: str | None = None
|
||||
asn: str | None = None
|
||||
org: str | None = None
|
||||
|
||||
if geo_enricher is not None:
|
||||
try:
|
||||
geo = await geo_enricher(ip)
|
||||
if geo is not None:
|
||||
country_code = geo.country_code
|
||||
country_name = geo.country_name
|
||||
asn = geo.asn
|
||||
org = geo.org
|
||||
except Exception: # noqa: BLE001
|
||||
log.warning("history_service_geo_lookup_failed", ip=ip)
|
||||
|
||||
items.append(
|
||||
HistoryBanItem(
|
||||
ip=ip,
|
||||
jail=jail_name,
|
||||
banned_at=banned_at,
|
||||
ban_count=ban_count,
|
||||
failures=failures,
|
||||
matches=matches,
|
||||
country_code=country_code,
|
||||
country_name=country_name,
|
||||
asn=asn,
|
||||
org=org,
|
||||
)
|
||||
)
|
||||
|
||||
return HistoryListResponse(
|
||||
items=items,
|
||||
total=total,
|
||||
@@ -185,7 +225,7 @@ async def get_ip_detail(
|
||||
socket_path: str,
|
||||
ip: str,
|
||||
*,
|
||||
geo_enricher: Any | None = None,
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
) -> IpDetailResponse | None:
|
||||
"""Return the full historical record for a single IP address.
|
||||
|
||||
@@ -202,19 +242,10 @@ async def get_ip_detail(
|
||||
:class:`~app.models.history.IpDetailResponse` if any records exist
|
||||
for *ip*, or ``None`` if the IP has no history in the database.
|
||||
"""
|
||||
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||
db_path: str = await get_fail2ban_db_path(socket_path)
|
||||
log.info("history_service_ip_detail", db_path=db_path, ip=ip)
|
||||
|
||||
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
|
||||
f2b_db.row_factory = aiosqlite.Row
|
||||
async with f2b_db.execute(
|
||||
"SELECT jail, ip, timeofban, bancount, data "
|
||||
"FROM bans "
|
||||
"WHERE ip = ? "
|
||||
"ORDER BY timeofban DESC",
|
||||
(ip,),
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
rows = await fail2ban_db_repo.get_history_for_ip(db_path=db_path, ip=ip)
|
||||
|
||||
if not rows:
|
||||
return None
|
||||
@@ -223,10 +254,10 @@ async def get_ip_detail(
|
||||
total_failures: int = 0
|
||||
|
||||
for row in rows:
|
||||
jail_name: str = str(row["jail"])
|
||||
banned_at: str = _ts_to_iso(int(row["timeofban"]))
|
||||
ban_count: int = int(row["bancount"])
|
||||
matches, failures = _parse_data_json(row["data"])
|
||||
jail_name: str = row.jail
|
||||
banned_at: str = ts_to_iso(row.timeofban)
|
||||
ban_count: int = row.bancount
|
||||
matches, failures = parse_data_json(row.data)
|
||||
total_failures += failures
|
||||
timeline.append(
|
||||
IpTimelineEvent(
|
||||
|
||||
993
backend/app/services/jail_config_service.py
Normal file
993
backend/app/services/jail_config_service.py
Normal file
@@ -0,0 +1,993 @@
|
||||
"""Jail configuration management for BanGUI.
|
||||
|
||||
Handles parsing, validation, and lifecycle operations (activate/deactivate)
|
||||
for fail2ban jail configurations. Provides functions to discover inactive
|
||||
jails, validate their configurations before activation, and manage jail
|
||||
overrides in jail.d/*.local files.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import configparser
|
||||
import contextlib
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import structlog
|
||||
|
||||
from app.exceptions import JailNotFoundError
|
||||
from app.models.config import (
|
||||
ActivateJailRequest,
|
||||
InactiveJail,
|
||||
InactiveJailListResponse,
|
||||
JailActivationResponse,
|
||||
JailValidationResult,
|
||||
RollbackResponse,
|
||||
)
|
||||
from app.utils.config_file_utils import (
|
||||
_build_inactive_jail,
|
||||
_get_active_jail_names,
|
||||
_parse_jails_sync,
|
||||
_validate_jail_config_sync,
|
||||
)
|
||||
from app.utils.fail2ban_client import Fail2BanClient
|
||||
from app.utils.jail_utils import reload_jails
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SOCKET_TIMEOUT: float = 10.0
|
||||
|
||||
# Allowlist pattern for jail names used in path construction.
|
||||
_SAFE_JAIL_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
|
||||
|
||||
# Sections that are not jail definitions.
|
||||
_META_SECTIONS: frozenset[str] = frozenset({"INCLUDES", "DEFAULT"})
|
||||
|
||||
# True-ish values for the ``enabled`` key.
|
||||
_TRUE_VALUES: frozenset[str] = frozenset({"true", "yes", "1"})
|
||||
|
||||
# False-ish values for the ``enabled`` key.
|
||||
_FALSE_VALUES: frozenset[str] = frozenset({"false", "no", "0"})
|
||||
|
||||
# Seconds to wait between fail2ban liveness probes after a reload.
|
||||
_POST_RELOAD_PROBE_INTERVAL: float = 2.0
|
||||
|
||||
# Maximum number of post-reload probe attempts (initial attempt + retries).
|
||||
_POST_RELOAD_MAX_ATTEMPTS: int = 4
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Custom exceptions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class JailNotFoundInConfigError(Exception):
|
||||
"""Raised when the requested jail name is not defined in any config file."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""Initialise with the jail name that was not found.
|
||||
|
||||
Args:
|
||||
name: The jail name that could not be located.
|
||||
"""
|
||||
self.name: str = name
|
||||
super().__init__(f"Jail not found in config files: {name!r}")
|
||||
|
||||
|
||||
class JailAlreadyActiveError(Exception):
|
||||
"""Raised when trying to activate a jail that is already active."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""Initialise with the jail name.
|
||||
|
||||
Args:
|
||||
name: The jail that is already active.
|
||||
"""
|
||||
self.name: str = name
|
||||
super().__init__(f"Jail is already active: {name!r}")
|
||||
|
||||
|
||||
class JailAlreadyInactiveError(Exception):
|
||||
"""Raised when trying to deactivate a jail that is already inactive."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""Initialise with the jail name.
|
||||
|
||||
Args:
|
||||
name: The jail that is already inactive.
|
||||
"""
|
||||
self.name: str = name
|
||||
super().__init__(f"Jail is already inactive: {name!r}")
|
||||
|
||||
|
||||
class JailNameError(Exception):
|
||||
"""Raised when a jail name contains invalid characters."""
|
||||
|
||||
|
||||
class ConfigWriteError(Exception):
|
||||
"""Raised when writing a ``.local`` override file fails."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _safe_jail_name(name: str) -> str:
|
||||
"""Validate *name* and return it unchanged or raise :class:`JailNameError`.
|
||||
|
||||
Args:
|
||||
name: Proposed jail name.
|
||||
|
||||
Returns:
|
||||
The name unchanged if valid.
|
||||
|
||||
Raises:
|
||||
JailNameError: If *name* contains unsafe characters.
|
||||
"""
|
||||
if not _SAFE_JAIL_NAME_RE.match(name):
|
||||
raise JailNameError(
|
||||
f"Jail name {name!r} contains invalid characters. "
|
||||
"Only alphanumeric characters, hyphens, underscores, and dots are "
|
||||
"allowed; must start with an alphanumeric character."
|
||||
)
|
||||
return name
|
||||
|
||||
|
||||
def _build_parser() -> configparser.RawConfigParser:
|
||||
"""Create a :class:`configparser.RawConfigParser` for fail2ban configs.
|
||||
|
||||
Returns:
|
||||
Parser with interpolation disabled and case-sensitive option names.
|
||||
"""
|
||||
parser = configparser.RawConfigParser(interpolation=None, strict=False)
|
||||
# fail2ban keys are lowercase but preserve case to be safe.
|
||||
parser.optionxform = str # type: ignore[assignment]
|
||||
return parser
|
||||
|
||||
|
||||
def _is_truthy(value: str) -> bool:
|
||||
"""Return ``True`` if *value* is a fail2ban boolean true string.
|
||||
|
||||
Args:
|
||||
value: Raw string from config (e.g. ``"true"``, ``"yes"``, ``"1"``).
|
||||
|
||||
Returns:
|
||||
``True`` when the value represents enabled.
|
||||
"""
|
||||
return value.strip().lower() in _TRUE_VALUES
|
||||
|
||||
|
||||
def _write_local_override_sync(
|
||||
config_dir: Path,
|
||||
jail_name: str,
|
||||
enabled: bool,
|
||||
overrides: dict[str, object],
|
||||
) -> None:
|
||||
"""Write a ``jail.d/{name}.local`` file atomically.
|
||||
|
||||
Always writes to ``jail.d/{jail_name}.local``. If the file already
|
||||
exists it is replaced entirely. The write is atomic: content is
|
||||
written to a temp file first, then renamed into place.
|
||||
|
||||
Args:
|
||||
config_dir: The fail2ban configuration root directory.
|
||||
jail_name: Validated jail name (used as filename stem).
|
||||
enabled: Value to write for ``enabled =``.
|
||||
overrides: Optional setting overrides (bantime, findtime, maxretry,
|
||||
port, logpath).
|
||||
|
||||
Raises:
|
||||
ConfigWriteError: If writing fails.
|
||||
"""
|
||||
jail_d = config_dir / "jail.d"
|
||||
try:
|
||||
jail_d.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc
|
||||
|
||||
local_path = jail_d / f"{jail_name}.local"
|
||||
|
||||
lines: list[str] = [
|
||||
"# Managed by BanGUI — do not edit manually",
|
||||
"",
|
||||
f"[{jail_name}]",
|
||||
"",
|
||||
f"enabled = {'true' if enabled else 'false'}",
|
||||
# Provide explicit banaction defaults so fail2ban can resolve the
|
||||
# %(banaction)s interpolation used in the built-in action_ chain.
|
||||
"banaction = iptables-multiport",
|
||||
"banaction_allports = iptables-allports",
|
||||
]
|
||||
|
||||
if overrides.get("bantime") is not None:
|
||||
lines.append(f"bantime = {overrides['bantime']}")
|
||||
if overrides.get("findtime") is not None:
|
||||
lines.append(f"findtime = {overrides['findtime']}")
|
||||
if overrides.get("maxretry") is not None:
|
||||
lines.append(f"maxretry = {overrides['maxretry']}")
|
||||
if overrides.get("port") is not None:
|
||||
lines.append(f"port = {overrides['port']}")
|
||||
if overrides.get("logpath"):
|
||||
paths: list[str] = cast("list[str]", overrides["logpath"])
|
||||
if paths:
|
||||
lines.append(f"logpath = {paths[0]}")
|
||||
for p in paths[1:]:
|
||||
lines.append(f" {p}")
|
||||
|
||||
content = "\n".join(lines) + "\n"
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
encoding="utf-8",
|
||||
dir=jail_d,
|
||||
delete=False,
|
||||
suffix=".tmp",
|
||||
) as tmp:
|
||||
tmp.write(content)
|
||||
tmp_name = tmp.name
|
||||
os.replace(tmp_name, local_path)
|
||||
except OSError as exc:
|
||||
# Clean up temp file if rename failed.
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(tmp_name) # noqa: F821 — only reachable when tmp_name is set
|
||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
||||
|
||||
log.info(
|
||||
"jail_local_written",
|
||||
jail=jail_name,
|
||||
path=str(local_path),
|
||||
enabled=enabled,
|
||||
)
|
||||
|
||||
|
||||
def _restore_local_file_sync(local_path: Path, original_content: bytes | None) -> None:
|
||||
"""Restore a ``.local`` file to its pre-activation state.
|
||||
|
||||
If *original_content* is ``None``, the file is deleted (it did not exist
|
||||
before the activation). Otherwise the original bytes are written back
|
||||
atomically via a temp-file rename.
|
||||
|
||||
Args:
|
||||
local_path: Absolute path to the ``.local`` file to restore.
|
||||
original_content: Original raw bytes to write back, or ``None`` to
|
||||
delete the file.
|
||||
|
||||
Raises:
|
||||
ConfigWriteError: If the write or delete operation fails.
|
||||
"""
|
||||
if original_content is None:
|
||||
try:
|
||||
local_path.unlink(missing_ok=True)
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(f"Failed to delete {local_path} during rollback: {exc}") from exc
|
||||
return
|
||||
|
||||
tmp_name: str | None = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="wb",
|
||||
dir=local_path.parent,
|
||||
delete=False,
|
||||
suffix=".tmp",
|
||||
) as tmp:
|
||||
tmp.write(original_content)
|
||||
tmp_name = tmp.name
|
||||
os.replace(tmp_name, local_path)
|
||||
except OSError as exc:
|
||||
with contextlib.suppress(OSError):
|
||||
if tmp_name is not None:
|
||||
os.unlink(tmp_name)
|
||||
raise ConfigWriteError(f"Failed to restore {local_path} during rollback: {exc}") from exc
|
||||
|
||||
|
||||
def _validate_regex_patterns(patterns: list[str]) -> None:
|
||||
"""Validate each pattern in *patterns* using Python's ``re`` module.
|
||||
|
||||
Args:
|
||||
patterns: List of regex strings to validate.
|
||||
|
||||
Raises:
|
||||
FilterInvalidRegexError: If any pattern fails to compile.
|
||||
"""
|
||||
for pattern in patterns:
|
||||
try:
|
||||
re.compile(pattern)
|
||||
except re.error as exc:
|
||||
# Import here to avoid circular dependency
|
||||
from app.exceptions import FilterInvalidRegexError
|
||||
raise FilterInvalidRegexError(pattern, str(exc)) from exc
|
||||
|
||||
|
||||
def _set_jail_local_key_sync(
|
||||
config_dir: Path,
|
||||
jail_name: str,
|
||||
key: str,
|
||||
value: str,
|
||||
) -> None:
|
||||
"""Update ``jail.d/{jail_name}.local`` to set a single key in the jail section.
|
||||
|
||||
If the ``.local`` file already exists it is read, the key is updated (or
|
||||
added), and the file is written back atomically without disturbing other
|
||||
settings. If the file does not exist a new one is created containing
|
||||
only the BanGUI header comment, the jail section, and the requested key.
|
||||
|
||||
Args:
|
||||
config_dir: The fail2ban configuration root directory.
|
||||
jail_name: Validated jail name (used as section name and filename stem).
|
||||
key: Config key to set inside the jail section.
|
||||
value: Config value to assign.
|
||||
|
||||
Raises:
|
||||
ConfigWriteError: If writing fails.
|
||||
"""
|
||||
jail_d = config_dir / "jail.d"
|
||||
try:
|
||||
jail_d.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc
|
||||
|
||||
local_path = jail_d / f"{jail_name}.local"
|
||||
|
||||
parser = _build_parser()
|
||||
if local_path.is_file():
|
||||
try:
|
||||
parser.read(str(local_path), encoding="utf-8")
|
||||
except (configparser.Error, OSError) as exc:
|
||||
log.warning(
|
||||
"jail_local_read_for_update_error",
|
||||
jail=jail_name,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
if not parser.has_section(jail_name):
|
||||
parser.add_section(jail_name)
|
||||
parser.set(jail_name, key, value)
|
||||
|
||||
# Serialize: write a BanGUI header then the parser output.
|
||||
buf = io.StringIO()
|
||||
buf.write("# Managed by BanGUI — do not edit manually\n\n")
|
||||
parser.write(buf)
|
||||
content = buf.getvalue()
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
encoding="utf-8",
|
||||
dir=jail_d,
|
||||
delete=False,
|
||||
suffix=".tmp",
|
||||
) as tmp:
|
||||
tmp.write(content)
|
||||
tmp_name = tmp.name
|
||||
os.replace(tmp_name, local_path)
|
||||
except OSError as exc:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(tmp_name) # noqa: F821
|
||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
||||
|
||||
log.info(
|
||||
"jail_local_key_set",
|
||||
jail=jail_name,
|
||||
key=key,
|
||||
path=str(local_path),
|
||||
)
|
||||
|
||||
|
||||
async def _probe_fail2ban_running(socket_path: str) -> bool:
|
||||
"""Return ``True`` if the fail2ban socket responds to a ping.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
|
||||
Returns:
|
||||
``True`` when fail2ban is reachable, ``False`` otherwise.
|
||||
"""
|
||||
try:
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=5.0)
|
||||
resp = await client.send(["ping"])
|
||||
return isinstance(resp, (list, tuple)) and resp[0] == 0
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
|
||||
async def wait_for_fail2ban(
|
||||
socket_path: str,
|
||||
max_wait_seconds: float = 10.0,
|
||||
poll_interval: float = 2.0,
|
||||
) -> bool:
|
||||
"""Poll the fail2ban socket until it responds or the timeout expires.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
max_wait_seconds: Total time budget in seconds.
|
||||
poll_interval: Delay between probe attempts in seconds.
|
||||
|
||||
Returns:
|
||||
``True`` if fail2ban came online within the budget.
|
||||
"""
|
||||
elapsed = 0.0
|
||||
while elapsed < max_wait_seconds:
|
||||
if await _probe_fail2ban_running(socket_path):
|
||||
return True
|
||||
await asyncio.sleep(poll_interval)
|
||||
elapsed += poll_interval
|
||||
return False
|
||||
|
||||
|
||||
async def start_daemon(start_cmd_parts: list[str]) -> bool:
|
||||
"""Start the fail2ban daemon using *start_cmd_parts*.
|
||||
|
||||
Uses :func:`asyncio.create_subprocess_exec` (no shell interpretation)
|
||||
to avoid command injection.
|
||||
|
||||
Args:
|
||||
start_cmd_parts: Command and arguments, e.g.
|
||||
``["fail2ban-client", "start"]``.
|
||||
|
||||
Returns:
|
||||
``True`` when the process exited with code 0.
|
||||
"""
|
||||
if not start_cmd_parts:
|
||||
log.warning("fail2ban_start_cmd_empty")
|
||||
return False
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*start_cmd_parts,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
await asyncio.wait_for(proc.wait(), timeout=30.0)
|
||||
success = proc.returncode == 0
|
||||
if not success:
|
||||
log.warning(
|
||||
"fail2ban_start_cmd_nonzero",
|
||||
cmd=start_cmd_parts,
|
||||
returncode=proc.returncode,
|
||||
)
|
||||
return success
|
||||
except (TimeoutError, OSError) as exc:
|
||||
log.warning("fail2ban_start_cmd_error", cmd=start_cmd_parts, error=str(exc))
|
||||
return False
|
||||
|
||||
|
||||
# Shared functions from config_file_service are imported from app.utils.config_file_utils
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def list_inactive_jails(
|
||||
config_dir: str,
|
||||
socket_path: str,
|
||||
) -> InactiveJailListResponse:
|
||||
"""Return all jails defined in config files that are not currently active.
|
||||
|
||||
Parses ``jail.conf``, ``jail.local``, and ``jail.d/`` following the
|
||||
fail2ban merge order. A jail is considered inactive when:
|
||||
|
||||
- Its merged ``enabled`` value is ``false`` (or absent, which defaults to
|
||||
``false`` in fail2ban), **or**
|
||||
- Its ``enabled`` value is ``true`` in config but fail2ban does not report
|
||||
it as running.
|
||||
|
||||
Args:
|
||||
config_dir: Absolute path to the fail2ban configuration directory.
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.InactiveJailListResponse` with all
|
||||
inactive jails.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
parsed_result: tuple[dict[str, dict[str, str]], dict[str, str]] = await loop.run_in_executor(
|
||||
None, _parse_jails_sync, Path(config_dir)
|
||||
)
|
||||
all_jails, source_files = parsed_result
|
||||
active_names: set[str] = await _get_active_jail_names(socket_path)
|
||||
|
||||
inactive: list[InactiveJail] = []
|
||||
for jail_name, settings in sorted(all_jails.items()):
|
||||
if jail_name in active_names:
|
||||
# fail2ban reports this jail as running — skip it.
|
||||
continue
|
||||
|
||||
source = source_files.get(jail_name, config_dir)
|
||||
inactive.append(_build_inactive_jail(jail_name, settings, source, Path(config_dir)))
|
||||
|
||||
log.info(
|
||||
"inactive_jails_listed",
|
||||
total_defined=len(all_jails),
|
||||
active=len(active_names),
|
||||
inactive=len(inactive),
|
||||
)
|
||||
return InactiveJailListResponse(jails=inactive, total=len(inactive))
|
||||
|
||||
|
||||
async def activate_jail(
|
||||
config_dir: str,
|
||||
socket_path: str,
|
||||
name: str,
|
||||
req: ActivateJailRequest,
|
||||
) -> JailActivationResponse:
|
||||
"""Enable an inactive jail and reload fail2ban.
|
||||
|
||||
Performs pre-activation validation, writes ``enabled = true`` (plus any
|
||||
override values from *req*) to ``jail.d/{name}.local``, and triggers a
|
||||
full fail2ban reload. After the reload a multi-attempt health probe
|
||||
determines whether fail2ban (and the specific jail) are still running.
|
||||
|
||||
Args:
|
||||
config_dir: Absolute path to the fail2ban configuration directory.
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
name: Name of the jail to activate. Must exist in the parsed config.
|
||||
req: Optional override values to write alongside ``enabled = true``.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.JailActivationResponse` including
|
||||
``fail2ban_running`` and ``validation_warnings`` fields.
|
||||
|
||||
Raises:
|
||||
JailNameError: If *name* contains invalid characters.
|
||||
JailNotFoundInConfigError: If *name* is not defined in any config file.
|
||||
JailAlreadyActiveError: If fail2ban already reports *name* as running.
|
||||
ConfigWriteError: If writing the ``.local`` file fails.
|
||||
~app.utils.fail2ban_client.Fail2BanConnectionError: If the fail2ban
|
||||
socket is unreachable during reload.
|
||||
"""
|
||||
_safe_jail_name(name)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
all_jails, _source_files = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
||||
|
||||
if name not in all_jails:
|
||||
raise JailNotFoundInConfigError(name)
|
||||
|
||||
active_names = await _get_active_jail_names(socket_path)
|
||||
if name in active_names:
|
||||
raise JailAlreadyActiveError(name)
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# Pre-activation validation — collect warnings but do not block #
|
||||
# ---------------------------------------------------------------------- #
|
||||
validation_result: JailValidationResult = await loop.run_in_executor(
|
||||
None, _validate_jail_config_sync, Path(config_dir), name
|
||||
)
|
||||
warnings: list[str] = [f"{i.field}: {i.message}" for i in validation_result.issues]
|
||||
if warnings:
|
||||
log.warning(
|
||||
"jail_activation_validation_warnings",
|
||||
jail=name,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
# Block activation on critical validation failures (missing filter or logpath).
|
||||
blocking = [i for i in validation_result.issues if i.field in ("filter", "logpath")]
|
||||
if blocking:
|
||||
log.warning(
|
||||
"jail_activation_blocked",
|
||||
jail=name,
|
||||
issues=[f"{i.field}: {i.message}" for i in blocking],
|
||||
)
|
||||
return JailActivationResponse(
|
||||
name=name,
|
||||
active=False,
|
||||
fail2ban_running=True,
|
||||
validation_warnings=warnings,
|
||||
message=(f"Jail {name!r} cannot be activated: " + "; ".join(i.message for i in blocking)),
|
||||
)
|
||||
|
||||
overrides: dict[str, object] = {
|
||||
"bantime": req.bantime,
|
||||
"findtime": req.findtime,
|
||||
"maxretry": req.maxretry,
|
||||
"port": req.port,
|
||||
"logpath": req.logpath,
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# Backup the existing .local file (if any) before overwriting it so that #
|
||||
# we can restore it if activation fails. #
|
||||
# ---------------------------------------------------------------------- #
|
||||
local_path = Path(config_dir) / "jail.d" / f"{name}.local"
|
||||
original_content: bytes | None = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: local_path.read_bytes() if local_path.exists() else None,
|
||||
)
|
||||
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
_write_local_override_sync,
|
||||
Path(config_dir),
|
||||
name,
|
||||
True,
|
||||
overrides,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# Activation reload — if it fails, roll back immediately #
|
||||
# ---------------------------------------------------------------------- #
|
||||
try:
|
||||
await reload_jails(socket_path, include_jails=[name])
|
||||
except JailNotFoundError as exc:
|
||||
# Jail configuration is invalid (e.g. missing logpath that prevents
|
||||
# fail2ban from loading the jail). Roll back and provide a specific error.
|
||||
log.warning(
|
||||
"reload_after_activate_failed_jail_not_found",
|
||||
jail=name,
|
||||
error=str(exc),
|
||||
)
|
||||
recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
|
||||
return JailActivationResponse(
|
||||
name=name,
|
||||
active=False,
|
||||
fail2ban_running=False,
|
||||
recovered=recovered,
|
||||
validation_warnings=warnings,
|
||||
message=(
|
||||
f"Jail {name!r} activation failed: {str(exc)}. "
|
||||
"Check that all logpath files exist and are readable. "
|
||||
"The configuration was "
|
||||
+ ("automatically recovered." if recovered else "not recovered — manual intervention is required.")
|
||||
),
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("reload_after_activate_failed", jail=name, error=str(exc))
|
||||
recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
|
||||
return JailActivationResponse(
|
||||
name=name,
|
||||
active=False,
|
||||
fail2ban_running=False,
|
||||
recovered=recovered,
|
||||
validation_warnings=warnings,
|
||||
message=(
|
||||
f"Jail {name!r} activation failed during reload and the "
|
||||
"configuration was "
|
||||
+ ("automatically recovered." if recovered else "not recovered — manual intervention is required.")
|
||||
),
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# Post-reload health probe with retries #
|
||||
# ---------------------------------------------------------------------- #
|
||||
fail2ban_running = False
|
||||
for attempt in range(_POST_RELOAD_MAX_ATTEMPTS):
|
||||
if attempt > 0:
|
||||
await asyncio.sleep(_POST_RELOAD_PROBE_INTERVAL)
|
||||
if await _probe_fail2ban_running(socket_path):
|
||||
fail2ban_running = True
|
||||
break
|
||||
|
||||
if not fail2ban_running:
|
||||
log.warning(
|
||||
"fail2ban_down_after_activate",
|
||||
jail=name,
|
||||
message="fail2ban socket unreachable after reload — initiating rollback.",
|
||||
)
|
||||
recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
|
||||
return JailActivationResponse(
|
||||
name=name,
|
||||
active=False,
|
||||
fail2ban_running=False,
|
||||
recovered=recovered,
|
||||
validation_warnings=warnings,
|
||||
message=(
|
||||
f"Jail {name!r} activation failed: fail2ban stopped responding "
|
||||
"after reload. The configuration was "
|
||||
+ ("automatically recovered." if recovered else "not recovered — manual intervention is required.")
|
||||
),
|
||||
)
|
||||
|
||||
# Verify the jail actually started (config error may prevent it silently).
|
||||
post_reload_names = await _get_active_jail_names(socket_path)
|
||||
actually_running = name in post_reload_names
|
||||
if not actually_running:
|
||||
log.warning(
|
||||
"jail_activation_unverified",
|
||||
jail=name,
|
||||
message="Jail did not appear in running jails — initiating rollback.",
|
||||
)
|
||||
recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
|
||||
return JailActivationResponse(
|
||||
name=name,
|
||||
active=False,
|
||||
fail2ban_running=True,
|
||||
recovered=recovered,
|
||||
validation_warnings=warnings,
|
||||
message=(
|
||||
f"Jail {name!r} was written to config but did not start after "
|
||||
"reload. The configuration was "
|
||||
+ ("automatically recovered." if recovered else "not recovered — manual intervention is required.")
|
||||
),
|
||||
)
|
||||
|
||||
log.info("jail_activated", jail=name)
|
||||
return JailActivationResponse(
|
||||
name=name,
|
||||
active=True,
|
||||
fail2ban_running=True,
|
||||
validation_warnings=warnings,
|
||||
message=f"Jail {name!r} activated successfully.",
|
||||
)
|
||||
|
||||
|
||||
async def _rollback_activation_async(
|
||||
config_dir: str,
|
||||
name: str,
|
||||
socket_path: str,
|
||||
original_content: bytes | None,
|
||||
) -> bool:
|
||||
"""Restore the pre-activation ``.local`` file and reload fail2ban.
|
||||
|
||||
Called internally by :func:`activate_jail` when the activation fails after
|
||||
the config file was already written. Tries to:
|
||||
|
||||
1. Restore the original file content (or delete the file if it was newly
|
||||
created by the activation attempt).
|
||||
2. Reload fail2ban so the daemon runs with the restored configuration.
|
||||
3. Probe fail2ban to confirm it came back up.
|
||||
|
||||
Args:
|
||||
config_dir: Absolute path to the fail2ban configuration directory.
|
||||
name: Name of the jail whose ``.local`` file should be restored.
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
original_content: Raw bytes of the original ``.local`` file, or
|
||||
``None`` if the file did not exist before the activation.
|
||||
|
||||
Returns:
|
||||
``True`` if fail2ban is responsive again after the rollback, ``False``
|
||||
if recovery also failed.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
local_path = Path(config_dir) / "jail.d" / f"{name}.local"
|
||||
|
||||
# Step 1 — restore original file (or delete it).
|
||||
try:
|
||||
await loop.run_in_executor(None, _restore_local_file_sync, local_path, original_content)
|
||||
log.info("jail_activation_rollback_file_restored", jail=name)
|
||||
except ConfigWriteError as exc:
|
||||
log.error("jail_activation_rollback_restore_failed", jail=name, error=str(exc))
|
||||
return False
|
||||
|
||||
# Step 2 — reload fail2ban with the restored config.
|
||||
try:
|
||||
await reload_jails(socket_path)
|
||||
log.info("jail_activation_rollback_reload_ok", jail=name)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("jail_activation_rollback_reload_failed", jail=name, error=str(exc))
|
||||
return False
|
||||
|
||||
# Step 3 — wait for fail2ban to come back.
|
||||
for attempt in range(_POST_RELOAD_MAX_ATTEMPTS):
|
||||
if attempt > 0:
|
||||
await asyncio.sleep(_POST_RELOAD_PROBE_INTERVAL)
|
||||
if await _probe_fail2ban_running(socket_path):
|
||||
log.info("jail_activation_rollback_recovered", jail=name)
|
||||
return True
|
||||
|
||||
log.warning("jail_activation_rollback_still_down", jail=name)
|
||||
return False
|
||||
|
||||
|
||||
async def deactivate_jail(
|
||||
config_dir: str,
|
||||
socket_path: str,
|
||||
name: str,
|
||||
) -> JailActivationResponse:
|
||||
"""Disable an active jail and reload fail2ban.
|
||||
|
||||
Writes ``enabled = false`` to ``jail.d/{name}.local`` and triggers a
|
||||
full fail2ban reload so the jail stops immediately.
|
||||
|
||||
Args:
|
||||
config_dir: Absolute path to the fail2ban configuration directory.
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
name: Name of the jail to deactivate. Must exist in the parsed config.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.JailActivationResponse`.
|
||||
|
||||
Raises:
|
||||
JailNameError: If *name* contains invalid characters.
|
||||
JailNotFoundInConfigError: If *name* is not defined in any config file.
|
||||
JailAlreadyInactiveError: If fail2ban already reports *name* as not
|
||||
running.
|
||||
ConfigWriteError: If writing the ``.local`` file fails.
|
||||
~app.utils.fail2ban_client.Fail2BanConnectionError: If the fail2ban
|
||||
socket is unreachable during reload.
|
||||
"""
|
||||
_safe_jail_name(name)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
all_jails, _source_files = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
||||
|
||||
if name not in all_jails:
|
||||
raise JailNotFoundInConfigError(name)
|
||||
|
||||
active_names = await _get_active_jail_names(socket_path)
|
||||
if name not in active_names:
|
||||
raise JailAlreadyInactiveError(name)
|
||||
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
_write_local_override_sync,
|
||||
Path(config_dir),
|
||||
name,
|
||||
False,
|
||||
{},
|
||||
)
|
||||
|
||||
try:
|
||||
await reload_jails(socket_path, exclude_jails=[name])
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("reload_after_deactivate_failed", jail=name, error=str(exc))
|
||||
|
||||
log.info("jail_deactivated", jail=name)
|
||||
return JailActivationResponse(
|
||||
name=name,
|
||||
active=False,
|
||||
message=f"Jail {name!r} deactivated successfully.",
|
||||
)
|
||||
|
||||
|
||||
async def delete_jail_local_override(
|
||||
config_dir: str,
|
||||
socket_path: str,
|
||||
name: str,
|
||||
) -> None:
|
||||
"""Delete the ``jail.d/{name}.local`` override file for an inactive jail.
|
||||
|
||||
This is the clean-up action shown in the config UI when an inactive jail
|
||||
still has a ``.local`` override file (e.g. ``enabled = false``). The
|
||||
file is deleted outright; no fail2ban reload is required because the jail
|
||||
is already inactive.
|
||||
|
||||
Args:
|
||||
config_dir: Absolute path to the fail2ban configuration directory.
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
name: Name of the jail whose ``.local`` file should be removed.
|
||||
|
||||
Raises:
|
||||
JailNameError: If *name* contains invalid characters.
|
||||
JailNotFoundInConfigError: If *name* is not defined in any config file.
|
||||
JailAlreadyActiveError: If the jail is currently active (refusing to
|
||||
delete the live config file).
|
||||
ConfigWriteError: If the file cannot be deleted.
|
||||
"""
|
||||
_safe_jail_name(name)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
all_jails, _source_files = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
||||
|
||||
if name not in all_jails:
|
||||
raise JailNotFoundInConfigError(name)
|
||||
|
||||
active_names = await _get_active_jail_names(socket_path)
|
||||
if name in active_names:
|
||||
raise JailAlreadyActiveError(name)
|
||||
|
||||
local_path = Path(config_dir) / "jail.d" / f"{name}.local"
|
||||
try:
|
||||
await loop.run_in_executor(None, lambda: local_path.unlink(missing_ok=True))
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(f"Failed to delete {local_path}: {exc}") from exc
|
||||
|
||||
log.info("jail_local_override_deleted", jail=name, path=str(local_path))
|
||||
|
||||
|
||||
async def validate_jail_config(
|
||||
config_dir: str,
|
||||
name: str,
|
||||
) -> JailValidationResult:
|
||||
"""Run pre-activation validation checks on a jail configuration.
|
||||
|
||||
Validates that referenced filter and action files exist in ``filter.d/``
|
||||
and ``action.d/``, that all regex patterns compile, and that declared log
|
||||
paths exist on disk.
|
||||
|
||||
Args:
|
||||
config_dir: Absolute path to the fail2ban configuration directory.
|
||||
name: Name of the jail to validate.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.JailValidationResult` with any issues found.
|
||||
|
||||
Raises:
|
||||
JailNameError: If *name* contains invalid characters.
|
||||
"""
|
||||
_safe_jail_name(name)
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
_validate_jail_config_sync,
|
||||
Path(config_dir),
|
||||
name,
|
||||
)
|
||||
|
||||
|
||||
async def rollback_jail(
|
||||
config_dir: str,
|
||||
socket_path: str,
|
||||
name: str,
|
||||
start_cmd_parts: list[str],
|
||||
) -> RollbackResponse:
|
||||
"""Disable a bad jail config and restart the fail2ban daemon.
|
||||
|
||||
Writes ``enabled = false`` to ``jail.d/{name}.local`` (works even when
|
||||
fail2ban is down — only a file write), then attempts to start the daemon
|
||||
with *start_cmd_parts*. Waits up to 10 seconds for the socket to respond.
|
||||
|
||||
Args:
|
||||
config_dir: Absolute path to the fail2ban configuration directory.
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
name: Name of the jail to disable.
|
||||
start_cmd_parts: Argument list for the daemon start command, e.g.
|
||||
``["fail2ban-client", "start"]``.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.RollbackResponse`.
|
||||
|
||||
Raises:
|
||||
JailNameError: If *name* contains invalid characters.
|
||||
ConfigWriteError: If writing the ``.local`` file fails.
|
||||
"""
|
||||
_safe_jail_name(name)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Write enabled=false — this must succeed even when fail2ban is down.
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
_write_local_override_sync,
|
||||
Path(config_dir),
|
||||
name,
|
||||
False,
|
||||
{},
|
||||
)
|
||||
log.info("jail_rolled_back_disabled", jail=name)
|
||||
|
||||
# Attempt to start the daemon.
|
||||
started = await start_daemon(start_cmd_parts)
|
||||
log.info("jail_rollback_start_attempted", jail=name, start_ok=started)
|
||||
|
||||
# Wait for the socket to come back.
|
||||
fail2ban_running = await wait_for_fail2ban(socket_path, max_wait_seconds=10.0, poll_interval=2.0)
|
||||
|
||||
active_jails = 0
|
||||
if fail2ban_running:
|
||||
names = await _get_active_jail_names(socket_path)
|
||||
active_jails = len(names)
|
||||
|
||||
if fail2ban_running:
|
||||
log.info("jail_rollback_success", jail=name, active_jails=active_jails)
|
||||
return RollbackResponse(
|
||||
jail_name=name,
|
||||
disabled=True,
|
||||
fail2ban_running=True,
|
||||
active_jails=active_jails,
|
||||
message=(f"Jail {name!r} disabled and fail2ban restarted successfully with {active_jails} active jail(s)."),
|
||||
)
|
||||
|
||||
log.warning("jail_rollback_fail2ban_still_down", jail=name)
|
||||
return RollbackResponse(
|
||||
jail_name=name,
|
||||
disabled=True,
|
||||
fail2ban_running=False,
|
||||
active_jails=0,
|
||||
message=(
|
||||
f"Jail {name!r} was disabled but fail2ban did not come back online. "
|
||||
"Check the fail2ban log for additional errors."
|
||||
),
|
||||
)
|
||||
@@ -14,11 +14,12 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import contextlib
|
||||
import ipaddress
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, TypedDict, cast
|
||||
|
||||
import structlog
|
||||
|
||||
from app.models.ban import ActiveBan, ActiveBanListResponse
|
||||
from app.exceptions import JailNotFoundError, JailOperationError
|
||||
from app.models.ban import ActiveBan, ActiveBanListResponse, JailBannedIpsResponse
|
||||
from app.models.config import BantimeEscalation
|
||||
from app.models.jail import (
|
||||
Jail,
|
||||
@@ -27,10 +28,36 @@ from app.models.jail import (
|
||||
JailStatus,
|
||||
JailSummary,
|
||||
)
|
||||
from app.utils.fail2ban_client import Fail2BanClient, Fail2BanConnectionError
|
||||
from app.utils.fail2ban_client import (
|
||||
Fail2BanClient,
|
||||
Fail2BanCommand,
|
||||
Fail2BanConnectionError,
|
||||
Fail2BanResponse,
|
||||
Fail2BanToken,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable
|
||||
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
|
||||
from app.models.geo import GeoBatchLookup, GeoEnricher, GeoInfo
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
class IpLookupResult(TypedDict):
|
||||
"""Result returned by :func:`lookup_ip`.
|
||||
|
||||
This is intentionally a :class:`TypedDict` to provide precise typing for
|
||||
callers (e.g. routers) while keeping the implementation flexible.
|
||||
"""
|
||||
|
||||
ip: str
|
||||
currently_banned_in: list[str]
|
||||
geo: GeoInfo | None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -43,34 +70,24 @@ _SOCKET_TIMEOUT: float = 10.0
|
||||
# ensures only one reload stream is in-flight at a time.
|
||||
_reload_all_lock: asyncio.Lock = asyncio.Lock()
|
||||
|
||||
# Capability detection for optional fail2ban transmitter commands (backend, idle).
|
||||
# These commands are not supported in all fail2ban versions. Caching the result
|
||||
# avoids sending unsupported commands every polling cycle and spamming the
|
||||
# fail2ban log with "Invalid command" errors.
|
||||
_backend_cmd_supported: bool | None = None
|
||||
_backend_cmd_lock: asyncio.Lock = asyncio.Lock()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Custom exceptions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class JailNotFoundError(Exception):
|
||||
"""Raised when a requested jail name does not exist in fail2ban."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""Initialise with the jail name that was not found.
|
||||
|
||||
Args:
|
||||
name: The jail name that could not be located.
|
||||
"""
|
||||
self.name: str = name
|
||||
super().__init__(f"Jail not found: {name!r}")
|
||||
|
||||
|
||||
class JailOperationError(Exception):
|
||||
"""Raised when a jail control command fails for a non-auth reason."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _ok(response: Any) -> Any:
|
||||
def _ok(response: object) -> object:
|
||||
"""Extract the payload from a fail2ban ``(return_code, data)`` response.
|
||||
|
||||
Args:
|
||||
@@ -83,7 +100,7 @@ def _ok(response: Any) -> Any:
|
||||
ValueError: If the response indicates an error (return code ≠ 0).
|
||||
"""
|
||||
try:
|
||||
code, data = response
|
||||
code, data = cast("Fail2BanResponse", response)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
|
||||
|
||||
@@ -93,7 +110,7 @@ def _ok(response: Any) -> Any:
|
||||
return data
|
||||
|
||||
|
||||
def _to_dict(pairs: Any) -> dict[str, Any]:
|
||||
def _to_dict(pairs: object) -> dict[str, object]:
|
||||
"""Convert a list of ``(key, value)`` pairs to a plain dict.
|
||||
|
||||
Args:
|
||||
@@ -104,7 +121,7 @@ def _to_dict(pairs: Any) -> dict[str, Any]:
|
||||
"""
|
||||
if not isinstance(pairs, (list, tuple)):
|
||||
return {}
|
||||
result: dict[str, Any] = {}
|
||||
result: dict[str, object] = {}
|
||||
for item in pairs:
|
||||
try:
|
||||
k, v = item
|
||||
@@ -114,7 +131,7 @@ def _to_dict(pairs: Any) -> dict[str, Any]:
|
||||
return result
|
||||
|
||||
|
||||
def _ensure_list(value: Any) -> list[str]:
|
||||
def _ensure_list(value: object | None) -> list[str]:
|
||||
"""Coerce a fail2ban response value to a list of strings.
|
||||
|
||||
Some fail2ban ``get`` responses return ``None`` or a single string
|
||||
@@ -163,9 +180,9 @@ def _is_not_found_error(exc: Exception) -> bool:
|
||||
|
||||
async def _safe_get(
|
||||
client: Fail2BanClient,
|
||||
command: list[Any],
|
||||
default: Any = None,
|
||||
) -> Any:
|
||||
command: Fail2BanCommand,
|
||||
default: object | None = None,
|
||||
) -> object | None:
|
||||
"""Send a ``get`` command and return ``default`` on error.
|
||||
|
||||
Errors during optional detail queries (logpath, regex, etc.) should
|
||||
@@ -180,11 +197,57 @@ async def _safe_get(
|
||||
The response payload, or *default* on any error.
|
||||
"""
|
||||
try:
|
||||
return _ok(await client.send(command))
|
||||
response = await client.send(command)
|
||||
return _ok(cast("Fail2BanResponse", response))
|
||||
except (ValueError, TypeError, Exception):
|
||||
return default
|
||||
|
||||
|
||||
async def _check_backend_cmd_supported(
|
||||
client: Fail2BanClient,
|
||||
jail_name: str,
|
||||
) -> bool:
|
||||
"""Detect whether the fail2ban daemon supports optional ``get ... backend`` command.
|
||||
|
||||
Some fail2ban versions (e.g. LinuxServer.io container) do not implement the
|
||||
optional ``get <jail> backend`` and ``get <jail> idle`` transmitter sub-commands.
|
||||
This helper probes the daemon once and caches the result to avoid repeated
|
||||
"Invalid command" errors in the fail2ban log.
|
||||
|
||||
Uses double-check locking to minimize lock contention in concurrent polls.
|
||||
|
||||
Args:
|
||||
client: The :class:`~app.utils.fail2ban_client.Fail2BanClient` to use.
|
||||
jail_name: Name of any jail to use for the probe command.
|
||||
|
||||
Returns:
|
||||
``True`` if the command is supported, ``False`` otherwise.
|
||||
Once determined, the result is cached and reused for all jails.
|
||||
"""
|
||||
global _backend_cmd_supported
|
||||
|
||||
# Fast path: return cached result if already determined.
|
||||
if _backend_cmd_supported is not None:
|
||||
return _backend_cmd_supported
|
||||
|
||||
# Slow path: acquire lock and probe the command once.
|
||||
async with _backend_cmd_lock:
|
||||
# Double-check idiom: another coroutine may have probed while we waited.
|
||||
if _backend_cmd_supported is not None:
|
||||
return _backend_cmd_supported
|
||||
|
||||
# Probe: send the command and catch any exception.
|
||||
try:
|
||||
_ok(await client.send(["get", jail_name, "backend"]))
|
||||
_backend_cmd_supported = True
|
||||
log.debug("backend_cmd_supported_detected")
|
||||
except Exception:
|
||||
_backend_cmd_supported = False
|
||||
log.debug("backend_cmd_unsupported_detected")
|
||||
|
||||
return _backend_cmd_supported
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API — Jail listing & detail
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -238,7 +301,11 @@ async def _fetch_jail_summary(
|
||||
"""Fetch and build a :class:`~app.models.jail.JailSummary` for one jail.
|
||||
|
||||
Sends the ``status``, ``get ... bantime``, ``findtime``, ``maxretry``,
|
||||
``backend``, and ``idle`` commands in parallel.
|
||||
``backend``, and ``idle`` commands in parallel (if supported).
|
||||
|
||||
The ``backend`` and ``idle`` commands are optional and not supported in
|
||||
all fail2ban versions. If not supported, this function will not send them
|
||||
to avoid spamming the fail2ban log with "Invalid command" errors.
|
||||
|
||||
Args:
|
||||
client: Shared :class:`~app.utils.fail2ban_client.Fail2BanClient`.
|
||||
@@ -247,21 +314,42 @@ async def _fetch_jail_summary(
|
||||
Returns:
|
||||
A :class:`~app.models.jail.JailSummary` populated from the responses.
|
||||
"""
|
||||
_r = await asyncio.gather(
|
||||
# Check whether optional backend/idle commands are supported.
|
||||
# This probe happens once per session and is cached to avoid repeated
|
||||
# "Invalid command" errors in the fail2ban log.
|
||||
backend_cmd_is_supported = await _check_backend_cmd_supported(client, name)
|
||||
|
||||
# Build the gather list based on command support.
|
||||
gather_list: list[Awaitable[object]] = [
|
||||
client.send(["status", name, "short"]),
|
||||
client.send(["get", name, "bantime"]),
|
||||
client.send(["get", name, "findtime"]),
|
||||
client.send(["get", name, "maxretry"]),
|
||||
client.send(["get", name, "backend"]),
|
||||
client.send(["get", name, "idle"]),
|
||||
return_exceptions=True,
|
||||
)
|
||||
status_raw: Any = _r[0]
|
||||
bantime_raw: Any = _r[1]
|
||||
findtime_raw: Any = _r[2]
|
||||
maxretry_raw: Any = _r[3]
|
||||
backend_raw: Any = _r[4]
|
||||
idle_raw: Any = _r[5]
|
||||
]
|
||||
|
||||
if backend_cmd_is_supported:
|
||||
# Commands are supported; send them for real values.
|
||||
gather_list.extend([
|
||||
client.send(["get", name, "backend"]),
|
||||
client.send(["get", name, "idle"]),
|
||||
])
|
||||
else:
|
||||
# Commands not supported; return default values without sending.
|
||||
async def _return_default(value: object | None) -> Fail2BanResponse:
|
||||
return (0, value)
|
||||
|
||||
gather_list.extend([
|
||||
_return_default("polling"), # backend default
|
||||
_return_default(False), # idle default
|
||||
])
|
||||
|
||||
_r = await asyncio.gather(*gather_list, return_exceptions=True)
|
||||
status_raw: object | Exception = _r[0]
|
||||
bantime_raw: object | Exception = _r[1]
|
||||
findtime_raw: object | Exception = _r[2]
|
||||
maxretry_raw: object | Exception = _r[3]
|
||||
backend_raw: object | Exception = _r[4]
|
||||
idle_raw: object | Exception = _r[5]
|
||||
|
||||
# Parse jail status (filter + actions).
|
||||
jail_status: JailStatus | None = None
|
||||
@@ -271,35 +359,35 @@ async def _fetch_jail_summary(
|
||||
filter_stats = _to_dict(raw.get("Filter") or [])
|
||||
action_stats = _to_dict(raw.get("Actions") or [])
|
||||
jail_status = JailStatus(
|
||||
currently_banned=int(action_stats.get("Currently banned", 0) or 0),
|
||||
total_banned=int(action_stats.get("Total banned", 0) or 0),
|
||||
currently_failed=int(filter_stats.get("Currently failed", 0) or 0),
|
||||
total_failed=int(filter_stats.get("Total failed", 0) or 0),
|
||||
currently_banned=int(str(action_stats.get("Currently banned", 0) or 0)),
|
||||
total_banned=int(str(action_stats.get("Total banned", 0) or 0)),
|
||||
currently_failed=int(str(filter_stats.get("Currently failed", 0) or 0)),
|
||||
total_failed=int(str(filter_stats.get("Total failed", 0) or 0)),
|
||||
)
|
||||
except (ValueError, TypeError) as exc:
|
||||
log.warning("jail_status_parse_error", jail=name, error=str(exc))
|
||||
|
||||
def _safe_int(raw: Any, fallback: int) -> int:
|
||||
def _safe_int(raw: object | Exception, fallback: int) -> int:
|
||||
if isinstance(raw, Exception):
|
||||
return fallback
|
||||
try:
|
||||
return int(_ok(raw))
|
||||
return int(str(_ok(cast("Fail2BanResponse", raw))))
|
||||
except (ValueError, TypeError):
|
||||
return fallback
|
||||
|
||||
def _safe_str(raw: Any, fallback: str) -> str:
|
||||
def _safe_str(raw: object | Exception, fallback: str) -> str:
|
||||
if isinstance(raw, Exception):
|
||||
return fallback
|
||||
try:
|
||||
return str(_ok(raw))
|
||||
return str(_ok(cast("Fail2BanResponse", raw)))
|
||||
except (ValueError, TypeError):
|
||||
return fallback
|
||||
|
||||
def _safe_bool(raw: Any, fallback: bool = False) -> bool:
|
||||
def _safe_bool(raw: object | Exception, fallback: bool = False) -> bool:
|
||||
if isinstance(raw, Exception):
|
||||
return fallback
|
||||
try:
|
||||
return bool(_ok(raw))
|
||||
return bool(_ok(cast("Fail2BanResponse", raw)))
|
||||
except (ValueError, TypeError):
|
||||
return fallback
|
||||
|
||||
@@ -349,10 +437,10 @@ async def get_jail(socket_path: str, name: str) -> JailDetailResponse:
|
||||
action_stats = _to_dict(raw.get("Actions") or [])
|
||||
|
||||
jail_status = JailStatus(
|
||||
currently_banned=int(action_stats.get("Currently banned", 0) or 0),
|
||||
total_banned=int(action_stats.get("Total banned", 0) or 0),
|
||||
currently_failed=int(filter_stats.get("Currently failed", 0) or 0),
|
||||
total_failed=int(filter_stats.get("Total failed", 0) or 0),
|
||||
currently_banned=int(str(action_stats.get("Currently banned", 0) or 0)),
|
||||
total_banned=int(str(action_stats.get("Total banned", 0) or 0)),
|
||||
currently_failed=int(str(filter_stats.get("Currently failed", 0) or 0)),
|
||||
total_failed=int(str(filter_stats.get("Total failed", 0) or 0)),
|
||||
)
|
||||
|
||||
# Fetch all detail fields in parallel.
|
||||
@@ -401,11 +489,11 @@ async def get_jail(socket_path: str, name: str) -> JailDetailResponse:
|
||||
bt_increment: bool = bool(bt_increment_raw)
|
||||
bantime_escalation = BantimeEscalation(
|
||||
increment=bt_increment,
|
||||
factor=float(bt_factor_raw) if bt_factor_raw is not None else None,
|
||||
factor=float(str(bt_factor_raw)) if bt_factor_raw is not None else None,
|
||||
formula=str(bt_formula_raw) if bt_formula_raw else None,
|
||||
multipliers=str(bt_multipliers_raw) if bt_multipliers_raw else None,
|
||||
max_time=int(bt_maxtime_raw) if bt_maxtime_raw is not None else None,
|
||||
rnd_time=int(bt_rndtime_raw) if bt_rndtime_raw is not None else None,
|
||||
max_time=int(str(bt_maxtime_raw)) if bt_maxtime_raw is not None else None,
|
||||
rnd_time=int(str(bt_rndtime_raw)) if bt_rndtime_raw is not None else None,
|
||||
overall_jails=bool(bt_overalljails_raw),
|
||||
)
|
||||
|
||||
@@ -421,9 +509,9 @@ async def get_jail(socket_path: str, name: str) -> JailDetailResponse:
|
||||
ignore_ips=_ensure_list(ignoreip_raw),
|
||||
date_pattern=str(datepattern_raw) if datepattern_raw else None,
|
||||
log_encoding=str(logencoding_raw or "UTF-8"),
|
||||
find_time=int(findtime_raw or 600),
|
||||
ban_time=int(bantime_raw or 600),
|
||||
max_retry=int(maxretry_raw or 5),
|
||||
find_time=int(str(findtime_raw or 600)),
|
||||
ban_time=int(str(bantime_raw or 600)),
|
||||
max_retry=int(str(maxretry_raw or 5)),
|
||||
bantime_escalation=bantime_escalation,
|
||||
status=jail_status,
|
||||
actions=_ensure_list(actions_raw),
|
||||
@@ -569,7 +657,10 @@ async def reload_all(
|
||||
exclude_jails: Jail names to remove from the start stream.
|
||||
|
||||
Raises:
|
||||
JailOperationError: If fail2ban reports the operation failed.
|
||||
JailNotFoundError: If a jail in *include_jails* does not exist or
|
||||
its configuration is invalid (e.g. missing logpath).
|
||||
JailOperationError: If fail2ban reports the operation failed for
|
||||
a different reason.
|
||||
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
|
||||
cannot be reached.
|
||||
"""
|
||||
@@ -589,13 +680,47 @@ async def reload_all(
|
||||
if exclude_jails:
|
||||
names_set -= set(exclude_jails)
|
||||
|
||||
stream: list[list[str]] = [["start", n] for n in sorted(names_set)]
|
||||
_ok(await client.send(["reload", "--all", [], stream]))
|
||||
stream: list[list[object]] = [["start", n] for n in sorted(names_set)]
|
||||
_ok(await client.send(["reload", "--all", [], cast("Fail2BanToken", stream)]))
|
||||
log.info("all_jails_reloaded")
|
||||
except ValueError as exc:
|
||||
# Detect UnknownJailException (missing or invalid jail configuration)
|
||||
# and re-raise as JailNotFoundError for better error specificity.
|
||||
if _is_not_found_error(exc):
|
||||
# Extract the jail name from include_jails if available.
|
||||
jail_name = include_jails[0] if include_jails else "unknown"
|
||||
raise JailNotFoundError(jail_name) from exc
|
||||
raise JailOperationError(str(exc)) from exc
|
||||
|
||||
|
||||
async def restart(socket_path: str) -> None:
|
||||
"""Stop the fail2ban daemon via the Unix socket.
|
||||
|
||||
Sends ``["stop"]`` to the fail2ban daemon, which calls ``server.quit()``
|
||||
on the daemon side and tears down all jails. The caller is responsible
|
||||
for starting the daemon again (e.g. via ``fail2ban-client start``).
|
||||
|
||||
Note:
|
||||
``["restart"]`` is a *client-side* orchestration command that is not
|
||||
handled by the fail2ban server transmitter — sending it to the socket
|
||||
raises ``"Invalid command"`` in the daemon.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
|
||||
Raises:
|
||||
JailOperationError: If fail2ban reports the stop command failed.
|
||||
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
|
||||
cannot be reached.
|
||||
"""
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
try:
|
||||
_ok(await client.send(["stop"]))
|
||||
log.info("fail2ban_stopped_for_restart")
|
||||
except ValueError as exc:
|
||||
raise JailOperationError(str(exc)) from exc
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API — Ban / Unban
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -679,9 +804,10 @@ async def unban_ip(
|
||||
|
||||
async def get_active_bans(
|
||||
socket_path: str,
|
||||
geo_enricher: Any | None = None,
|
||||
http_session: Any | None = None,
|
||||
app_db: Any | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
http_session: aiohttp.ClientSession | None = None,
|
||||
app_db: aiosqlite.Connection | None = None,
|
||||
) -> ActiveBanListResponse:
|
||||
"""Return all currently banned IPs across every jail.
|
||||
|
||||
@@ -716,7 +842,6 @@ async def get_active_bans(
|
||||
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
|
||||
cannot be reached.
|
||||
"""
|
||||
from app.services import geo_service # noqa: PLC0415
|
||||
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
|
||||
@@ -733,7 +858,7 @@ async def get_active_bans(
|
||||
return ActiveBanListResponse(bans=[], total=0)
|
||||
|
||||
# For each jail, fetch the ban list with time info in parallel.
|
||||
results: list[Any] = await asyncio.gather(
|
||||
results: list[object | Exception] = await asyncio.gather(
|
||||
*[client.send(["get", jn, "banip", "--with-time"]) for jn in jail_names],
|
||||
return_exceptions=True,
|
||||
)
|
||||
@@ -749,7 +874,7 @@ async def get_active_bans(
|
||||
continue
|
||||
|
||||
try:
|
||||
ban_list: list[str] = _ok(raw_result) or []
|
||||
ban_list: list[str] = cast("list[str]", _ok(raw_result)) or []
|
||||
except (TypeError, ValueError) as exc:
|
||||
log.warning(
|
||||
"active_bans_parse_error",
|
||||
@@ -764,10 +889,10 @@ async def get_active_bans(
|
||||
bans.append(ban)
|
||||
|
||||
# Enrich with geo data — prefer batch lookup over per-IP enricher.
|
||||
if http_session is not None and bans:
|
||||
if http_session is not None and bans and geo_batch_lookup is not None:
|
||||
all_ips: list[str] = [ban.ip for ban in bans]
|
||||
try:
|
||||
geo_map = await geo_service.lookup_batch(all_ips, http_session, db=app_db)
|
||||
geo_map = await geo_batch_lookup(all_ips, http_session, db=app_db)
|
||||
except Exception: # noqa: BLE001
|
||||
log.warning("active_bans_batch_geo_failed")
|
||||
geo_map = {}
|
||||
@@ -862,9 +987,122 @@ def _parse_ban_entry(entry: str, jail: str) -> ActiveBan | None:
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API — Jail-specific paginated bans
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
#: Maximum allowed page size for :func:`get_jail_banned_ips`.
|
||||
_MAX_PAGE_SIZE: int = 100
|
||||
|
||||
|
||||
async def get_jail_banned_ips(
|
||||
socket_path: str,
|
||||
jail_name: str,
|
||||
page: int = 1,
|
||||
page_size: int = 25,
|
||||
search: str | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
http_session: aiohttp.ClientSession | None = None,
|
||||
app_db: aiosqlite.Connection | None = None,
|
||||
) -> JailBannedIpsResponse:
|
||||
"""Return a paginated list of currently banned IPs for a single jail.
|
||||
|
||||
Fetches the full ban list from the fail2ban socket, applies an optional
|
||||
substring search filter on the IP, paginates server-side, and geo-enriches
|
||||
**only** the current page slice to stay within rate limits.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
jail_name: Name of the jail to query.
|
||||
page: 1-based page number (default 1).
|
||||
page_size: Items per page; clamped to :data:`_MAX_PAGE_SIZE` (default 25).
|
||||
search: Optional case-insensitive substring filter applied to IP addresses.
|
||||
http_session: Optional shared :class:`aiohttp.ClientSession` for geo
|
||||
enrichment via :func:`~app.services.geo_service.lookup_batch`.
|
||||
app_db: Optional BanGUI application database for persistent geo cache.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.ban.JailBannedIpsResponse` with the paginated bans.
|
||||
|
||||
Raises:
|
||||
JailNotFoundError: If *jail_name* is not a known active jail.
|
||||
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket is
|
||||
unreachable.
|
||||
"""
|
||||
# Clamp page_size to the allowed maximum.
|
||||
page_size = min(page_size, _MAX_PAGE_SIZE)
|
||||
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
|
||||
# Verify the jail exists.
|
||||
try:
|
||||
_ok(await client.send(["status", jail_name, "short"]))
|
||||
except ValueError as exc:
|
||||
if _is_not_found_error(exc):
|
||||
raise JailNotFoundError(jail_name) from exc
|
||||
raise
|
||||
|
||||
# Fetch the full ban list for this jail.
|
||||
try:
|
||||
raw_result = _ok(await client.send(["get", jail_name, "banip", "--with-time"]))
|
||||
except (ValueError, TypeError):
|
||||
raw_result = []
|
||||
|
||||
ban_list: list[str] = cast("list[str]", raw_result) or []
|
||||
|
||||
# Parse all entries.
|
||||
all_bans: list[ActiveBan] = []
|
||||
for entry in ban_list:
|
||||
ban = _parse_ban_entry(str(entry), jail_name)
|
||||
if ban is not None:
|
||||
all_bans.append(ban)
|
||||
|
||||
# Apply optional substring search filter (case-insensitive).
|
||||
if search:
|
||||
search_lower = search.lower()
|
||||
all_bans = [b for b in all_bans if search_lower in b.ip.lower()]
|
||||
|
||||
total = len(all_bans)
|
||||
|
||||
# Slice the requested page.
|
||||
start = (page - 1) * page_size
|
||||
page_bans = all_bans[start : start + page_size]
|
||||
|
||||
# Geo-enrich only the page slice.
|
||||
if http_session is not None and page_bans and geo_batch_lookup is not None:
|
||||
page_ips = [b.ip for b in page_bans]
|
||||
try:
|
||||
geo_map = await geo_batch_lookup(page_ips, http_session, db=app_db)
|
||||
except Exception: # noqa: BLE001
|
||||
log.warning("jail_banned_ips_geo_failed", jail=jail_name)
|
||||
geo_map = {}
|
||||
enriched_page: list[ActiveBan] = []
|
||||
for ban in page_bans:
|
||||
geo = geo_map.get(ban.ip)
|
||||
if geo is not None:
|
||||
enriched_page.append(ban.model_copy(update={"country": geo.country_code}))
|
||||
else:
|
||||
enriched_page.append(ban)
|
||||
page_bans = enriched_page
|
||||
|
||||
log.info(
|
||||
"jail_banned_ips_fetched",
|
||||
jail=jail_name,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return JailBannedIpsResponse(
|
||||
items=page_bans,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
async def _enrich_bans(
|
||||
bans: list[ActiveBan],
|
||||
geo_enricher: Any,
|
||||
geo_enricher: GeoEnricher,
|
||||
) -> list[ActiveBan]:
|
||||
"""Enrich ban records with geo data asynchronously.
|
||||
|
||||
@@ -875,14 +1113,15 @@ async def _enrich_bans(
|
||||
Returns:
|
||||
The same list with ``country`` fields populated where lookup succeeded.
|
||||
"""
|
||||
geo_results: list[Any] = await asyncio.gather(
|
||||
*[geo_enricher(ban.ip) for ban in bans],
|
||||
geo_results: list[object | Exception] = await asyncio.gather(
|
||||
*[cast("Awaitable[object]", geo_enricher(ban.ip)) for ban in bans],
|
||||
return_exceptions=True,
|
||||
)
|
||||
enriched: list[ActiveBan] = []
|
||||
for ban, geo in zip(bans, geo_results, strict=False):
|
||||
if geo is not None and not isinstance(geo, Exception):
|
||||
enriched.append(ban.model_copy(update={"country": geo.country_code}))
|
||||
geo_info = cast("GeoInfo", geo)
|
||||
enriched.append(ban.model_copy(update={"country": geo_info.country_code}))
|
||||
else:
|
||||
enriched.append(ban)
|
||||
return enriched
|
||||
@@ -1030,8 +1269,8 @@ async def set_ignore_self(socket_path: str, name: str, *, on: bool) -> None:
|
||||
async def lookup_ip(
|
||||
socket_path: str,
|
||||
ip: str,
|
||||
geo_enricher: Any | None = None,
|
||||
) -> dict[str, Any]:
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
) -> IpLookupResult:
|
||||
"""Return ban status and history for a single IP address.
|
||||
|
||||
Checks every running jail for whether the IP is currently banned.
|
||||
@@ -1074,7 +1313,7 @@ async def lookup_ip(
|
||||
)
|
||||
|
||||
# Check ban status per jail in parallel.
|
||||
ban_results: list[Any] = await asyncio.gather(
|
||||
ban_results: list[object | Exception] = await asyncio.gather(
|
||||
*[client.send(["get", jn, "banip"]) for jn in jail_names],
|
||||
return_exceptions=True,
|
||||
)
|
||||
@@ -1084,7 +1323,7 @@ async def lookup_ip(
|
||||
if isinstance(result, Exception):
|
||||
continue
|
||||
try:
|
||||
ban_list: list[str] = _ok(result) or []
|
||||
ban_list: list[str] = cast("list[str]", _ok(result)) or []
|
||||
if ip in ban_list:
|
||||
currently_banned_in.append(jail_name)
|
||||
except (ValueError, TypeError):
|
||||
@@ -1121,6 +1360,6 @@ async def unban_all_ips(socket_path: str) -> int:
|
||||
cannot be reached.
|
||||
"""
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
count: int = int(_ok(await client.send(["unban", "--all"])))
|
||||
count: int = int(str(_ok(await client.send(["unban", "--all"])) or 0))
|
||||
log.info("all_ips_unbanned", count=count)
|
||||
return count
|
||||
|
||||
128
backend/app/services/log_service.py
Normal file
128
backend/app/services/log_service.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""Log helper service.
|
||||
|
||||
Contains regex test and log preview helpers that are independent of
|
||||
fail2ban socket operations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from app.models.config import (
|
||||
LogPreviewLine,
|
||||
LogPreviewRequest,
|
||||
LogPreviewResponse,
|
||||
RegexTestRequest,
|
||||
RegexTestResponse,
|
||||
)
|
||||
|
||||
|
||||
def test_regex(request: RegexTestRequest) -> RegexTestResponse:
|
||||
"""Test a regex pattern against a sample log line.
|
||||
|
||||
Args:
|
||||
request: The regex test payload.
|
||||
|
||||
Returns:
|
||||
RegexTestResponse with match result, groups and optional error.
|
||||
"""
|
||||
try:
|
||||
compiled = re.compile(request.fail_regex)
|
||||
except re.error as exc:
|
||||
return RegexTestResponse(matched=False, groups=[], error=str(exc))
|
||||
|
||||
match = compiled.search(request.log_line)
|
||||
if match is None:
|
||||
return RegexTestResponse(matched=False)
|
||||
|
||||
groups: list[str] = list(match.groups() or [])
|
||||
return RegexTestResponse(matched=True, groups=[str(g) for g in groups if g is not None])
|
||||
|
||||
|
||||
async def preview_log(req: LogPreviewRequest) -> LogPreviewResponse:
|
||||
"""Inspect the last lines of a log file and evaluate regex matches.
|
||||
|
||||
Args:
|
||||
req: Log preview request.
|
||||
|
||||
Returns:
|
||||
LogPreviewResponse with lines, total_lines and matched_count, or error.
|
||||
"""
|
||||
try:
|
||||
compiled = re.compile(req.fail_regex)
|
||||
except re.error as exc:
|
||||
return LogPreviewResponse(
|
||||
lines=[],
|
||||
total_lines=0,
|
||||
matched_count=0,
|
||||
regex_error=str(exc),
|
||||
)
|
||||
|
||||
path = Path(req.log_path)
|
||||
if not path.is_file():
|
||||
return LogPreviewResponse(
|
||||
lines=[],
|
||||
total_lines=0,
|
||||
matched_count=0,
|
||||
regex_error=f"File not found: {req.log_path!r}",
|
||||
)
|
||||
|
||||
try:
|
||||
raw_lines = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
_read_tail_lines,
|
||||
str(path),
|
||||
req.num_lines,
|
||||
)
|
||||
except OSError as exc:
|
||||
return LogPreviewResponse(
|
||||
lines=[],
|
||||
total_lines=0,
|
||||
matched_count=0,
|
||||
regex_error=f"Cannot read file: {exc}",
|
||||
)
|
||||
|
||||
result_lines: list[LogPreviewLine] = []
|
||||
matched_count = 0
|
||||
for line in raw_lines:
|
||||
m = compiled.search(line)
|
||||
groups = [str(g) for g in (m.groups() or []) if g is not None] if m else []
|
||||
result_lines.append(
|
||||
LogPreviewLine(line=line, matched=(m is not None), groups=groups),
|
||||
)
|
||||
if m:
|
||||
matched_count += 1
|
||||
|
||||
return LogPreviewResponse(
|
||||
lines=result_lines,
|
||||
total_lines=len(result_lines),
|
||||
matched_count=matched_count,
|
||||
)
|
||||
|
||||
|
||||
def _read_tail_lines(file_path: str, num_lines: int) -> list[str]:
|
||||
"""Read the last *num_lines* from *file_path* in a memory-efficient way."""
|
||||
chunk_size = 8192
|
||||
raw_lines: list[bytes] = []
|
||||
with open(file_path, "rb") as fh:
|
||||
fh.seek(0, 2)
|
||||
end_pos = fh.tell()
|
||||
if end_pos == 0:
|
||||
return []
|
||||
|
||||
buf = b""
|
||||
pos = end_pos
|
||||
while len(raw_lines) <= num_lines and pos > 0:
|
||||
read_size = min(chunk_size, pos)
|
||||
pos -= read_size
|
||||
fh.seek(pos)
|
||||
chunk = fh.read(read_size)
|
||||
buf = chunk + buf
|
||||
raw_lines = buf.split(b"\n")
|
||||
|
||||
if pos > 0 and len(raw_lines) > 1:
|
||||
raw_lines = raw_lines[1:]
|
||||
|
||||
return [ln.decode("utf-8", errors="replace").rstrip() for ln in raw_lines[-num_lines:] if ln.strip()]
|
||||
@@ -817,7 +817,7 @@ async def get_parsed_filter_file(config_dir: str, name: str) -> FilterConfig:
|
||||
"""Parse a filter definition file and return its structured representation.
|
||||
|
||||
Reads the raw ``.conf``/``.local`` file from ``filter.d/``, parses it with
|
||||
:func:`~app.services.conffile_parser.parse_filter_file`, and returns the
|
||||
:func:`~app.utils.conffile_parser.parse_filter_file`, and returns the
|
||||
result.
|
||||
|
||||
Args:
|
||||
@@ -831,7 +831,7 @@ async def get_parsed_filter_file(config_dir: str, name: str) -> FilterConfig:
|
||||
ConfigFileNotFoundError: If no matching file is found.
|
||||
ConfigDirError: If *config_dir* does not exist.
|
||||
"""
|
||||
from app.services.conffile_parser import parse_filter_file # avoid circular imports
|
||||
from app.utils.conffile_parser import parse_filter_file # avoid circular imports
|
||||
|
||||
def _do() -> FilterConfig:
|
||||
filter_d = _resolve_subdir(config_dir, "filter.d")
|
||||
@@ -863,7 +863,7 @@ async def update_parsed_filter_file(
|
||||
ConfigFileWriteError: If the file cannot be written.
|
||||
ConfigDirError: If *config_dir* does not exist.
|
||||
"""
|
||||
from app.services.conffile_parser import ( # avoid circular imports
|
||||
from app.utils.conffile_parser import ( # avoid circular imports
|
||||
merge_filter_update,
|
||||
parse_filter_file,
|
||||
serialize_filter_config,
|
||||
@@ -901,7 +901,7 @@ async def get_parsed_action_file(config_dir: str, name: str) -> ActionConfig:
|
||||
ConfigFileNotFoundError: If no matching file is found.
|
||||
ConfigDirError: If *config_dir* does not exist.
|
||||
"""
|
||||
from app.services.conffile_parser import parse_action_file # avoid circular imports
|
||||
from app.utils.conffile_parser import parse_action_file # avoid circular imports
|
||||
|
||||
def _do() -> ActionConfig:
|
||||
action_d = _resolve_subdir(config_dir, "action.d")
|
||||
@@ -930,7 +930,7 @@ async def update_parsed_action_file(
|
||||
ConfigFileWriteError: If the file cannot be written.
|
||||
ConfigDirError: If *config_dir* does not exist.
|
||||
"""
|
||||
from app.services.conffile_parser import ( # avoid circular imports
|
||||
from app.utils.conffile_parser import ( # avoid circular imports
|
||||
merge_action_update,
|
||||
parse_action_file,
|
||||
serialize_action_config,
|
||||
@@ -963,7 +963,7 @@ async def get_parsed_jail_file(config_dir: str, filename: str) -> JailFileConfig
|
||||
ConfigFileNotFoundError: If no matching file is found.
|
||||
ConfigDirError: If *config_dir* does not exist.
|
||||
"""
|
||||
from app.services.conffile_parser import parse_jail_file # avoid circular imports
|
||||
from app.utils.conffile_parser import parse_jail_file # avoid circular imports
|
||||
|
||||
def _do() -> JailFileConfig:
|
||||
jail_d = _resolve_subdir(config_dir, "jail.d")
|
||||
@@ -992,7 +992,7 @@ async def update_parsed_jail_file(
|
||||
ConfigFileWriteError: If the file cannot be written.
|
||||
ConfigDirError: If *config_dir* does not exist.
|
||||
"""
|
||||
from app.services.conffile_parser import ( # avoid circular imports
|
||||
from app.utils.conffile_parser import ( # avoid circular imports
|
||||
merge_jail_file_update,
|
||||
parse_jail_file,
|
||||
serialize_jail_file_config,
|
||||
@@ -10,25 +10,50 @@ HTTP/FastAPI concerns.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import structlog
|
||||
|
||||
from app.exceptions import ServerOperationError
|
||||
from app.exceptions import ServerOperationError
|
||||
from app.models.server import ServerSettings, ServerSettingsResponse, ServerSettingsUpdate
|
||||
from app.utils.fail2ban_client import Fail2BanClient
|
||||
from app.utils.fail2ban_client import Fail2BanClient, Fail2BanCommand, Fail2BanResponse
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
type Fail2BanSettingValue = str | int | bool
|
||||
"""Allowed values for server settings commands."""
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
_SOCKET_TIMEOUT: float = 10.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Custom exceptions
|
||||
# ---------------------------------------------------------------------------
|
||||
def _to_int(value: object | None, default: int) -> int:
|
||||
"""Convert a raw value to an int, falling back to a default.
|
||||
|
||||
The fail2ban control socket can return either int or str values for some
|
||||
settings, so we normalise them here in a type-safe way.
|
||||
"""
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
if isinstance(value, float):
|
||||
return int(value)
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return default
|
||||
return default
|
||||
|
||||
|
||||
class ServerOperationError(Exception):
|
||||
"""Raised when a server-level set command fails."""
|
||||
def _to_str(value: object | None, default: str) -> str:
|
||||
"""Convert a raw value to a string, falling back to a default."""
|
||||
if value is None:
|
||||
return default
|
||||
return str(value)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -36,7 +61,7 @@ class ServerOperationError(Exception):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _ok(response: Any) -> Any:
|
||||
def _ok(response: Fail2BanResponse) -> object:
|
||||
"""Extract payload from a fail2ban ``(code, data)`` response.
|
||||
|
||||
Args:
|
||||
@@ -59,9 +84,9 @@ def _ok(response: Any) -> Any:
|
||||
|
||||
async def _safe_get(
|
||||
client: Fail2BanClient,
|
||||
command: list[Any],
|
||||
default: Any = None,
|
||||
) -> Any:
|
||||
command: Fail2BanCommand,
|
||||
default: object | None = None,
|
||||
) -> object | None:
|
||||
"""Send a command and silently return *default* on any error.
|
||||
|
||||
Args:
|
||||
@@ -73,7 +98,8 @@ async def _safe_get(
|
||||
The successful response, or *default*.
|
||||
"""
|
||||
try:
|
||||
return _ok(await client.send(command))
|
||||
response = await client.send(command)
|
||||
return _ok(cast("Fail2BanResponse", response))
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
@@ -118,17 +144,28 @@ async def get_settings(socket_path: str) -> ServerSettingsResponse:
|
||||
_safe_get(client, ["get", "dbmaxmatches"], 10),
|
||||
)
|
||||
|
||||
log_level = _to_str(log_level_raw, "INFO").upper()
|
||||
log_target = _to_str(log_target_raw, "STDOUT")
|
||||
syslog_socket = _to_str(syslog_socket_raw, "") or None
|
||||
db_path = _to_str(db_path_raw, "/var/lib/fail2ban/fail2ban.sqlite3")
|
||||
db_purge_age = _to_int(db_purge_age_raw, 86400)
|
||||
db_max_matches = _to_int(db_max_matches_raw, 10)
|
||||
|
||||
settings = ServerSettings(
|
||||
log_level=str(log_level_raw or "INFO").upper(),
|
||||
log_target=str(log_target_raw or "STDOUT"),
|
||||
syslog_socket=str(syslog_socket_raw) if syslog_socket_raw else None,
|
||||
db_path=str(db_path_raw or "/var/lib/fail2ban/fail2ban.sqlite3"),
|
||||
db_purge_age=int(db_purge_age_raw or 86400),
|
||||
db_max_matches=int(db_max_matches_raw or 10),
|
||||
log_level=log_level,
|
||||
log_target=log_target,
|
||||
syslog_socket=syslog_socket,
|
||||
db_path=db_path,
|
||||
db_purge_age=db_purge_age,
|
||||
db_max_matches=db_max_matches,
|
||||
)
|
||||
|
||||
log.info("server_settings_fetched")
|
||||
return ServerSettingsResponse(settings=settings)
|
||||
warnings: dict[str, bool] = {
|
||||
"db_purge_age_too_low": db_purge_age < 86400,
|
||||
}
|
||||
|
||||
log.info("server_settings_fetched", db_purge_age=db_purge_age, warnings=warnings)
|
||||
return ServerSettingsResponse(settings=settings, warnings=warnings)
|
||||
|
||||
|
||||
async def update_settings(socket_path: str, update: ServerSettingsUpdate) -> None:
|
||||
@@ -146,9 +183,10 @@ async def update_settings(socket_path: str, update: ServerSettingsUpdate) -> Non
|
||||
"""
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
|
||||
async def _set(key: str, value: Any) -> None:
|
||||
async def _set(key: str, value: Fail2BanSettingValue) -> None:
|
||||
try:
|
||||
_ok(await client.send(["set", key, value]))
|
||||
response = await client.send(["set", key, value])
|
||||
_ok(cast("Fail2BanResponse", response))
|
||||
except ValueError as exc:
|
||||
raise ServerOperationError(f"Failed to set {key!r} = {value!r}: {exc}") from exc
|
||||
|
||||
@@ -182,7 +220,8 @@ async def flush_logs(socket_path: str) -> str:
|
||||
"""
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
try:
|
||||
result = _ok(await client.send(["flushlogs"]))
|
||||
response = await client.send(["flushlogs"])
|
||||
result = _ok(cast("Fail2BanResponse", response))
|
||||
log.info("logs_flushed", result=result)
|
||||
return str(result)
|
||||
except ValueError as exc:
|
||||
|
||||
@@ -102,30 +102,20 @@ async def run_setup(
|
||||
log.info("bangui_setup_completed")
|
||||
|
||||
|
||||
from app.utils.setup_utils import (
|
||||
get_map_color_thresholds as util_get_map_color_thresholds,
|
||||
get_password_hash as util_get_password_hash,
|
||||
set_map_color_thresholds as util_set_map_color_thresholds,
|
||||
)
|
||||
|
||||
|
||||
async def get_password_hash(db: aiosqlite.Connection) -> str | None:
|
||||
"""Return the stored bcrypt password hash, or ``None`` if not set.
|
||||
|
||||
Args:
|
||||
db: Active aiosqlite connection.
|
||||
|
||||
Returns:
|
||||
The bcrypt hash string, or ``None``.
|
||||
"""
|
||||
return await settings_repo.get_setting(db, _KEY_PASSWORD_HASH)
|
||||
"""Return the stored bcrypt password hash, or ``None`` if not set."""
|
||||
return await util_get_password_hash(db)
|
||||
|
||||
|
||||
async def get_timezone(db: aiosqlite.Connection) -> str:
|
||||
"""Return the configured IANA timezone string.
|
||||
|
||||
Falls back to ``"UTC"`` when no timezone has been stored (e.g. before
|
||||
setup completes or for legacy databases).
|
||||
|
||||
Args:
|
||||
db: Active aiosqlite connection.
|
||||
|
||||
Returns:
|
||||
An IANA timezone identifier such as ``"Europe/Berlin"`` or ``"UTC"``.
|
||||
"""
|
||||
"""Return the configured IANA timezone string."""
|
||||
tz = await settings_repo.get_setting(db, _KEY_TIMEZONE)
|
||||
return tz if tz else "UTC"
|
||||
|
||||
@@ -133,31 +123,8 @@ async def get_timezone(db: aiosqlite.Connection) -> str:
|
||||
async def get_map_color_thresholds(
|
||||
db: aiosqlite.Connection,
|
||||
) -> tuple[int, int, int]:
|
||||
"""Return the configured map color thresholds (high, medium, low).
|
||||
|
||||
Falls back to default values (100, 50, 20) if not set.
|
||||
|
||||
Args:
|
||||
db: Active aiosqlite connection.
|
||||
|
||||
Returns:
|
||||
A tuple of (threshold_high, threshold_medium, threshold_low).
|
||||
"""
|
||||
high = await settings_repo.get_setting(
|
||||
db, _KEY_MAP_COLOR_THRESHOLD_HIGH
|
||||
)
|
||||
medium = await settings_repo.get_setting(
|
||||
db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM
|
||||
)
|
||||
low = await settings_repo.get_setting(
|
||||
db, _KEY_MAP_COLOR_THRESHOLD_LOW
|
||||
)
|
||||
|
||||
return (
|
||||
int(high) if high else 100,
|
||||
int(medium) if medium else 50,
|
||||
int(low) if low else 20,
|
||||
)
|
||||
"""Return the configured map color thresholds (high, medium, low)."""
|
||||
return await util_get_map_color_thresholds(db)
|
||||
|
||||
|
||||
async def set_map_color_thresholds(
|
||||
@@ -167,31 +134,12 @@ async def set_map_color_thresholds(
|
||||
threshold_medium: int,
|
||||
threshold_low: int,
|
||||
) -> None:
|
||||
"""Update the map color threshold configuration.
|
||||
|
||||
Args:
|
||||
db: Active aiosqlite connection.
|
||||
threshold_high: Ban count for red coloring.
|
||||
threshold_medium: Ban count for yellow coloring.
|
||||
threshold_low: Ban count for green coloring.
|
||||
|
||||
Raises:
|
||||
ValueError: If thresholds are not positive integers or if
|
||||
high <= medium <= low.
|
||||
"""
|
||||
if threshold_high <= 0 or threshold_medium <= 0 or threshold_low <= 0:
|
||||
raise ValueError("All thresholds must be positive integers.")
|
||||
if not (threshold_high > threshold_medium > threshold_low):
|
||||
raise ValueError("Thresholds must satisfy: high > medium > low.")
|
||||
|
||||
await settings_repo.set_setting(
|
||||
db, _KEY_MAP_COLOR_THRESHOLD_HIGH, str(threshold_high)
|
||||
)
|
||||
await settings_repo.set_setting(
|
||||
db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM, str(threshold_medium)
|
||||
)
|
||||
await settings_repo.set_setting(
|
||||
db, _KEY_MAP_COLOR_THRESHOLD_LOW, str(threshold_low)
|
||||
"""Update the map color threshold configuration."""
|
||||
await util_set_map_color_thresholds(
|
||||
db,
|
||||
threshold_high=threshold_high,
|
||||
threshold_medium=threshold_medium,
|
||||
threshold_low=threshold_low,
|
||||
)
|
||||
log.info(
|
||||
"map_color_thresholds_updated",
|
||||
|
||||
@@ -43,9 +43,15 @@ async def _run_import(app: Any) -> None:
|
||||
http_session = app.state.http_session
|
||||
socket_path: str = app.state.settings.fail2ban_socket
|
||||
|
||||
from app.services import jail_service
|
||||
|
||||
log.info("blocklist_import_starting")
|
||||
try:
|
||||
result = await blocklist_service.import_all(db, http_session, socket_path)
|
||||
result = await blocklist_service.import_all(
|
||||
db,
|
||||
http_session,
|
||||
socket_path,
|
||||
)
|
||||
log.info(
|
||||
"blocklist_import_finished",
|
||||
total_imported=result.total_imported,
|
||||
|
||||
@@ -17,7 +17,7 @@ The task runs every 10 minutes. On each invocation it:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
|
||||
@@ -35,7 +35,7 @@ GEO_RE_RESOLVE_INTERVAL: int = 600
|
||||
JOB_ID: str = "geo_re_resolve"
|
||||
|
||||
|
||||
async def _run_re_resolve(app: Any) -> None:
|
||||
async def _run_re_resolve(app: FastAPI) -> None:
|
||||
"""Query NULL-country IPs from the database and re-resolve them.
|
||||
|
||||
Reads shared resources from ``app.state`` and delegates to
|
||||
@@ -49,12 +49,7 @@ async def _run_re_resolve(app: Any) -> None:
|
||||
http_session = app.state.http_session
|
||||
|
||||
# Fetch all IPs with NULL country_code from the persistent cache.
|
||||
unresolved_ips: list[str] = []
|
||||
async with db.execute(
|
||||
"SELECT ip FROM geo_cache WHERE country_code IS NULL"
|
||||
) as cursor:
|
||||
async for row in cursor:
|
||||
unresolved_ips.append(str(row[0]))
|
||||
unresolved_ips = await geo_service.get_unresolved_ips(db)
|
||||
|
||||
if not unresolved_ips:
|
||||
log.debug("geo_re_resolve_skip", reason="no_unresolved_ips")
|
||||
|
||||
@@ -4,14 +4,25 @@ Registers an APScheduler job that probes the fail2ban socket every 30 seconds
|
||||
and stores the result on ``app.state.server_status``. The dashboard endpoint
|
||||
reads from this cache, keeping HTTP responses fast and the daemon connection
|
||||
decoupled from user-facing requests.
|
||||
|
||||
Crash detection (Task 3)
|
||||
------------------------
|
||||
When a jail activation is performed, the router stores a timestamp on
|
||||
``app.state.last_activation`` (a ``dict`` with ``jail_name`` and ``at``
|
||||
keys). If the health probe subsequently detects an online→offline transition
|
||||
within 60 seconds of that activation, a
|
||||
:class:`~app.models.config.PendingRecovery` record is written to
|
||||
``app.state.pending_recovery`` so the UI can offer a one-click rollback.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
import datetime
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
|
||||
import structlog
|
||||
|
||||
from app.models.config import PendingRecovery
|
||||
from app.models.server import ServerStatus
|
||||
from app.services import health_service
|
||||
|
||||
@@ -20,13 +31,30 @@ if TYPE_CHECKING: # pragma: no cover
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
|
||||
class ActivationRecord(TypedDict):
|
||||
"""Stored timestamp data for a jail activation event."""
|
||||
|
||||
jail_name: str
|
||||
at: datetime.datetime
|
||||
|
||||
|
||||
#: How often the probe fires (seconds).
|
||||
HEALTH_CHECK_INTERVAL: int = 30
|
||||
|
||||
#: Maximum seconds since an activation for a subsequent crash to be attributed
|
||||
#: to that activation.
|
||||
_ACTIVATION_CRASH_WINDOW: int = 60
|
||||
|
||||
async def _run_probe(app: Any) -> None:
|
||||
|
||||
async def _run_probe(app: FastAPI) -> None:
|
||||
"""Probe fail2ban and cache the result on *app.state*.
|
||||
|
||||
Detects online/offline state transitions. When fail2ban goes offline
|
||||
within :data:`_ACTIVATION_CRASH_WINDOW` seconds of the last jail
|
||||
activation, writes a :class:`~app.models.config.PendingRecovery` record to
|
||||
``app.state.pending_recovery``.
|
||||
|
||||
This is the APScheduler job callback. It reads ``fail2ban_socket`` from
|
||||
``app.state.settings``, runs the health probe, and writes the result to
|
||||
``app.state.server_status``.
|
||||
@@ -42,11 +70,54 @@ async def _run_probe(app: Any) -> None:
|
||||
status: ServerStatus = await health_service.probe(socket_path)
|
||||
app.state.server_status = status
|
||||
|
||||
now = datetime.datetime.now(tz=datetime.UTC)
|
||||
|
||||
# Log transitions between online and offline states.
|
||||
if status.online and not prev_status.online:
|
||||
log.info("fail2ban_came_online", version=status.version)
|
||||
# Clear any pending recovery once fail2ban is back online.
|
||||
existing: PendingRecovery | None = getattr(
|
||||
app.state, "pending_recovery", None
|
||||
)
|
||||
if existing is not None and not existing.recovered:
|
||||
app.state.pending_recovery = PendingRecovery(
|
||||
jail_name=existing.jail_name,
|
||||
activated_at=existing.activated_at,
|
||||
detected_at=existing.detected_at,
|
||||
recovered=True,
|
||||
)
|
||||
log.info(
|
||||
"pending_recovery_resolved",
|
||||
jail=existing.jail_name,
|
||||
)
|
||||
|
||||
elif not status.online and prev_status.online:
|
||||
log.warning("fail2ban_went_offline")
|
||||
# Check whether this crash happened shortly after a jail activation.
|
||||
last_activation: ActivationRecord | None = getattr(
|
||||
app.state, "last_activation", None
|
||||
)
|
||||
if last_activation is not None:
|
||||
activated_at: datetime.datetime = last_activation["at"]
|
||||
seconds_since = (now - activated_at).total_seconds()
|
||||
if seconds_since <= _ACTIVATION_CRASH_WINDOW:
|
||||
jail_name: str = last_activation["jail_name"]
|
||||
# Only create a new record when there is not already an
|
||||
# unresolved one for the same jail.
|
||||
current: PendingRecovery | None = getattr(
|
||||
app.state, "pending_recovery", None
|
||||
)
|
||||
if current is None or current.recovered:
|
||||
app.state.pending_recovery = PendingRecovery(
|
||||
jail_name=jail_name,
|
||||
activated_at=activated_at,
|
||||
detected_at=now,
|
||||
)
|
||||
log.warning(
|
||||
"activation_crash_detected",
|
||||
jail=jail_name,
|
||||
seconds_since_activation=seconds_since,
|
||||
)
|
||||
|
||||
log.debug(
|
||||
"health_check_complete",
|
||||
@@ -71,6 +142,10 @@ def register(app: FastAPI) -> None:
|
||||
# first probe fires.
|
||||
app.state.server_status = ServerStatus(online=False)
|
||||
|
||||
# Initialise activation tracking state.
|
||||
app.state.last_activation = None
|
||||
app.state.pending_recovery = None
|
||||
|
||||
app.state.scheduler.add_job(
|
||||
_run_probe,
|
||||
trigger="interval",
|
||||
|
||||
109
backend/app/tasks/history_sync.py
Normal file
109
backend/app/tasks/history_sync.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""History sync background task.
|
||||
|
||||
Periodically copies new records from the fail2ban sqlite database into the
|
||||
BanGUI application archive table to prevent gaps when fail2ban purges old rows.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
|
||||
from app.repositories import fail2ban_db_repo
|
||||
from app.utils.fail2ban_db_utils import get_fail2ban_db_path
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from fastapi import FastAPI
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
#: Stable APScheduler job id.
|
||||
JOB_ID: str = "history_sync"
|
||||
|
||||
#: Interval in seconds between sync runs.
|
||||
HISTORY_SYNC_INTERVAL: int = 300
|
||||
|
||||
#: Backfill window when archive is empty (seconds).
|
||||
BACKFILL_WINDOW: int = 648000
|
||||
|
||||
|
||||
async def _get_last_archive_ts(db) -> int | None:
|
||||
async with db.execute("SELECT MAX(timeofban) FROM history_archive") as cur:
|
||||
row = await cur.fetchone()
|
||||
if row is None or row[0] is None:
|
||||
return None
|
||||
return int(row[0])
|
||||
|
||||
|
||||
async def _run_sync(app: FastAPI) -> None:
|
||||
db = app.state.db
|
||||
socket_path: str = app.state.settings.fail2ban_socket
|
||||
|
||||
try:
|
||||
last_ts = await _get_last_archive_ts(db)
|
||||
now_ts = int(datetime.datetime.now(datetime.UTC).timestamp())
|
||||
|
||||
if last_ts is None:
|
||||
last_ts = now_ts - BACKFILL_WINDOW
|
||||
log.info("history_sync_backfill", window_seconds=BACKFILL_WINDOW)
|
||||
|
||||
per_page = 500
|
||||
next_since = last_ts + 1
|
||||
total_synced = 0
|
||||
|
||||
while True:
|
||||
fail2ban_db_path = await get_fail2ban_db_path(socket_path)
|
||||
rows, total = await fail2ban_db_repo.get_history_page(
|
||||
db_path=fail2ban_db_path,
|
||||
since=next_since,
|
||||
page=1,
|
||||
page_size=per_page,
|
||||
)
|
||||
|
||||
if not rows:
|
||||
break
|
||||
|
||||
from app.repositories.history_archive_repo import archive_ban_event
|
||||
|
||||
for row in rows:
|
||||
await archive_ban_event(
|
||||
db=db,
|
||||
jail=row.jail,
|
||||
ip=row.ip,
|
||||
timeofban=row.timeofban,
|
||||
bancount=row.bancount,
|
||||
data=row.data,
|
||||
action="ban",
|
||||
)
|
||||
total_synced += 1
|
||||
|
||||
# Continue where we left off by max timeofban + 1.
|
||||
max_time = max(row.timeofban for row in rows)
|
||||
next_since = max_time + 1
|
||||
|
||||
if len(rows) < per_page:
|
||||
break
|
||||
|
||||
log.info("history_sync_complete", synced=total_synced)
|
||||
|
||||
except Exception:
|
||||
log.exception("history_sync_failed")
|
||||
|
||||
|
||||
def register(app: FastAPI) -> None:
|
||||
"""Register the history sync periodic job.
|
||||
|
||||
Should be called after scheduler startup, from the lifespan handler.
|
||||
"""
|
||||
app.state.scheduler.add_job(
|
||||
_run_sync,
|
||||
trigger="interval",
|
||||
seconds=HISTORY_SYNC_INTERVAL,
|
||||
kwargs={"app": app},
|
||||
id=JOB_ID,
|
||||
replace_existing=True,
|
||||
next_run_time=datetime.datetime.now(tz=datetime.UTC),
|
||||
)
|
||||
log.info("history_sync_scheduled", interval_seconds=HISTORY_SYNC_INTERVAL)
|
||||
21
backend/app/utils/config_file_utils.py
Normal file
21
backend/app/utils/config_file_utils.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Utilities re-exported from config_file_service for cross-module usage."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from app.services.config_file_service import (
|
||||
_build_inactive_jail,
|
||||
_get_active_jail_names,
|
||||
_ordered_config_files,
|
||||
_parse_jails_sync,
|
||||
_validate_jail_config_sync,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"_ordered_config_files",
|
||||
"_parse_jails_sync",
|
||||
"_build_inactive_jail",
|
||||
"_get_active_jail_names",
|
||||
"_validate_jail_config_sync",
|
||||
]
|
||||
@@ -21,14 +21,52 @@ import contextlib
|
||||
import errno
|
||||
import socket
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence, Set
|
||||
from pickle import HIGHEST_PROTOCOL, dumps, loads
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Use covariant container types so callers can pass ``list[int]`` / ``dict[str, str]``
|
||||
# without needing to cast. At runtime we only accept the basic built-in
|
||||
# containers supported by fail2ban's protocol (list/dict/set) and stringify
|
||||
# anything else.
|
||||
#
|
||||
# NOTE: ``Sequence`` will also accept tuples, but tuples are stringified at
|
||||
# runtime because fail2ban only understands lists.
|
||||
|
||||
type Fail2BanToken = (
|
||||
str
|
||||
| int
|
||||
| float
|
||||
| bool
|
||||
| None
|
||||
| Mapping[str, object]
|
||||
| Sequence[object]
|
||||
| Set[object]
|
||||
)
|
||||
"""A single token in a fail2ban command.
|
||||
|
||||
Fail2ban accepts simple types (str/int/float/bool) plus compound types
|
||||
(list/dict/set). Complex objects are stringified before being sent.
|
||||
"""
|
||||
|
||||
type Fail2BanCommand = Sequence[Fail2BanToken]
|
||||
"""A command sent to fail2ban over the socket.
|
||||
|
||||
Commands are pickle serialised sequences of tokens.
|
||||
"""
|
||||
|
||||
type Fail2BanResponse = tuple[int, object]
|
||||
"""A typical fail2ban response containing a status code and payload."""
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
|
||||
import structlog
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
# fail2ban protocol constants — inline to avoid a hard import dependency
|
||||
@@ -81,9 +119,9 @@ class Fail2BanProtocolError(Exception):
|
||||
|
||||
def _send_command_sync(
|
||||
socket_path: str,
|
||||
command: list[Any],
|
||||
command: Fail2BanCommand,
|
||||
timeout: float,
|
||||
) -> Any:
|
||||
) -> object:
|
||||
"""Send a command to fail2ban and return the parsed response.
|
||||
|
||||
This is a **synchronous** function intended to be called from within
|
||||
@@ -180,7 +218,7 @@ def _send_command_sync(
|
||||
) from last_oserror
|
||||
|
||||
|
||||
def _coerce_command_token(token: Any) -> Any:
|
||||
def _coerce_command_token(token: object) -> Fail2BanToken:
|
||||
"""Coerce a command token to a type that fail2ban understands.
|
||||
|
||||
fail2ban's ``CSocket.convert`` accepts ``str``, ``bool``, ``int``,
|
||||
@@ -229,7 +267,7 @@ class Fail2BanClient:
|
||||
self.socket_path: str = socket_path
|
||||
self.timeout: float = timeout
|
||||
|
||||
async def send(self, command: list[Any]) -> Any:
|
||||
async def send(self, command: Fail2BanCommand) -> object:
|
||||
"""Send a command to fail2ban and return the response.
|
||||
|
||||
Acquires the module-level concurrency semaphore before dispatching
|
||||
@@ -267,13 +305,13 @@ class Fail2BanClient:
|
||||
log.debug("fail2ban_sending_command", command=command)
|
||||
loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
|
||||
try:
|
||||
response: Any = await loop.run_in_executor(
|
||||
None,
|
||||
_send_command_sync,
|
||||
self.socket_path,
|
||||
command,
|
||||
self.timeout,
|
||||
)
|
||||
response: object = await loop.run_in_executor(
|
||||
None,
|
||||
_send_command_sync,
|
||||
self.socket_path,
|
||||
command,
|
||||
self.timeout,
|
||||
)
|
||||
except Fail2BanConnectionError:
|
||||
log.warning(
|
||||
"fail2ban_connection_error",
|
||||
@@ -300,7 +338,7 @@ class Fail2BanClient:
|
||||
``True`` when the daemon responds correctly, ``False`` otherwise.
|
||||
"""
|
||||
try:
|
||||
response: Any = await self.send(["ping"])
|
||||
response: object = await self.send(["ping"])
|
||||
return bool(response == 1) # fail2ban returns 1 on successful ping
|
||||
except (Fail2BanConnectionError, Fail2BanProtocolError):
|
||||
return False
|
||||
|
||||
63
backend/app/utils/fail2ban_db_utils.py
Normal file
63
backend/app/utils/fail2ban_db_utils.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Utilities shared by fail2ban-related services."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
|
||||
|
||||
def ts_to_iso(unix_ts: int) -> str:
|
||||
"""Convert a Unix timestamp to an ISO 8601 UTC string."""
|
||||
return datetime.fromtimestamp(unix_ts, tz=UTC).isoformat()
|
||||
|
||||
|
||||
async def get_fail2ban_db_path(socket_path: str) -> str:
|
||||
"""Query fail2ban for the path to its SQLite database file."""
|
||||
from app.utils.fail2ban_client import Fail2BanClient # pragma: no cover
|
||||
|
||||
socket_timeout: float = 5.0
|
||||
|
||||
async with Fail2BanClient(socket_path, timeout=socket_timeout) as client:
|
||||
response = await client.send(["get", "dbfile"])
|
||||
|
||||
if not isinstance(response, tuple) or len(response) != 2:
|
||||
raise RuntimeError(f"Unexpected response from fail2ban: {response!r}")
|
||||
|
||||
code, data = response
|
||||
if code != 0:
|
||||
raise RuntimeError(f"fail2ban error code {code}: {data!r}")
|
||||
|
||||
if data is None:
|
||||
raise RuntimeError("fail2ban has no database configured (dbfile is None)")
|
||||
|
||||
return str(data)
|
||||
|
||||
|
||||
def parse_data_json(raw: object) -> tuple[list[str], int]:
|
||||
"""Extract matches and failure count from the fail2ban bans.data value."""
|
||||
if raw is None:
|
||||
return [], 0
|
||||
|
||||
obj: dict[str, object] = {}
|
||||
if isinstance(raw, str):
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
if isinstance(parsed, dict):
|
||||
obj = parsed
|
||||
except json.JSONDecodeError:
|
||||
return [], 0
|
||||
elif isinstance(raw, dict):
|
||||
obj = raw
|
||||
|
||||
raw_matches = obj.get("matches")
|
||||
matches = [str(m) for m in raw_matches] if isinstance(raw_matches, list) else []
|
||||
|
||||
raw_failures = obj.get("failures")
|
||||
failures = 0
|
||||
if isinstance(raw_failures, (int, float, str)):
|
||||
try:
|
||||
failures = int(raw_failures)
|
||||
except (ValueError, TypeError):
|
||||
failures = 0
|
||||
|
||||
return matches, failures
|
||||
93
backend/app/utils/jail_config.py
Normal file
93
backend/app/utils/jail_config.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Utilities for ensuring required fail2ban jail configuration files exist.
|
||||
|
||||
BanGUI requires two custom jails — ``manual-Jail`` and ``blocklist-import``
|
||||
— to be present in the fail2ban ``jail.d`` directory. This module provides
|
||||
:func:`ensure_jail_configs` which checks each of the four files
|
||||
(``*.conf`` template + ``*.local`` override) and creates any that are missing
|
||||
with the correct default content.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Default file contents
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_MANUAL_JAIL_CONF = """\
|
||||
[manual-Jail]
|
||||
|
||||
enabled = false
|
||||
filter = manual-Jail
|
||||
logpath = /remotelogs/bangui/auth.log
|
||||
backend = polling
|
||||
maxretry = 3
|
||||
findtime = 120
|
||||
bantime = 60
|
||||
ignoreip = 127.0.0.0/8 ::1 172.16.0.0/12
|
||||
"""
|
||||
|
||||
_MANUAL_JAIL_LOCAL = """\
|
||||
[manual-Jail]
|
||||
enabled = true
|
||||
"""
|
||||
|
||||
_BLOCKLIST_IMPORT_CONF = """\
|
||||
[blocklist-import]
|
||||
|
||||
enabled = false
|
||||
filter =
|
||||
logpath = /dev/null
|
||||
backend = auto
|
||||
maxretry = 1
|
||||
findtime = 1d
|
||||
bantime = 86400
|
||||
ignoreip = 127.0.0.0/8 ::1 172.16.0.0/12
|
||||
"""
|
||||
|
||||
_BLOCKLIST_IMPORT_LOCAL = """\
|
||||
[blocklist-import]
|
||||
enabled = true
|
||||
"""
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# File registry: (filename, default_content)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_JAIL_FILES: list[tuple[str, str]] = [
|
||||
("manual-Jail.conf", _MANUAL_JAIL_CONF),
|
||||
("manual-Jail.local", _MANUAL_JAIL_LOCAL),
|
||||
("blocklist-import.conf", _BLOCKLIST_IMPORT_CONF),
|
||||
("blocklist-import.local", _BLOCKLIST_IMPORT_LOCAL),
|
||||
]
|
||||
|
||||
|
||||
def ensure_jail_configs(jail_d_path: Path) -> None:
|
||||
"""Ensure the required fail2ban jail configuration files exist.
|
||||
|
||||
Checks for ``manual-Jail.conf``, ``manual-Jail.local``,
|
||||
``blocklist-import.conf``, and ``blocklist-import.local`` inside
|
||||
*jail_d_path*. Any file that is missing is created with its default
|
||||
content. Existing files are **never** overwritten.
|
||||
|
||||
Args:
|
||||
jail_d_path: Path to the fail2ban ``jail.d`` directory. Will be
|
||||
created (including all parents) if it does not already exist.
|
||||
"""
|
||||
jail_d_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for filename, default_content in _JAIL_FILES:
|
||||
file_path = jail_d_path / filename
|
||||
if file_path.exists():
|
||||
log.debug("jail_config_already_exists", path=str(file_path))
|
||||
else:
|
||||
file_path.write_text(default_content, encoding="utf-8")
|
||||
log.info("jail_config_created", path=str(file_path))
|
||||
20
backend/app/utils/jail_utils.py
Normal file
20
backend/app/utils/jail_utils.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Jail helpers to decouple service layer dependencies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from app.services.jail_service import reload_all
|
||||
|
||||
|
||||
async def reload_jails(
|
||||
socket_path: str,
|
||||
include_jails: Sequence[str] | None = None,
|
||||
exclude_jails: Sequence[str] | None = None,
|
||||
) -> None:
|
||||
"""Reload fail2ban jails using shared jail service helper."""
|
||||
await reload_all(
|
||||
socket_path,
|
||||
include_jails=list(include_jails) if include_jails is not None else None,
|
||||
exclude_jails=list(exclude_jails) if exclude_jails is not None else None,
|
||||
)
|
||||
14
backend/app/utils/log_utils.py
Normal file
14
backend/app/utils/log_utils.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Log-related helpers to avoid direct service-to-service imports."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.models.config import LogPreviewRequest, LogPreviewResponse, RegexTestRequest, RegexTestResponse
|
||||
from app.services.log_service import preview_log as _preview_log, test_regex as _test_regex
|
||||
|
||||
|
||||
async def preview_log(req: LogPreviewRequest) -> LogPreviewResponse:
|
||||
return await _preview_log(req)
|
||||
|
||||
|
||||
def test_regex(req: RegexTestRequest) -> RegexTestResponse:
|
||||
return _test_regex(req)
|
||||
47
backend/app/utils/setup_utils.py
Normal file
47
backend/app/utils/setup_utils.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Setup-related utilities shared by multiple services."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.repositories import settings_repo
|
||||
|
||||
_KEY_PASSWORD_HASH = "master_password_hash"
|
||||
_KEY_SETUP_DONE = "setup_completed"
|
||||
_KEY_MAP_COLOR_THRESHOLD_HIGH = "map_color_threshold_high"
|
||||
_KEY_MAP_COLOR_THRESHOLD_MEDIUM = "map_color_threshold_medium"
|
||||
_KEY_MAP_COLOR_THRESHOLD_LOW = "map_color_threshold_low"
|
||||
|
||||
|
||||
async def get_password_hash(db):
|
||||
"""Return the stored master password hash or None."""
|
||||
return await settings_repo.get_setting(db, _KEY_PASSWORD_HASH)
|
||||
|
||||
|
||||
async def get_map_color_thresholds(db):
|
||||
"""Return map color thresholds as tuple (high, medium, low)."""
|
||||
high = await settings_repo.get_setting(db, _KEY_MAP_COLOR_THRESHOLD_HIGH)
|
||||
medium = await settings_repo.get_setting(db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM)
|
||||
low = await settings_repo.get_setting(db, _KEY_MAP_COLOR_THRESHOLD_LOW)
|
||||
|
||||
return (
|
||||
int(high) if high else 100,
|
||||
int(medium) if medium else 50,
|
||||
int(low) if low else 20,
|
||||
)
|
||||
|
||||
|
||||
async def set_map_color_thresholds(
|
||||
db,
|
||||
*,
|
||||
threshold_high: int,
|
||||
threshold_medium: int,
|
||||
threshold_low: int,
|
||||
) -> None:
|
||||
"""Persist map color thresholds after validating values."""
|
||||
if threshold_high <= 0 or threshold_medium <= 0 or threshold_low <= 0:
|
||||
raise ValueError("All thresholds must be positive integers.")
|
||||
if not (threshold_high > threshold_medium > threshold_low):
|
||||
raise ValueError("Thresholds must satisfy: high > medium > low.")
|
||||
|
||||
await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_HIGH, str(threshold_high))
|
||||
await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM, str(threshold_medium))
|
||||
await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_LOW, str(threshold_low))
|
||||
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "bangui-backend"
|
||||
version = "0.1.0"
|
||||
version = "0.9.18"
|
||||
description = "BanGUI backend — fail2ban web management interface"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
@@ -60,4 +60,5 @@ plugins = ["pydantic.mypy"]
|
||||
asyncio_mode = "auto"
|
||||
pythonpath = [".", "../fail2ban-master"]
|
||||
testpaths = ["tests"]
|
||||
addopts = "--cov=app --cov-report=term-missing"
|
||||
addopts = "--asyncio-mode=auto --cov=app --cov-report=term-missing"
|
||||
filterwarnings = ["ignore::pytest.PytestRemovedIn9Warning"]
|
||||
|
||||
@@ -37,9 +37,15 @@ def test_settings(tmp_path: Path) -> Settings:
|
||||
Returns:
|
||||
A :class:`~app.config.Settings` instance with overridden paths.
|
||||
"""
|
||||
config_dir = tmp_path / "fail2ban"
|
||||
(config_dir / "jail.d").mkdir(parents=True)
|
||||
(config_dir / "filter.d").mkdir(parents=True)
|
||||
(config_dir / "action.d").mkdir(parents=True)
|
||||
|
||||
return Settings(
|
||||
database_path=str(tmp_path / "test_bangui.db"),
|
||||
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||
fail2ban_config_dir=str(config_dir),
|
||||
session_secret="test-secret-key-do-not-use-in-production",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
|
||||
276
backend/tests/test_regression_500s.py
Normal file
276
backend/tests/test_regression_500s.py
Normal file
@@ -0,0 +1,276 @@
|
||||
"""Regression tests for the four 500-error bugs discovered on 2026-03-22.
|
||||
|
||||
Each test targets the exact code path that caused a 500 Internal Server Error.
|
||||
These tests call the **real** service/repository functions (not the router)
|
||||
so they fail even if the route layer is mocked in router-level tests.
|
||||
|
||||
Bugs covered:
|
||||
1. ``list_history`` rejected the ``origin`` keyword argument (TypeError).
|
||||
2. ``jail_config_service`` used ``_get_active_jail_names`` without importing it.
|
||||
3. ``filter_config_service`` used ``_parse_jails_sync`` / ``_get_active_jail_names``
|
||||
without importing them.
|
||||
4. ``config_service.get_service_status`` omitted the required ``bangui_version``
|
||||
field from the ``ServiceStatusResponse`` constructor (Pydantic ValidationError).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
|
||||
# ── Bug 1 ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestHistoryOriginParameter:
|
||||
"""Bug 1: ``origin`` parameter must be threaded through service → repo."""
|
||||
|
||||
# -- Service layer --
|
||||
|
||||
async def test_list_history_accepts_origin_kwarg(self) -> None:
|
||||
"""``history_service.list_history()`` must accept an ``origin`` keyword."""
|
||||
from app.services import history_service
|
||||
|
||||
sig = inspect.signature(history_service.list_history)
|
||||
assert "origin" in sig.parameters, (
|
||||
"list_history() is missing the 'origin' parameter — "
|
||||
"the router passes origin=… which would cause a TypeError"
|
||||
)
|
||||
|
||||
async def test_list_history_forwards_origin_to_repo(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""``list_history(origin='blocklist')`` must forward origin to the DB repo."""
|
||||
from app.services import history_service
|
||||
|
||||
db_path = str(tmp_path / "f2b.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await db.execute(
|
||||
"CREATE TABLE jails (name TEXT, enabled INTEGER DEFAULT 1)"
|
||||
)
|
||||
await db.execute(
|
||||
"CREATE TABLE bans "
|
||||
"(jail TEXT, ip TEXT, timeofban INTEGER, bantime INTEGER, "
|
||||
"bancount INTEGER DEFAULT 1, data JSON)"
|
||||
)
|
||||
await db.execute(
|
||||
"INSERT INTO bans VALUES (?, ?, ?, ?, ?, ?)",
|
||||
("blocklist-import", "10.0.0.1", int(time.time()), 3600, 1, "{}"),
|
||||
)
|
||||
await db.execute(
|
||||
"INSERT INTO bans VALUES (?, ?, ?, ?, ?, ?)",
|
||||
("sshd", "10.0.0.2", int(time.time()), 3600, 1, "{}"),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
with patch(
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=db_path),
|
||||
):
|
||||
result = await history_service.list_history(
|
||||
"fake_socket", origin="blocklist"
|
||||
)
|
||||
|
||||
assert all(
|
||||
item.jail == "blocklist-import" for item in result.items
|
||||
), "origin='blocklist' must filter to blocklist-import jail only"
|
||||
|
||||
# -- Repository layer --
|
||||
|
||||
async def test_get_history_page_accepts_origin_kwarg(self) -> None:
|
||||
"""``fail2ban_db_repo.get_history_page()`` must accept ``origin``."""
|
||||
from app.repositories import fail2ban_db_repo
|
||||
|
||||
sig = inspect.signature(fail2ban_db_repo.get_history_page)
|
||||
assert "origin" in sig.parameters, (
|
||||
"get_history_page() is missing the 'origin' parameter"
|
||||
)
|
||||
|
||||
async def test_get_history_page_filters_by_origin(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""``get_history_page(origin='selfblock')`` excludes blocklist-import."""
|
||||
from app.repositories import fail2ban_db_repo
|
||||
|
||||
db_path = str(tmp_path / "f2b.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await db.execute(
|
||||
"CREATE TABLE bans "
|
||||
"(jail TEXT, ip TEXT, timeofban INTEGER, bancount INTEGER, data TEXT)"
|
||||
)
|
||||
await db.executemany(
|
||||
"INSERT INTO bans VALUES (?, ?, ?, ?, ?)",
|
||||
[
|
||||
("blocklist-import", "10.0.0.1", 100, 1, "{}"),
|
||||
("sshd", "10.0.0.2", 200, 1, "{}"),
|
||||
("sshd", "10.0.0.3", 300, 1, "{}"),
|
||||
],
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
rows, total = await fail2ban_db_repo.get_history_page(
|
||||
db_path=db_path, origin="selfblock"
|
||||
)
|
||||
|
||||
assert total == 2
|
||||
assert all(r.jail != "blocklist-import" for r in rows)
|
||||
|
||||
|
||||
# ── Bug 2 ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestJailConfigImports:
|
||||
"""Bug 2: ``jail_config_service`` must import ``_get_active_jail_names``."""
|
||||
|
||||
async def test_get_active_jail_names_is_importable(self) -> None:
|
||||
"""The module must successfully import ``_get_active_jail_names``."""
|
||||
import app.services.jail_config_service as mod
|
||||
|
||||
assert hasattr(mod, "_get_active_jail_names") or callable(
|
||||
getattr(mod, "_get_active_jail_names", None)
|
||||
), (
|
||||
"_get_active_jail_names is not available in jail_config_service — "
|
||||
"any call site will raise NameError → 500"
|
||||
)
|
||||
|
||||
async def test_list_inactive_jails_does_not_raise_name_error(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""``list_inactive_jails`` must not crash with NameError."""
|
||||
from app.services import jail_config_service
|
||||
|
||||
config_dir = str(tmp_path / "fail2ban")
|
||||
Path(config_dir).mkdir()
|
||||
(Path(config_dir) / "jail.conf").write_text("[DEFAULT]\n")
|
||||
|
||||
with patch(
|
||||
"app.services.jail_config_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
):
|
||||
result = await jail_config_service.list_inactive_jails(
|
||||
config_dir, "/fake/socket"
|
||||
)
|
||||
|
||||
assert result.total >= 0
|
||||
|
||||
|
||||
# ── Bug 3 ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestFilterConfigImports:
|
||||
"""Bug 3: ``filter_config_service`` must import ``_parse_jails_sync``
|
||||
and ``_get_active_jail_names``."""
|
||||
|
||||
async def test_parse_jails_sync_is_available(self) -> None:
|
||||
"""``_parse_jails_sync`` must be resolvable at module scope."""
|
||||
import app.services.filter_config_service as mod
|
||||
|
||||
assert hasattr(mod, "_parse_jails_sync"), (
|
||||
"_parse_jails_sync is not available in filter_config_service — "
|
||||
"list_filters() will raise NameError → 500"
|
||||
)
|
||||
|
||||
async def test_get_active_jail_names_is_available(self) -> None:
|
||||
"""``_get_active_jail_names`` must be resolvable at module scope."""
|
||||
import app.services.filter_config_service as mod
|
||||
|
||||
assert hasattr(mod, "_get_active_jail_names"), (
|
||||
"_get_active_jail_names is not available in filter_config_service — "
|
||||
"list_filters() will raise NameError → 500"
|
||||
)
|
||||
|
||||
async def test_list_filters_does_not_raise_name_error(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""``list_filters`` must not crash with NameError."""
|
||||
from app.services import filter_config_service
|
||||
|
||||
config_dir = str(tmp_path / "fail2ban")
|
||||
filter_d = Path(config_dir) / "filter.d"
|
||||
filter_d.mkdir(parents=True)
|
||||
|
||||
# Create a minimal filter file so _parse_filters_sync has something to scan.
|
||||
(filter_d / "sshd.conf").write_text(
|
||||
"[Definition]\nfailregex = ^Failed password\n"
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.filter_config_service._parse_jails_sync",
|
||||
return_value=({}, {}),
|
||||
),
|
||||
patch(
|
||||
"app.services.filter_config_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
),
|
||||
):
|
||||
result = await filter_config_service.list_filters(
|
||||
config_dir, "/fake/socket"
|
||||
)
|
||||
|
||||
assert result.total >= 0
|
||||
|
||||
|
||||
# ── Bug 4 ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestServiceStatusBanguiVersion:
|
||||
"""Bug 4: ``get_service_status`` must include application version
|
||||
in the ``version`` field of the ``ServiceStatusResponse``."""
|
||||
|
||||
async def test_online_response_contains_bangui_version(self) -> None:
|
||||
"""The returned model must contain the ``bangui_version`` field."""
|
||||
from app.models.server import ServerStatus
|
||||
from app.services import config_service
|
||||
import app
|
||||
|
||||
online_status = ServerStatus(
|
||||
online=True,
|
||||
version="1.0.0",
|
||||
active_jails=2,
|
||||
total_bans=5,
|
||||
total_failures=3,
|
||||
)
|
||||
|
||||
async def _send(command: list[Any]) -> Any:
|
||||
key = "|".join(str(c) for c in command)
|
||||
if key == "get|loglevel":
|
||||
return (0, "INFO")
|
||||
if key == "get|logtarget":
|
||||
return (0, "/var/log/fail2ban.log")
|
||||
return (0, None)
|
||||
|
||||
class _FakeClient:
|
||||
def __init__(self, **_kw: Any) -> None:
|
||||
self.send = AsyncMock(side_effect=_send)
|
||||
|
||||
with patch("app.services.config_service.Fail2BanClient", _FakeClient):
|
||||
result = await config_service.get_service_status(
|
||||
"/fake/socket",
|
||||
probe_fn=AsyncMock(return_value=online_status),
|
||||
)
|
||||
|
||||
assert result.version == app.__version__, (
|
||||
"ServiceStatusResponse must expose BanGUI version in version field"
|
||||
)
|
||||
|
||||
async def test_offline_response_contains_bangui_version(self) -> None:
|
||||
"""Even when fail2ban is offline, ``bangui_version`` must be present."""
|
||||
from app.models.server import ServerStatus
|
||||
from app.services import config_service
|
||||
import app
|
||||
|
||||
offline_status = ServerStatus(online=False)
|
||||
|
||||
result = await config_service.get_service_status(
|
||||
"/fake/socket",
|
||||
probe_fn=AsyncMock(return_value=offline_status),
|
||||
)
|
||||
|
||||
assert result.version == app.__version__
|
||||
167
backend/tests/test_repositories/test_fail2ban_db_repo.py
Normal file
167
backend/tests/test_repositories/test_fail2ban_db_repo.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""Tests for the fail2ban_db repository.
|
||||
|
||||
These tests use an in-memory sqlite file created under pytest's tmp_path and
|
||||
exercise the core query functions used by the services.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
|
||||
from app.repositories import fail2ban_db_repo
|
||||
|
||||
|
||||
async def _create_bans_table(db: aiosqlite.Connection) -> None:
|
||||
await db.execute(
|
||||
"""
|
||||
CREATE TABLE bans (
|
||||
jail TEXT,
|
||||
ip TEXT,
|
||||
timeofban INTEGER,
|
||||
bancount INTEGER,
|
||||
data TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_db_nonempty_returns_false_when_table_is_empty(tmp_path: Path) -> None:
|
||||
db_path = str(tmp_path / "fail2ban.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await _create_bans_table(db)
|
||||
|
||||
assert await fail2ban_db_repo.check_db_nonempty(db_path) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_db_nonempty_returns_true_when_row_exists(tmp_path: Path) -> None:
|
||||
db_path = str(tmp_path / "fail2ban.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await _create_bans_table(db)
|
||||
await db.execute(
|
||||
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
|
||||
("jail1", "1.2.3.4", 123, 1, "{}"),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
assert await fail2ban_db_repo.check_db_nonempty(db_path) is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_currently_banned_filters_and_pagination(tmp_path: Path) -> None:
|
||||
db_path = str(tmp_path / "fail2ban.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await _create_bans_table(db)
|
||||
# Three bans; one is from the blocklist-import jail.
|
||||
await db.executemany(
|
||||
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
|
||||
[
|
||||
("jail1", "1.1.1.1", 10, 1, "{}"),
|
||||
("blocklist-import", "2.2.2.2", 20, 2, "{}"),
|
||||
("jail1", "3.3.3.3", 30, 3, "{}"),
|
||||
],
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
records, total = await fail2ban_db_repo.get_currently_banned(
|
||||
db_path=db_path,
|
||||
since=15,
|
||||
origin="selfblock",
|
||||
limit=10,
|
||||
offset=0,
|
||||
)
|
||||
|
||||
# Only the non-blocklist row with timeofban >= 15 should remain.
|
||||
assert total == 1
|
||||
assert len(records) == 1
|
||||
assert records[0].ip == "3.3.3.3"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_ban_counts_by_bucket_ignores_out_of_range_buckets(tmp_path: Path) -> None:
|
||||
db_path = str(tmp_path / "fail2ban.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await _create_bans_table(db)
|
||||
await db.executemany(
|
||||
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
|
||||
[
|
||||
("jail1", "1.1.1.1", 5, 1, "{}"),
|
||||
("jail1", "2.2.2.2", 15, 1, "{}"),
|
||||
("jail1", "3.3.3.3", 35, 1, "{}"),
|
||||
],
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
counts = await fail2ban_db_repo.get_ban_counts_by_bucket(
|
||||
db_path=db_path,
|
||||
since=0,
|
||||
bucket_secs=10,
|
||||
num_buckets=3,
|
||||
)
|
||||
|
||||
assert counts == [1, 1, 0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_history_page_and_for_ip(tmp_path: Path) -> None:
|
||||
db_path = str(tmp_path / "fail2ban.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await _create_bans_table(db)
|
||||
await db.executemany(
|
||||
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
|
||||
[
|
||||
("jail1", "1.1.1.1", 100, 1, "{}"),
|
||||
("jail1", "1.1.1.1", 200, 2, "{}"),
|
||||
("jail1", "2.2.2.2", 300, 3, "{}"),
|
||||
],
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
page, total = await fail2ban_db_repo.get_history_page(
|
||||
db_path=db_path,
|
||||
since=None,
|
||||
jail="jail1",
|
||||
ip_filter="1.1.1",
|
||||
page=1,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
assert total == 2
|
||||
assert len(page) == 2
|
||||
assert page[0].ip == "1.1.1.1"
|
||||
|
||||
history_for_ip = await fail2ban_db_repo.get_history_for_ip(db_path=db_path, ip="2.2.2.2")
|
||||
assert len(history_for_ip) == 1
|
||||
assert history_for_ip[0].ip == "2.2.2.2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_history_page_origin_filter(tmp_path: Path) -> None:
|
||||
db_path = str(tmp_path / "fail2ban.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await _create_bans_table(db)
|
||||
await db.executemany(
|
||||
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
|
||||
[
|
||||
("jail1", "1.1.1.1", 100, 1, "{}"),
|
||||
("blocklist-import", "2.2.2.2", 200, 1, "{}"),
|
||||
],
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
page, total = await fail2ban_db_repo.get_history_page(
|
||||
db_path=db_path,
|
||||
since=None,
|
||||
jail=None,
|
||||
ip_filter=None,
|
||||
origin="selfblock",
|
||||
page=1,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
assert total == 1
|
||||
assert len(page) == 1
|
||||
assert page[0].ip == "1.1.1.1"
|
||||
140
backend/tests/test_repositories/test_geo_cache_repo.py
Normal file
140
backend/tests/test_repositories/test_geo_cache_repo.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""Tests for the geo cache repository."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
|
||||
from app.repositories import geo_cache_repo
|
||||
|
||||
|
||||
async def _create_geo_cache_table(db: aiosqlite.Connection) -> None:
|
||||
await db.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS geo_cache (
|
||||
ip TEXT PRIMARY KEY,
|
||||
country_code TEXT,
|
||||
country_name TEXT,
|
||||
asn TEXT,
|
||||
org TEXT,
|
||||
cached_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))
|
||||
)
|
||||
"""
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_unresolved_ips_returns_empty_when_none_exist(tmp_path: Path) -> None:
|
||||
db_path = str(tmp_path / "geo_cache.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await _create_geo_cache_table(db)
|
||||
await db.execute(
|
||||
"INSERT INTO geo_cache (ip, country_code, country_name, asn, org) VALUES (?, ?, ?, ?, ?)",
|
||||
("1.1.1.1", "DE", "Germany", "AS123", "Test"),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
ips = await geo_cache_repo.get_unresolved_ips(db)
|
||||
|
||||
assert ips == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_unresolved_ips_returns_pending_ips(tmp_path: Path) -> None:
|
||||
db_path = str(tmp_path / "geo_cache.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await _create_geo_cache_table(db)
|
||||
await db.executemany(
|
||||
"INSERT INTO geo_cache (ip, country_code) VALUES (?, ?)",
|
||||
[
|
||||
("2.2.2.2", None),
|
||||
("3.3.3.3", None),
|
||||
("4.4.4.4", "US"),
|
||||
],
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
ips = await geo_cache_repo.get_unresolved_ips(db)
|
||||
|
||||
assert sorted(ips) == ["2.2.2.2", "3.3.3.3"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_all_and_count_unresolved(tmp_path: Path) -> None:
|
||||
db_path = str(tmp_path / "geo_cache.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await _create_geo_cache_table(db)
|
||||
await db.executemany(
|
||||
"INSERT INTO geo_cache (ip, country_code, country_name, asn, org) VALUES (?, ?, ?, ?, ?)",
|
||||
[
|
||||
("5.5.5.5", None, None, None, None),
|
||||
("6.6.6.6", "FR", "France", "AS456", "TestOrg"),
|
||||
],
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
rows = await geo_cache_repo.load_all(db)
|
||||
unresolved = await geo_cache_repo.count_unresolved(db)
|
||||
|
||||
assert unresolved == 1
|
||||
assert any(row["ip"] == "6.6.6.6" and row["country_code"] == "FR" for row in rows)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_entry_and_neg_entry(tmp_path: Path) -> None:
|
||||
db_path = str(tmp_path / "geo_cache.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await _create_geo_cache_table(db)
|
||||
|
||||
await geo_cache_repo.upsert_entry(
|
||||
db,
|
||||
"7.7.7.7",
|
||||
"GB",
|
||||
"United Kingdom",
|
||||
"AS789",
|
||||
"TestOrg",
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
await geo_cache_repo.upsert_neg_entry(db, "8.8.8.8")
|
||||
await db.commit()
|
||||
|
||||
# Ensure positive entry is present.
|
||||
async with db.execute("SELECT country_code FROM geo_cache WHERE ip = ?", ("7.7.7.7",)) as cur:
|
||||
row = await cur.fetchone()
|
||||
assert row is not None
|
||||
assert row[0] == "GB"
|
||||
|
||||
# Ensure negative entry exists with NULL country_code.
|
||||
async with db.execute("SELECT country_code FROM geo_cache WHERE ip = ?", ("8.8.8.8",)) as cur:
|
||||
row = await cur.fetchone()
|
||||
assert row is not None
|
||||
assert row[0] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_upsert_entries_and_neg_entries(tmp_path: Path) -> None:
|
||||
db_path = str(tmp_path / "geo_cache.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await _create_geo_cache_table(db)
|
||||
|
||||
rows = [
|
||||
("9.9.9.9", "NL", "Netherlands", "AS101", "Test"),
|
||||
("10.10.10.10", "JP", "Japan", "AS102", "Test"),
|
||||
]
|
||||
count = await geo_cache_repo.bulk_upsert_entries(db, rows)
|
||||
assert count == 2
|
||||
|
||||
neg_count = await geo_cache_repo.bulk_upsert_neg_entries(db, ["11.11.11.11", "12.12.12.12"])
|
||||
assert neg_count == 2
|
||||
|
||||
await db.commit()
|
||||
|
||||
async with db.execute("SELECT COUNT(*) FROM geo_cache") as cur:
|
||||
row = await cur.fetchone()
|
||||
assert row is not None
|
||||
assert int(row[0]) == 4
|
||||
60
backend/tests/test_repositories/test_history_archive_repo.py
Normal file
60
backend/tests/test_repositories/test_history_archive_repo.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""Tests for history_archive_repo."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
|
||||
from app.db import init_db
|
||||
from app.repositories.history_archive_repo import archive_ban_event, get_archived_history, purge_archived_history
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def app_db(tmp_path: Path) -> str:
|
||||
path = str(tmp_path / "app.db")
|
||||
async with aiosqlite.connect(path) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
await init_db(db)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_archive_ban_event_deduplication(app_db: str) -> None:
|
||||
async with aiosqlite.connect(app_db) as db:
|
||||
# first insert should add
|
||||
inserted = await archive_ban_event(db, "sshd", "1.1.1.1", 1000, 1, "{}", "ban")
|
||||
assert inserted
|
||||
|
||||
# duplicate event is ignored
|
||||
inserted = await archive_ban_event(db, "sshd", "1.1.1.1", 1000, 1, "{}", "ban")
|
||||
assert not inserted
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_archived_history_filtering_and_pagination(app_db: str) -> None:
|
||||
async with aiosqlite.connect(app_db) as db:
|
||||
await archive_ban_event(db, "sshd", "1.1.1.1", 1000, 1, "{}", "ban")
|
||||
await archive_ban_event(db, "nginx", "2.2.2.2", 2000, 1, "{}", "ban")
|
||||
|
||||
rows, total = await get_archived_history(db, jail="sshd")
|
||||
assert total == 1
|
||||
assert rows[0]["ip"] == "1.1.1.1"
|
||||
|
||||
rows, total = await get_archived_history(db, page=1, page_size=1)
|
||||
assert total == 2
|
||||
assert len(rows) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_purge_archived_history(app_db: str) -> None:
|
||||
now = int(time.time())
|
||||
async with aiosqlite.connect(app_db) as db:
|
||||
await archive_ban_event(db, "sshd", "1.1.1.1", now - 3000, 1, "{}", "ban")
|
||||
await archive_ban_event(db, "sshd", "1.1.1.2", now - 1000, 1, "{}", "ban")
|
||||
deleted = await purge_archived_history(db, age_seconds=2000)
|
||||
assert deleted == 1
|
||||
rows, total = await get_archived_history(db)
|
||||
assert total == 1
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -157,12 +158,12 @@ class TestRequireAuthSessionCache:
|
||||
"""In-memory session token cache inside ``require_auth``."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_cache(self) -> None: # type: ignore[misc]
|
||||
def reset_cache(self) -> Generator[None, None, None]:
|
||||
"""Flush the session cache before and after every test in this class."""
|
||||
from app import dependencies
|
||||
|
||||
dependencies.clear_session_cache()
|
||||
yield # type: ignore[misc]
|
||||
yield
|
||||
dependencies.clear_session_cache()
|
||||
|
||||
async def test_second_request_skips_db(self, client: AsyncClient) -> None:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -9,6 +9,8 @@ import aiosqlite
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
import app
|
||||
|
||||
from app.config import Settings
|
||||
from app.db import init_db
|
||||
from app.main import create_app
|
||||
@@ -151,6 +153,7 @@ class TestDashboardStatus:
|
||||
body = response.json()
|
||||
|
||||
assert "status" in body
|
||||
|
||||
status = body["status"]
|
||||
assert "online" in status
|
||||
assert "version" in status
|
||||
@@ -163,10 +166,11 @@ class TestDashboardStatus:
|
||||
) -> None:
|
||||
"""Endpoint returns the exact values from ``app.state.server_status``."""
|
||||
response = await dashboard_client.get("/api/dashboard/status")
|
||||
status = response.json()["status"]
|
||||
body = response.json()
|
||||
status = body["status"]
|
||||
|
||||
assert status["online"] is True
|
||||
assert status["version"] == "1.0.2"
|
||||
assert status["version"] == app.__version__
|
||||
assert status["active_jails"] == 2
|
||||
assert status["total_bans"] == 10
|
||||
assert status["total_failures"] == 5
|
||||
@@ -177,10 +181,11 @@ class TestDashboardStatus:
|
||||
"""Endpoint returns online=False when the cache holds an offline snapshot."""
|
||||
response = await offline_dashboard_client.get("/api/dashboard/status")
|
||||
assert response.status_code == 200
|
||||
status = response.json()["status"]
|
||||
body = response.json()
|
||||
status = body["status"]
|
||||
|
||||
assert status["online"] is False
|
||||
assert status["version"] is None
|
||||
assert status["version"] == app.__version__
|
||||
assert status["active_jails"] == 0
|
||||
assert status["total_bans"] == 0
|
||||
assert status["total_failures"] == 0
|
||||
@@ -285,6 +290,17 @@ class TestDashboardBans:
|
||||
called_range = mock_list.call_args[0][1]
|
||||
assert called_range == "7d"
|
||||
|
||||
async def test_accepts_source_param(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""The ``source`` query parameter is forwarded to ban_service."""
|
||||
mock_list = AsyncMock(return_value=_make_ban_list_response())
|
||||
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
|
||||
await dashboard_client.get("/api/dashboard/bans?source=archive")
|
||||
|
||||
called_source = mock_list.call_args[1]["source"]
|
||||
assert called_source == "archive"
|
||||
|
||||
async def test_empty_ban_list_returns_zero_total(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
@@ -412,6 +428,15 @@ class TestBansByCountry:
|
||||
called_range = mock_fn.call_args[0][1]
|
||||
assert called_range == "7d"
|
||||
|
||||
async def test_invalid_source_returns_422(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""An invalid source value returns HTTP 422."""
|
||||
response = await dashboard_client.get(
|
||||
"/api/dashboard/bans/by-country?source=invalid"
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
async def test_empty_window_returns_empty_response(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
@@ -487,6 +512,16 @@ class TestDashboardBansOriginField:
|
||||
origins = {ban["origin"] for ban in bans}
|
||||
assert origins == {"blocklist", "selfblock"}
|
||||
|
||||
async def test_bans_by_country_source_param_forwarded(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""The ``source`` query parameter is forwarded to bans_by_country."""
|
||||
mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
|
||||
with patch("app.routers.dashboard.ban_service.bans_by_country", new=mock_fn):
|
||||
await dashboard_client.get("/api/dashboard/bans/by-country?source=archive")
|
||||
|
||||
assert mock_fn.call_args[1]["source"] == "archive"
|
||||
|
||||
async def test_blocklist_origin_serialised_correctly(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
@@ -696,6 +731,15 @@ class TestBanTrend:
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
async def test_invalid_source_returns_422(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""An invalid source value returns HTTP 422."""
|
||||
response = await dashboard_client.get(
|
||||
"/api/dashboard/bans/trend?source=invalid"
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
async def test_empty_buckets_response(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Empty bucket list is serialised correctly."""
|
||||
from app.models.ban import BanTrendResponse
|
||||
@@ -831,6 +875,15 @@ class TestBansByJail:
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
async def test_invalid_source_returns_422(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""An invalid source value returns HTTP 422."""
|
||||
response = await dashboard_client.get(
|
||||
"/api/dashboard/bans/by-jail?source=invalid"
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
async def test_empty_jails_response(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Empty jails list is serialised correctly."""
|
||||
from app.models.ban import BansByJailResponse
|
||||
|
||||
@@ -26,7 +26,7 @@ from app.models.file_config import (
|
||||
JailConfigFileContent,
|
||||
JailConfigFilesResponse,
|
||||
)
|
||||
from app.services.file_config_service import (
|
||||
from app.services.raw_config_io_service import (
|
||||
ConfigDirError,
|
||||
ConfigFileExistsError,
|
||||
ConfigFileNameError,
|
||||
@@ -112,7 +112,7 @@ class TestListJailConfigFiles:
|
||||
self, file_config_client: AsyncClient
|
||||
) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.list_jail_config_files",
|
||||
"app.routers.file_config.raw_config_io_service.list_jail_config_files",
|
||||
AsyncMock(return_value=_jail_files_resp()),
|
||||
):
|
||||
resp = await file_config_client.get("/api/config/jail-files")
|
||||
@@ -126,7 +126,7 @@ class TestListJailConfigFiles:
|
||||
self, file_config_client: AsyncClient
|
||||
) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.list_jail_config_files",
|
||||
"app.routers.file_config.raw_config_io_service.list_jail_config_files",
|
||||
AsyncMock(side_effect=ConfigDirError("not found")),
|
||||
):
|
||||
resp = await file_config_client.get("/api/config/jail-files")
|
||||
@@ -157,7 +157,7 @@ class TestGetJailConfigFile:
|
||||
content="[sshd]\nenabled = true\n",
|
||||
)
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.get_jail_config_file",
|
||||
"app.routers.file_config.raw_config_io_service.get_jail_config_file",
|
||||
AsyncMock(return_value=content),
|
||||
):
|
||||
resp = await file_config_client.get("/api/config/jail-files/sshd.conf")
|
||||
@@ -167,7 +167,7 @@ class TestGetJailConfigFile:
|
||||
|
||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.get_jail_config_file",
|
||||
"app.routers.file_config.raw_config_io_service.get_jail_config_file",
|
||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing.conf")),
|
||||
):
|
||||
resp = await file_config_client.get("/api/config/jail-files/missing.conf")
|
||||
@@ -178,7 +178,7 @@ class TestGetJailConfigFile:
|
||||
self, file_config_client: AsyncClient
|
||||
) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.get_jail_config_file",
|
||||
"app.routers.file_config.raw_config_io_service.get_jail_config_file",
|
||||
AsyncMock(side_effect=ConfigFileNameError("bad name")),
|
||||
):
|
||||
resp = await file_config_client.get("/api/config/jail-files/bad.txt")
|
||||
@@ -194,7 +194,7 @@ class TestGetJailConfigFile:
|
||||
class TestSetJailConfigEnabled:
|
||||
async def test_204_on_success(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.set_jail_config_enabled",
|
||||
"app.routers.file_config.raw_config_io_service.set_jail_config_enabled",
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
@@ -206,7 +206,7 @@ class TestSetJailConfigEnabled:
|
||||
|
||||
async def test_404_file_not_found(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.set_jail_config_enabled",
|
||||
"app.routers.file_config.raw_config_io_service.set_jail_config_enabled",
|
||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing.conf")),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
@@ -232,7 +232,7 @@ class TestGetFilterFileRaw:
|
||||
|
||||
async def test_200_returns_content(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.get_filter_file",
|
||||
"app.routers.file_config.raw_config_io_service.get_filter_file",
|
||||
AsyncMock(return_value=_conf_file_content("nginx")),
|
||||
):
|
||||
resp = await file_config_client.get("/api/config/filters/nginx/raw")
|
||||
@@ -242,7 +242,7 @@ class TestGetFilterFileRaw:
|
||||
|
||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.get_filter_file",
|
||||
"app.routers.file_config.raw_config_io_service.get_filter_file",
|
||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
|
||||
):
|
||||
resp = await file_config_client.get("/api/config/filters/missing/raw")
|
||||
@@ -258,7 +258,7 @@ class TestGetFilterFileRaw:
|
||||
class TestUpdateFilterFile:
|
||||
async def test_204_on_success(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.write_filter_file",
|
||||
"app.routers.file_config.raw_config_io_service.write_filter_file",
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
@@ -270,7 +270,7 @@ class TestUpdateFilterFile:
|
||||
|
||||
async def test_400_write_error(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.write_filter_file",
|
||||
"app.routers.file_config.raw_config_io_service.write_filter_file",
|
||||
AsyncMock(side_effect=ConfigFileWriteError("disk full")),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
@@ -289,7 +289,7 @@ class TestUpdateFilterFile:
|
||||
class TestCreateFilterFile:
|
||||
async def test_201_creates_file(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.create_filter_file",
|
||||
"app.routers.file_config.raw_config_io_service.create_filter_file",
|
||||
AsyncMock(return_value="myfilter.conf"),
|
||||
):
|
||||
resp = await file_config_client.post(
|
||||
@@ -302,7 +302,7 @@ class TestCreateFilterFile:
|
||||
|
||||
async def test_409_conflict(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.create_filter_file",
|
||||
"app.routers.file_config.raw_config_io_service.create_filter_file",
|
||||
AsyncMock(side_effect=ConfigFileExistsError("myfilter.conf")),
|
||||
):
|
||||
resp = await file_config_client.post(
|
||||
@@ -314,7 +314,7 @@ class TestCreateFilterFile:
|
||||
|
||||
async def test_400_invalid_name(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.create_filter_file",
|
||||
"app.routers.file_config.raw_config_io_service.create_filter_file",
|
||||
AsyncMock(side_effect=ConfigFileNameError("bad/../name")),
|
||||
):
|
||||
resp = await file_config_client.post(
|
||||
@@ -327,41 +327,150 @@ class TestCreateFilterFile:
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/config/actions (smoke test — same logic as filters)
|
||||
# Note: GET /api/config/actions is handled by config.router (registered first);
|
||||
# file_config.router's "/actions" endpoint is shadowed by it.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestListActionFiles:
|
||||
async def test_200_returns_files(self, file_config_client: AsyncClient) -> None:
|
||||
action_entry = ConfFileEntry(name="iptables", filename="iptables.conf")
|
||||
resp_data = ConfFilesResponse(files=[action_entry], total=1)
|
||||
from app.models.config import ActionListResponse
|
||||
|
||||
mock_action = ActionConfig(
|
||||
name="iptables",
|
||||
filename="iptables.conf",
|
||||
)
|
||||
resp_data = ActionListResponse(actions=[mock_action], total=1)
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.list_action_files",
|
||||
"app.routers.config.action_config_service.list_actions",
|
||||
AsyncMock(return_value=resp_data),
|
||||
):
|
||||
resp = await file_config_client.get("/api/config/actions")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["files"][0]["filename"] == "iptables.conf"
|
||||
assert resp.json()["actions"][0]["name"] == "iptables"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/config/actions
|
||||
# Note: POST /api/config/actions is also handled by config.router.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCreateActionFile:
|
||||
async def test_201_creates_file(self, file_config_client: AsyncClient) -> None:
|
||||
created = ActionConfig(
|
||||
name="myaction",
|
||||
filename="myaction.local",
|
||||
actionban="echo ban <ip>",
|
||||
)
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.create_action_file",
|
||||
AsyncMock(return_value="myaction.conf"),
|
||||
"app.routers.config.action_config_service.create_action",
|
||||
AsyncMock(return_value=created),
|
||||
):
|
||||
resp = await file_config_client.post(
|
||||
"/api/config/actions",
|
||||
json={"name": "myaction", "content": "[Definition]\n"},
|
||||
json={"name": "myaction", "actionban": "echo ban <ip>"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 201
|
||||
assert resp.json()["filename"] == "myaction.conf"
|
||||
assert resp.json()["name"] == "myaction"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/config/actions/{name}/raw
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetActionFileRaw:
|
||||
"""Tests for ``GET /api/config/actions/{name}/raw``."""
|
||||
|
||||
async def test_200_returns_content(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.raw_config_io_service.get_action_file",
|
||||
AsyncMock(return_value=_conf_file_content("iptables")),
|
||||
):
|
||||
resp = await file_config_client.get("/api/config/actions/iptables/raw")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["name"] == "iptables"
|
||||
|
||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.raw_config_io_service.get_action_file",
|
||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
|
||||
):
|
||||
resp = await file_config_client.get("/api/config/actions/missing/raw")
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_503_on_config_dir_error(
|
||||
self, file_config_client: AsyncClient
|
||||
) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.raw_config_io_service.get_action_file",
|
||||
AsyncMock(side_effect=ConfigDirError("no dir")),
|
||||
):
|
||||
resp = await file_config_client.get("/api/config/actions/iptables/raw")
|
||||
|
||||
assert resp.status_code == 503
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PUT /api/config/actions/{name}/raw
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUpdateActionFileRaw:
|
||||
"""Tests for ``PUT /api/config/actions/{name}/raw``."""
|
||||
|
||||
async def test_204_on_success(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.raw_config_io_service.write_action_file",
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
"/api/config/actions/iptables/raw",
|
||||
json={"content": "[Definition]\nactionban = iptables -I INPUT -s <ip> -j DROP\n"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 204
|
||||
|
||||
async def test_400_write_error(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.raw_config_io_service.write_action_file",
|
||||
AsyncMock(side_effect=ConfigFileWriteError("disk full")),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
"/api/config/actions/iptables/raw",
|
||||
json={"content": "x"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.raw_config_io_service.write_action_file",
|
||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
"/api/config/actions/missing/raw",
|
||||
json={"content": "x"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_400_invalid_name(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.raw_config_io_service.write_action_file",
|
||||
AsyncMock(side_effect=ConfigFileNameError("bad/../name")),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
"/api/config/actions/escape/raw",
|
||||
json={"content": "x"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -372,7 +481,7 @@ class TestCreateActionFile:
|
||||
class TestCreateJailConfigFile:
|
||||
async def test_201_creates_file(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.create_jail_config_file",
|
||||
"app.routers.file_config.raw_config_io_service.create_jail_config_file",
|
||||
AsyncMock(return_value="myjail.conf"),
|
||||
):
|
||||
resp = await file_config_client.post(
|
||||
@@ -385,7 +494,7 @@ class TestCreateJailConfigFile:
|
||||
|
||||
async def test_409_conflict(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.create_jail_config_file",
|
||||
"app.routers.file_config.raw_config_io_service.create_jail_config_file",
|
||||
AsyncMock(side_effect=ConfigFileExistsError("myjail.conf")),
|
||||
):
|
||||
resp = await file_config_client.post(
|
||||
@@ -397,7 +506,7 @@ class TestCreateJailConfigFile:
|
||||
|
||||
async def test_400_invalid_name(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.create_jail_config_file",
|
||||
"app.routers.file_config.raw_config_io_service.create_jail_config_file",
|
||||
AsyncMock(side_effect=ConfigFileNameError("bad/../name")),
|
||||
):
|
||||
resp = await file_config_client.post(
|
||||
@@ -411,7 +520,7 @@ class TestCreateJailConfigFile:
|
||||
self, file_config_client: AsyncClient
|
||||
) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.create_jail_config_file",
|
||||
"app.routers.file_config.raw_config_io_service.create_jail_config_file",
|
||||
AsyncMock(side_effect=ConfigDirError("no dir")),
|
||||
):
|
||||
resp = await file_config_client.post(
|
||||
@@ -433,7 +542,7 @@ class TestGetParsedFilter:
|
||||
) -> None:
|
||||
cfg = FilterConfig(name="nginx", filename="nginx.conf")
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.get_parsed_filter_file",
|
||||
"app.routers.file_config.raw_config_io_service.get_parsed_filter_file",
|
||||
AsyncMock(return_value=cfg),
|
||||
):
|
||||
resp = await file_config_client.get("/api/config/filters/nginx/parsed")
|
||||
@@ -445,7 +554,7 @@ class TestGetParsedFilter:
|
||||
|
||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.get_parsed_filter_file",
|
||||
"app.routers.file_config.raw_config_io_service.get_parsed_filter_file",
|
||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
|
||||
):
|
||||
resp = await file_config_client.get(
|
||||
@@ -458,7 +567,7 @@ class TestGetParsedFilter:
|
||||
self, file_config_client: AsyncClient
|
||||
) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.get_parsed_filter_file",
|
||||
"app.routers.file_config.raw_config_io_service.get_parsed_filter_file",
|
||||
AsyncMock(side_effect=ConfigDirError("no dir")),
|
||||
):
|
||||
resp = await file_config_client.get("/api/config/filters/nginx/parsed")
|
||||
@@ -474,7 +583,7 @@ class TestGetParsedFilter:
|
||||
class TestUpdateParsedFilter:
|
||||
async def test_204_on_success(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.update_parsed_filter_file",
|
||||
"app.routers.file_config.raw_config_io_service.update_parsed_filter_file",
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
@@ -486,7 +595,7 @@ class TestUpdateParsedFilter:
|
||||
|
||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.update_parsed_filter_file",
|
||||
"app.routers.file_config.raw_config_io_service.update_parsed_filter_file",
|
||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
@@ -498,7 +607,7 @@ class TestUpdateParsedFilter:
|
||||
|
||||
async def test_400_write_error(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.update_parsed_filter_file",
|
||||
"app.routers.file_config.raw_config_io_service.update_parsed_filter_file",
|
||||
AsyncMock(side_effect=ConfigFileWriteError("disk full")),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
@@ -520,7 +629,7 @@ class TestGetParsedAction:
|
||||
) -> None:
|
||||
cfg = ActionConfig(name="iptables", filename="iptables.conf")
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.get_parsed_action_file",
|
||||
"app.routers.file_config.raw_config_io_service.get_parsed_action_file",
|
||||
AsyncMock(return_value=cfg),
|
||||
):
|
||||
resp = await file_config_client.get(
|
||||
@@ -534,7 +643,7 @@ class TestGetParsedAction:
|
||||
|
||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.get_parsed_action_file",
|
||||
"app.routers.file_config.raw_config_io_service.get_parsed_action_file",
|
||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
|
||||
):
|
||||
resp = await file_config_client.get(
|
||||
@@ -547,7 +656,7 @@ class TestGetParsedAction:
|
||||
self, file_config_client: AsyncClient
|
||||
) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.get_parsed_action_file",
|
||||
"app.routers.file_config.raw_config_io_service.get_parsed_action_file",
|
||||
AsyncMock(side_effect=ConfigDirError("no dir")),
|
||||
):
|
||||
resp = await file_config_client.get(
|
||||
@@ -565,7 +674,7 @@ class TestGetParsedAction:
|
||||
class TestUpdateParsedAction:
|
||||
async def test_204_on_success(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.update_parsed_action_file",
|
||||
"app.routers.file_config.raw_config_io_service.update_parsed_action_file",
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
@@ -577,7 +686,7 @@ class TestUpdateParsedAction:
|
||||
|
||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.update_parsed_action_file",
|
||||
"app.routers.file_config.raw_config_io_service.update_parsed_action_file",
|
||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
@@ -589,7 +698,7 @@ class TestUpdateParsedAction:
|
||||
|
||||
async def test_400_write_error(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.update_parsed_action_file",
|
||||
"app.routers.file_config.raw_config_io_service.update_parsed_action_file",
|
||||
AsyncMock(side_effect=ConfigFileWriteError("disk full")),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
@@ -612,7 +721,7 @@ class TestGetParsedJailFile:
|
||||
section = JailSectionConfig(enabled=True, port="ssh")
|
||||
cfg = JailFileConfig(filename="sshd.conf", jails={"sshd": section})
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.get_parsed_jail_file",
|
||||
"app.routers.file_config.raw_config_io_service.get_parsed_jail_file",
|
||||
AsyncMock(return_value=cfg),
|
||||
):
|
||||
resp = await file_config_client.get(
|
||||
@@ -626,7 +735,7 @@ class TestGetParsedJailFile:
|
||||
|
||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.get_parsed_jail_file",
|
||||
"app.routers.file_config.raw_config_io_service.get_parsed_jail_file",
|
||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing.conf")),
|
||||
):
|
||||
resp = await file_config_client.get(
|
||||
@@ -639,7 +748,7 @@ class TestGetParsedJailFile:
|
||||
self, file_config_client: AsyncClient
|
||||
) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.get_parsed_jail_file",
|
||||
"app.routers.file_config.raw_config_io_service.get_parsed_jail_file",
|
||||
AsyncMock(side_effect=ConfigDirError("no dir")),
|
||||
):
|
||||
resp = await file_config_client.get(
|
||||
@@ -657,7 +766,7 @@ class TestGetParsedJailFile:
|
||||
class TestUpdateParsedJailFile:
|
||||
async def test_204_on_success(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.update_parsed_jail_file",
|
||||
"app.routers.file_config.raw_config_io_service.update_parsed_jail_file",
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
@@ -669,7 +778,7 @@ class TestUpdateParsedJailFile:
|
||||
|
||||
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.update_parsed_jail_file",
|
||||
"app.routers.file_config.raw_config_io_service.update_parsed_jail_file",
|
||||
AsyncMock(side_effect=ConfigFileNotFoundError("missing.conf")),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
@@ -681,7 +790,7 @@ class TestUpdateParsedJailFile:
|
||||
|
||||
async def test_400_write_error(self, file_config_client: AsyncClient) -> None:
|
||||
with patch(
|
||||
"app.routers.file_config.file_config_service.update_parsed_jail_file",
|
||||
"app.routers.file_config.raw_config_io_service.update_parsed_jail_file",
|
||||
AsyncMock(side_effect=ConfigFileWriteError("disk full")),
|
||||
):
|
||||
resp = await file_config_client.put(
|
||||
|
||||
@@ -12,7 +12,7 @@ from httpx import ASGITransport, AsyncClient
|
||||
from app.config import Settings
|
||||
from app.db import init_db
|
||||
from app.main import create_app
|
||||
from app.services.geo_service import GeoInfo
|
||||
from app.models.geo import GeoInfo
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
@@ -70,7 +70,7 @@ class TestGeoLookup:
|
||||
async def test_200_with_geo_info(self, geo_client: AsyncClient) -> None:
|
||||
"""GET /api/geo/lookup/{ip} returns 200 with enriched result."""
|
||||
geo = GeoInfo(country_code="DE", country_name="Germany", asn="12345", org="Acme")
|
||||
result = {
|
||||
result: dict[str, object] = {
|
||||
"ip": "1.2.3.4",
|
||||
"currently_banned_in": ["sshd"],
|
||||
"geo": geo,
|
||||
@@ -92,7 +92,7 @@ class TestGeoLookup:
|
||||
|
||||
async def test_200_when_not_banned(self, geo_client: AsyncClient) -> None:
|
||||
"""GET /api/geo/lookup/{ip} returns empty list when IP is not banned anywhere."""
|
||||
result = {
|
||||
result: dict[str, object] = {
|
||||
"ip": "8.8.8.8",
|
||||
"currently_banned_in": [],
|
||||
"geo": GeoInfo(country_code="US", country_name="United States", asn=None, org=None),
|
||||
@@ -108,7 +108,7 @@ class TestGeoLookup:
|
||||
|
||||
async def test_200_with_no_geo(self, geo_client: AsyncClient) -> None:
|
||||
"""GET /api/geo/lookup/{ip} returns null geo when enricher fails."""
|
||||
result = {
|
||||
result: dict[str, object] = {
|
||||
"ip": "1.2.3.4",
|
||||
"currently_banned_in": [],
|
||||
"geo": None,
|
||||
@@ -144,7 +144,7 @@ class TestGeoLookup:
|
||||
|
||||
async def test_ipv6_address(self, geo_client: AsyncClient) -> None:
|
||||
"""GET /api/geo/lookup/{ip} handles IPv6 addresses."""
|
||||
result = {
|
||||
result: dict[str, object] = {
|
||||
"ip": "2001:db8::1",
|
||||
"currently_banned_in": [],
|
||||
"geo": None,
|
||||
|
||||
@@ -13,10 +13,11 @@ async def test_health_check_returns_200(client: AsyncClient) -> None:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_returns_ok_status(client: AsyncClient) -> None:
|
||||
"""``GET /api/health`` must return ``{"status": "ok"}``."""
|
||||
"""``GET /api/health`` must contain ``status: ok`` and a ``fail2ban`` field."""
|
||||
response = await client.get("/api/health")
|
||||
data: dict[str, str] = response.json()
|
||||
assert data == {"status": "ok"}
|
||||
assert data["status"] == "ok"
|
||||
assert data["fail2ban"] in ("online", "offline")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -213,6 +213,44 @@ class TestHistoryList:
|
||||
_args, kwargs = mock_fn.call_args
|
||||
assert kwargs.get("range_") == "7d"
|
||||
|
||||
async def test_forwards_origin_filter(self, history_client: AsyncClient) -> None:
|
||||
"""The ``origin`` query parameter is forwarded to the service."""
|
||||
mock_fn = AsyncMock(return_value=_make_history_list(n=0))
|
||||
with patch(
|
||||
"app.routers.history.history_service.list_history",
|
||||
new=mock_fn,
|
||||
):
|
||||
await history_client.get("/api/history?origin=blocklist")
|
||||
|
||||
_args, kwargs = mock_fn.call_args
|
||||
assert kwargs.get("origin") == "blocklist"
|
||||
|
||||
async def test_forwards_source_filter(self, history_client: AsyncClient) -> None:
|
||||
"""The ``source`` query parameter is forwarded to the service."""
|
||||
mock_fn = AsyncMock(return_value=_make_history_list(n=0))
|
||||
with patch(
|
||||
"app.routers.history.history_service.list_history",
|
||||
new=mock_fn,
|
||||
):
|
||||
await history_client.get("/api/history?source=archive")
|
||||
|
||||
_args, kwargs = mock_fn.call_args
|
||||
assert kwargs.get("source") == "archive"
|
||||
|
||||
async def test_archive_route_forces_source_archive(
|
||||
self, history_client: AsyncClient
|
||||
) -> None:
|
||||
"""GET /api/history/archive should call list_history with source='archive'."""
|
||||
mock_fn = AsyncMock(return_value=_make_history_list(n=0))
|
||||
with patch(
|
||||
"app.routers.history.history_service.list_history",
|
||||
new=mock_fn,
|
||||
):
|
||||
await history_client.get("/api/history/archive")
|
||||
|
||||
_args, kwargs = mock_fn.call_args
|
||||
assert kwargs.get("source") == "archive"
|
||||
|
||||
async def test_empty_result(self, history_client: AsyncClient) -> None:
|
||||
"""An empty history returns items=[] and total=0."""
|
||||
with patch(
|
||||
|
||||
@@ -12,6 +12,7 @@ from httpx import ASGITransport, AsyncClient
|
||||
from app.config import Settings
|
||||
from app.db import init_db
|
||||
from app.main import create_app
|
||||
from app.models.ban import JailBannedIpsResponse
|
||||
from app.models.jail import Jail, JailDetailResponse, JailListResponse, JailStatus, JailSummary
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -788,3 +789,146 @@ class TestFail2BanConnectionErrors:
|
||||
resp = await jails_client.post("/api/jails/sshd/reload")
|
||||
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/jails/{name}/banned
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetJailBannedIps:
|
||||
"""Tests for ``GET /api/jails/{name}/banned``."""
|
||||
|
||||
def _mock_response(
|
||||
self,
|
||||
*,
|
||||
items: list[dict[str, str | None]] | None = None,
|
||||
total: int = 2,
|
||||
page: int = 1,
|
||||
page_size: int = 25,
|
||||
) -> JailBannedIpsResponse:
|
||||
from app.models.ban import ActiveBan, JailBannedIpsResponse
|
||||
|
||||
ban_items = (
|
||||
[
|
||||
ActiveBan(
|
||||
ip=item.get("ip") or "1.2.3.4",
|
||||
jail="sshd",
|
||||
banned_at=item.get("banned_at", "2025-01-01T10:00:00+00:00"),
|
||||
expires_at=item.get("expires_at", "2025-01-01T10:10:00+00:00"),
|
||||
ban_count=1,
|
||||
country=item.get("country", None),
|
||||
)
|
||||
for item in (items or [{"ip": "1.2.3.4"}, {"ip": "5.6.7.8"}])
|
||||
]
|
||||
)
|
||||
return JailBannedIpsResponse(
|
||||
items=ban_items, total=total, page=page, page_size=page_size
|
||||
)
|
||||
|
||||
async def test_200_returns_paginated_bans(self, jails_client: AsyncClient) -> None:
|
||||
"""GET /api/jails/sshd/banned returns 200 with a JailBannedIpsResponse."""
|
||||
with patch(
|
||||
"app.routers.jails.jail_service.get_jail_banned_ips",
|
||||
AsyncMock(return_value=self._mock_response()),
|
||||
):
|
||||
resp = await jails_client.get("/api/jails/sshd/banned")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "items" in data
|
||||
assert "total" in data
|
||||
assert "page" in data
|
||||
assert "page_size" in data
|
||||
assert data["total"] == 2
|
||||
|
||||
async def test_200_with_search_parameter(self, jails_client: AsyncClient) -> None:
|
||||
"""GET /api/jails/sshd/banned?search=1.2.3 passes search to service."""
|
||||
mock_fn = AsyncMock(return_value=self._mock_response(items=[{"ip": "1.2.3.4"}], total=1))
|
||||
with patch("app.routers.jails.jail_service.get_jail_banned_ips", mock_fn):
|
||||
resp = await jails_client.get("/api/jails/sshd/banned?search=1.2.3")
|
||||
|
||||
assert resp.status_code == 200
|
||||
_args, call_kwargs = mock_fn.call_args
|
||||
assert call_kwargs.get("search") == "1.2.3"
|
||||
|
||||
async def test_200_with_page_and_page_size(self, jails_client: AsyncClient) -> None:
|
||||
"""GET /api/jails/sshd/banned?page=2&page_size=10 passes params to service."""
|
||||
mock_fn = AsyncMock(
|
||||
return_value=self._mock_response(page=2, page_size=10, total=0, items=[])
|
||||
)
|
||||
with patch("app.routers.jails.jail_service.get_jail_banned_ips", mock_fn):
|
||||
resp = await jails_client.get("/api/jails/sshd/banned?page=2&page_size=10")
|
||||
|
||||
assert resp.status_code == 200
|
||||
_args, call_kwargs = mock_fn.call_args
|
||||
assert call_kwargs.get("page") == 2
|
||||
assert call_kwargs.get("page_size") == 10
|
||||
|
||||
async def test_400_when_page_is_zero(self, jails_client: AsyncClient) -> None:
|
||||
"""GET /api/jails/sshd/banned?page=0 returns 400."""
|
||||
resp = await jails_client.get("/api/jails/sshd/banned?page=0")
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_400_when_page_size_exceeds_max(self, jails_client: AsyncClient) -> None:
|
||||
"""GET /api/jails/sshd/banned?page_size=200 returns 400."""
|
||||
resp = await jails_client.get("/api/jails/sshd/banned?page_size=200")
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_400_when_page_size_is_zero(self, jails_client: AsyncClient) -> None:
|
||||
"""GET /api/jails/sshd/banned?page_size=0 returns 400."""
|
||||
resp = await jails_client.get("/api/jails/sshd/banned?page_size=0")
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_404_for_unknown_jail(self, jails_client: AsyncClient) -> None:
|
||||
"""GET /api/jails/ghost/banned returns 404 when jail is unknown."""
|
||||
from app.services.jail_service import JailNotFoundError
|
||||
|
||||
with patch(
|
||||
"app.routers.jails.jail_service.get_jail_banned_ips",
|
||||
AsyncMock(side_effect=JailNotFoundError("ghost")),
|
||||
):
|
||||
resp = await jails_client.get("/api/jails/ghost/banned")
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_502_when_fail2ban_unreachable(self, jails_client: AsyncClient) -> None:
|
||||
"""GET /api/jails/sshd/banned returns 502 when fail2ban is unreachable."""
|
||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||
|
||||
with patch(
|
||||
"app.routers.jails.jail_service.get_jail_banned_ips",
|
||||
AsyncMock(
|
||||
side_effect=Fail2BanConnectionError("socket dead", "/tmp/fake.sock")
|
||||
),
|
||||
):
|
||||
resp = await jails_client.get("/api/jails/sshd/banned")
|
||||
|
||||
assert resp.status_code == 502
|
||||
|
||||
async def test_response_items_have_expected_fields(
|
||||
self, jails_client: AsyncClient
|
||||
) -> None:
|
||||
"""Response items contain ip, jail, banned_at, expires_at, ban_count, country."""
|
||||
with patch(
|
||||
"app.routers.jails.jail_service.get_jail_banned_ips",
|
||||
AsyncMock(return_value=self._mock_response()),
|
||||
):
|
||||
resp = await jails_client.get("/api/jails/sshd/banned")
|
||||
|
||||
item = resp.json()["items"][0]
|
||||
assert "ip" in item
|
||||
assert "jail" in item
|
||||
assert "banned_at" in item
|
||||
assert "expires_at" in item
|
||||
assert "ban_count" in item
|
||||
assert "country" in item
|
||||
|
||||
async def test_401_when_unauthenticated(self, jails_client: AsyncClient) -> None:
|
||||
"""GET /api/jails/sshd/banned returns 401 without a session cookie."""
|
||||
resp = await AsyncClient(
|
||||
transport=ASGITransport(app=jails_client._transport.app), # type: ignore[attr-defined]
|
||||
base_url="http://test",
|
||||
).get("/api/jails/sshd/banned")
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
@@ -68,7 +68,8 @@ def _make_settings() -> ServerSettingsResponse:
|
||||
db_path="/var/lib/fail2ban/fail2ban.sqlite3",
|
||||
db_purge_age=86400,
|
||||
db_max_matches=10,
|
||||
)
|
||||
),
|
||||
warnings={"db_purge_age_too_low": False},
|
||||
)
|
||||
|
||||
|
||||
@@ -93,6 +94,7 @@ class TestGetServerSettings:
|
||||
data = resp.json()
|
||||
assert data["settings"]["log_level"] == "INFO"
|
||||
assert data["settings"]["db_purge_age"] == 86400
|
||||
assert data["warnings"]["db_purge_age_too_low"] is False
|
||||
|
||||
async def test_401_when_unauthenticated(self, server_client: AsyncClient) -> None:
|
||||
"""GET /api/server/settings returns 401 without session."""
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
@@ -11,7 +11,7 @@ from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from app.config import Settings
|
||||
from app.db import init_db
|
||||
from app.main import create_app
|
||||
from app.main import _lifespan, create_app
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared setup payload
|
||||
@@ -247,9 +247,9 @@ class TestSetupCompleteCaching:
|
||||
assert not getattr(app.state, "_setup_complete_cached", False)
|
||||
|
||||
# First non-exempt request — middleware queries DB and sets the flag.
|
||||
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}) # type: ignore[call-overload]
|
||||
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]})
|
||||
|
||||
assert app.state._setup_complete_cached is True # type: ignore[attr-defined]
|
||||
assert app.state._setup_complete_cached is True
|
||||
|
||||
async def test_cached_path_skips_is_setup_complete(
|
||||
self,
|
||||
@@ -267,12 +267,12 @@ class TestSetupCompleteCaching:
|
||||
|
||||
# Do setup and warm the cache.
|
||||
await client.post("/api/setup", json=_SETUP_PAYLOAD)
|
||||
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}) # type: ignore[call-overload]
|
||||
assert app.state._setup_complete_cached is True # type: ignore[attr-defined]
|
||||
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]})
|
||||
assert app.state._setup_complete_cached is True
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def _counting(db): # type: ignore[no-untyped-def]
|
||||
async def _counting(db: aiosqlite.Connection) -> bool:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return True
|
||||
@@ -286,3 +286,151 @@ class TestSetupCompleteCaching:
|
||||
# Cache was warm — is_setup_complete must not have been called.
|
||||
assert call_count == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Task 0.1 — Lifespan creates the database parent directory (Task 0.1)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLifespanDatabaseDirectoryCreation:
|
||||
"""App lifespan creates the database parent directory when it does not exist."""
|
||||
|
||||
async def test_creates_nested_database_directory(self, tmp_path: Path) -> None:
|
||||
"""Lifespan creates intermediate directories for the database path.
|
||||
|
||||
Verifies that a deeply-nested database path is handled correctly —
|
||||
the parent directories are created before ``aiosqlite.connect`` is
|
||||
called so the app does not crash on a fresh volume.
|
||||
"""
|
||||
nested_db = tmp_path / "deep" / "nested" / "bangui.db"
|
||||
assert not nested_db.parent.exists()
|
||||
|
||||
settings = Settings(
|
||||
database_path=str(nested_db),
|
||||
fail2ban_socket="/tmp/fake.sock",
|
||||
session_secret="test-lifespan-mkdir-secret",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
)
|
||||
app = create_app(settings=settings)
|
||||
|
||||
mock_scheduler = MagicMock()
|
||||
mock_scheduler.start = MagicMock()
|
||||
mock_scheduler.shutdown = MagicMock()
|
||||
|
||||
with (
|
||||
patch("app.services.geo_service.init_geoip"),
|
||||
patch(
|
||||
"app.services.geo_service.load_cache_from_db",
|
||||
new=AsyncMock(return_value=None),
|
||||
),
|
||||
patch("app.tasks.health_check.register"),
|
||||
patch("app.tasks.blocklist_import.register"),
|
||||
patch("app.tasks.geo_cache_flush.register"),
|
||||
patch("app.tasks.geo_re_resolve.register"),
|
||||
patch("app.main.AsyncIOScheduler", return_value=mock_scheduler),
|
||||
patch("app.main.ensure_jail_configs"),
|
||||
):
|
||||
async with _lifespan(app):
|
||||
assert nested_db.parent.exists(), (
|
||||
"Expected lifespan to create database parent directory"
|
||||
)
|
||||
|
||||
async def test_existing_database_directory_is_not_an_error(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""Lifespan does not raise when the database directory already exists.
|
||||
|
||||
``mkdir(exist_ok=True)`` must be used so that re-starts on an existing
|
||||
volume do not fail.
|
||||
"""
|
||||
db_path = tmp_path / "bangui.db"
|
||||
# tmp_path already exists — this simulates a pre-existing volume.
|
||||
|
||||
settings = Settings(
|
||||
database_path=str(db_path),
|
||||
fail2ban_socket="/tmp/fake.sock",
|
||||
session_secret="test-lifespan-exist-ok-secret",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
)
|
||||
app = create_app(settings=settings)
|
||||
|
||||
mock_scheduler = MagicMock()
|
||||
mock_scheduler.start = MagicMock()
|
||||
mock_scheduler.shutdown = MagicMock()
|
||||
|
||||
with (
|
||||
patch("app.services.geo_service.init_geoip"),
|
||||
patch(
|
||||
"app.services.geo_service.load_cache_from_db",
|
||||
new=AsyncMock(return_value=None),
|
||||
),
|
||||
patch("app.tasks.health_check.register"),
|
||||
patch("app.tasks.blocklist_import.register"),
|
||||
patch("app.tasks.geo_cache_flush.register"),
|
||||
patch("app.tasks.geo_re_resolve.register"),
|
||||
patch("app.main.AsyncIOScheduler", return_value=mock_scheduler),
|
||||
patch("app.main.ensure_jail_configs"),
|
||||
):
|
||||
# Should not raise FileExistsError or similar.
|
||||
async with _lifespan(app):
|
||||
assert tmp_path.exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Task 0.2 — Middleware redirects when app.state.db is None
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSetupRedirectMiddlewareDbNone:
|
||||
"""SetupRedirectMiddleware redirects when the database is not yet available."""
|
||||
|
||||
async def test_redirects_to_setup_when_db_not_set(self, tmp_path: Path) -> None:
|
||||
"""A ``None`` db on app.state causes a 307 redirect to ``/api/setup``.
|
||||
|
||||
Simulates the race window where a request arrives before the lifespan
|
||||
has finished initialising the database connection.
|
||||
"""
|
||||
settings = Settings(
|
||||
database_path=str(tmp_path / "bangui.db"),
|
||||
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||
session_secret="test-db-none-secret",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
)
|
||||
app = create_app(settings=settings)
|
||||
# Deliberately do NOT set app.state.db to simulate startup not complete.
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(
|
||||
transport=transport, base_url="http://test"
|
||||
) as ac:
|
||||
response = await ac.get("/api/auth/login", follow_redirects=False)
|
||||
|
||||
assert response.status_code == 307
|
||||
assert response.headers["location"] == "/api/setup"
|
||||
|
||||
async def test_health_reachable_when_db_not_set(self, tmp_path: Path) -> None:
|
||||
"""Health endpoint is always reachable even when db is not initialised."""
|
||||
settings = Settings(
|
||||
database_path=str(tmp_path / "bangui.db"),
|
||||
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||
session_secret="test-db-none-health-secret",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
)
|
||||
app = create_app(settings=settings)
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(
|
||||
transport=transport, base_url="http://test"
|
||||
) as ac:
|
||||
response = await ac.get("/api/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ class TestCheckPasswordAsync:
|
||||
auth_service._check_password("secret", hashed), # noqa: SLF001
|
||||
auth_service._check_password("wrong", hashed), # noqa: SLF001
|
||||
)
|
||||
assert results == [True, False]
|
||||
assert tuple(results) == (True, False)
|
||||
|
||||
|
||||
class TestLogin:
|
||||
|
||||
@@ -11,6 +11,7 @@ from unittest.mock import AsyncMock, patch
|
||||
import aiosqlite
|
||||
import pytest
|
||||
|
||||
from app.db import init_db
|
||||
from app.services import ban_service
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -64,7 +65,7 @@ async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
||||
async def f2b_db_path(tmp_path: Path) -> str:
|
||||
"""Return the path to a test fail2ban SQLite database with several bans."""
|
||||
path = str(tmp_path / "fail2ban_test.sqlite3")
|
||||
await _create_f2b_db(
|
||||
@@ -103,7 +104,7 @@ async def f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mixed_origin_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
||||
async def mixed_origin_db_path(tmp_path: Path) -> str:
|
||||
"""Return a database with bans from both blocklist-import and organic jails."""
|
||||
path = str(tmp_path / "fail2ban_mixed_origin.sqlite3")
|
||||
await _create_f2b_db(
|
||||
@@ -136,13 +137,36 @@ async def mixed_origin_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def empty_f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
||||
async def empty_f2b_db_path(tmp_path: Path) -> str:
|
||||
"""Return the path to a fail2ban SQLite database with no ban records."""
|
||||
path = str(tmp_path / "fail2ban_empty.sqlite3")
|
||||
await _create_f2b_db(path, [])
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def app_db_with_archive(tmp_path: Path) -> aiosqlite.Connection:
|
||||
"""Return an app database connection pre-populated with archived ban rows."""
|
||||
db_path = str(tmp_path / "app_archive.db")
|
||||
db = await aiosqlite.connect(db_path)
|
||||
db.row_factory = aiosqlite.Row
|
||||
await init_db(db)
|
||||
|
||||
await db.execute(
|
||||
"INSERT INTO history_archive (jail, ip, timeofban, bancount, data, action) VALUES (?, ?, ?, ?, ?, ?)",
|
||||
("sshd", "1.2.3.4", _ONE_HOUR_AGO, 1, '{"matches": ["fail"], "failures": 1}', "ban"),
|
||||
)
|
||||
await db.execute(
|
||||
"INSERT INTO history_archive (jail, ip, timeofban, bancount, data, action) VALUES (?, ?, ?, ?, ?, ?)",
|
||||
("nginx", "5.6.7.8", _ONE_HOUR_AGO, 1, '{"matches": ["fail"], "failures": 2}', "ban"),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
yield db
|
||||
|
||||
await db.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list_bans — happy path
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -154,7 +178,7 @@ class TestListBansHappyPath:
|
||||
async def test_returns_bans_in_range(self, f2b_db_path: str) -> None:
|
||||
"""Only bans within the selected range are returned."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
@@ -166,7 +190,7 @@ class TestListBansHappyPath:
|
||||
async def test_results_sorted_newest_first(self, f2b_db_path: str) -> None:
|
||||
"""Items are ordered by ``banned_at`` descending (newest first)."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
@@ -177,7 +201,7 @@ class TestListBansHappyPath:
|
||||
async def test_ban_fields_present(self, f2b_db_path: str) -> None:
|
||||
"""Each item contains ip, jail, banned_at, ban_count."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
@@ -191,7 +215,7 @@ class TestListBansHappyPath:
|
||||
async def test_service_extracted_from_first_match(self, f2b_db_path: str) -> None:
|
||||
"""``service`` field is the first element of ``data.matches``."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
@@ -203,7 +227,7 @@ class TestListBansHappyPath:
|
||||
async def test_service_is_none_when_no_matches(self, f2b_db_path: str) -> None:
|
||||
"""``service`` is ``None`` when the ban has no stored matches."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
# Use 7d to include the older ban with no matches.
|
||||
@@ -215,7 +239,7 @@ class TestListBansHappyPath:
|
||||
async def test_empty_db_returns_zero(self, empty_f2b_db_path: str) -> None:
|
||||
"""When no bans exist the result has total=0 and no items."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
@@ -226,13 +250,27 @@ class TestListBansHappyPath:
|
||||
async def test_365d_range_includes_old_bans(self, f2b_db_path: str) -> None:
|
||||
"""The ``365d`` range includes bans that are 2 days old."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "365d")
|
||||
|
||||
assert result.total == 3
|
||||
|
||||
async def test_source_archive_reads_from_archive(
|
||||
self, app_db_with_archive: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""Using source='archive' reads from the BanGUI archive table."""
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
source="archive",
|
||||
app_db=app_db_with_archive,
|
||||
)
|
||||
|
||||
assert result.total == 2
|
||||
assert {item.ip for item in result.items} == {"1.2.3.4", "5.6.7.8"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list_bans — geo enrichment
|
||||
@@ -246,7 +284,7 @@ class TestListBansGeoEnrichment:
|
||||
self, f2b_db_path: str
|
||||
) -> None:
|
||||
"""Geo fields are populated when an enricher returns data."""
|
||||
from app.services.geo_service import GeoInfo
|
||||
from app.models.geo import GeoInfo
|
||||
|
||||
async def fake_enricher(ip: str) -> GeoInfo:
|
||||
return GeoInfo(
|
||||
@@ -257,7 +295,7 @@ class TestListBansGeoEnrichment:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
@@ -278,7 +316,7 @@ class TestListBansGeoEnrichment:
|
||||
raise RuntimeError("geo service down")
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
@@ -304,25 +342,27 @@ class TestListBansBatchGeoEnrichment:
|
||||
"""Geo fields are populated via lookup_batch when http_session is given."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from app.services.geo_service import GeoInfo
|
||||
from app.models.geo import GeoInfo
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_geo_map = {
|
||||
"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn="AS3320", org="Deutsche Telekom"),
|
||||
"5.6.7.8": GeoInfo(country_code="US", country_name="United States", asn="AS15169", org="Google"),
|
||||
}
|
||||
fake_geo_batch = AsyncMock(return_value=fake_geo_map)
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
), patch(
|
||||
"app.services.geo_service.lookup_batch",
|
||||
new=AsyncMock(return_value=fake_geo_map),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock", "24h", http_session=fake_session
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
http_session=fake_session,
|
||||
geo_batch_lookup=fake_geo_batch,
|
||||
)
|
||||
|
||||
fake_geo_batch.assert_awaited_once_with(["1.2.3.4", "5.6.7.8"], fake_session, db=None)
|
||||
assert result.total == 2
|
||||
de_item = next(i for i in result.items if i.ip == "1.2.3.4")
|
||||
us_item = next(i for i in result.items if i.ip == "5.6.7.8")
|
||||
@@ -339,15 +379,17 @@ class TestListBansBatchGeoEnrichment:
|
||||
|
||||
fake_session = MagicMock()
|
||||
|
||||
failing_geo_batch = AsyncMock(side_effect=RuntimeError("batch geo down"))
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
), patch(
|
||||
"app.services.geo_service.lookup_batch",
|
||||
new=AsyncMock(side_effect=RuntimeError("batch geo down")),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock", "24h", http_session=fake_session
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
http_session=fake_session,
|
||||
geo_batch_lookup=failing_geo_batch,
|
||||
)
|
||||
|
||||
assert result.total == 2
|
||||
@@ -360,28 +402,27 @@ class TestListBansBatchGeoEnrichment:
|
||||
"""When both http_session and geo_enricher are provided, batch wins."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from app.services.geo_service import GeoInfo
|
||||
from app.models.geo import GeoInfo
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_geo_map = {
|
||||
"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None),
|
||||
"5.6.7.8": GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None),
|
||||
}
|
||||
fake_geo_batch = AsyncMock(return_value=fake_geo_map)
|
||||
|
||||
async def enricher_should_not_be_called(ip: str) -> GeoInfo:
|
||||
raise AssertionError(f"geo_enricher was called for {ip!r} — should not happen")
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
), patch(
|
||||
"app.services.geo_service.lookup_batch",
|
||||
new=AsyncMock(return_value=fake_geo_map),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
http_session=fake_session,
|
||||
geo_batch_lookup=fake_geo_batch,
|
||||
geo_enricher=enricher_should_not_be_called,
|
||||
)
|
||||
|
||||
@@ -401,7 +442,7 @@ class TestListBansPagination:
|
||||
async def test_page_size_respected(self, f2b_db_path: str) -> None:
|
||||
"""``page_size=1`` returns at most one item."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "7d", page_size=1)
|
||||
@@ -412,7 +453,7 @@ class TestListBansPagination:
|
||||
async def test_page_2_returns_remaining_items(self, f2b_db_path: str) -> None:
|
||||
"""The second page returns items not on the first page."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
page1 = await ban_service.list_bans("/fake/sock", "7d", page=1, page_size=1)
|
||||
@@ -426,7 +467,7 @@ class TestListBansPagination:
|
||||
) -> None:
|
||||
"""``total`` reports all matching records regardless of pagination."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "7d", page_size=1)
|
||||
@@ -447,7 +488,7 @@ class TestBanOriginDerivation:
|
||||
) -> None:
|
||||
"""Bans from ``blocklist-import`` jail carry ``origin == "blocklist"``."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
@@ -461,7 +502,7 @@ class TestBanOriginDerivation:
|
||||
) -> None:
|
||||
"""Bans from organic jails (sshd, nginx, …) carry ``origin == "selfblock"``."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
@@ -476,7 +517,7 @@ class TestBanOriginDerivation:
|
||||
) -> None:
|
||||
"""Every returned item has an ``origin`` field with a valid value."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
@@ -489,7 +530,7 @@ class TestBanOriginDerivation:
|
||||
) -> None:
|
||||
"""``bans_by_country`` also derives origin correctly for blocklist bans."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country("/fake/sock", "24h")
|
||||
@@ -503,7 +544,7 @@ class TestBanOriginDerivation:
|
||||
) -> None:
|
||||
"""``bans_by_country`` derives origin correctly for organic jails."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country("/fake/sock", "24h")
|
||||
@@ -527,7 +568,7 @@ class TestOriginFilter:
|
||||
) -> None:
|
||||
"""``origin='blocklist'`` returns only blocklist-import jail bans."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
@@ -544,7 +585,7 @@ class TestOriginFilter:
|
||||
) -> None:
|
||||
"""``origin='selfblock'`` excludes the blocklist-import jail."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
@@ -562,7 +603,7 @@ class TestOriginFilter:
|
||||
) -> None:
|
||||
"""``origin=None`` applies no jail restriction — all bans returned."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h", origin=None)
|
||||
@@ -574,7 +615,7 @@ class TestOriginFilter:
|
||||
) -> None:
|
||||
"""``bans_by_country`` with ``origin='blocklist'`` counts only blocklist bans."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
@@ -589,7 +630,7 @@ class TestOriginFilter:
|
||||
) -> None:
|
||||
"""``bans_by_country`` with ``origin='selfblock'`` excludes blocklist jails."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
@@ -604,7 +645,7 @@ class TestOriginFilter:
|
||||
) -> None:
|
||||
"""``bans_by_country`` with ``origin=None`` returns all bans."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
@@ -613,6 +654,20 @@ class TestOriginFilter:
|
||||
|
||||
assert result.total == 3
|
||||
|
||||
async def test_bans_by_country_source_archive_reads_archive(
|
||||
self, app_db_with_archive: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""``bans_by_country`` accepts source='archive' and reads archived rows."""
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
source="archive",
|
||||
app_db=app_db_with_archive,
|
||||
)
|
||||
|
||||
assert result.total == 2
|
||||
assert len(result.bans) == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# bans_by_country — background geo resolution (Task 3)
|
||||
@@ -632,19 +687,19 @@ class TestBansbyCountryBackground:
|
||||
from app.services import geo_service
|
||||
|
||||
# Pre-populate the cache for all three IPs in the fixture.
|
||||
geo_service._cache["10.0.0.1"] = geo_service.GeoInfo( # type: ignore[attr-defined]
|
||||
geo_service._cache["10.0.0.1"] = geo_service.GeoInfo(
|
||||
country_code="DE", country_name="Germany", asn=None, org=None
|
||||
)
|
||||
geo_service._cache["10.0.0.2"] = geo_service.GeoInfo( # type: ignore[attr-defined]
|
||||
geo_service._cache["10.0.0.2"] = geo_service.GeoInfo(
|
||||
country_code="US", country_name="United States", asn=None, org=None
|
||||
)
|
||||
geo_service._cache["10.0.0.3"] = geo_service.GeoInfo( # type: ignore[attr-defined]
|
||||
geo_service._cache["10.0.0.3"] = geo_service.GeoInfo(
|
||||
country_code="JP", country_name="Japan", asn=None, org=None
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
),
|
||||
patch(
|
||||
@@ -652,8 +707,13 @@ class TestBansbyCountryBackground:
|
||||
) as mock_create_task,
|
||||
):
|
||||
mock_session = AsyncMock()
|
||||
mock_batch = AsyncMock(return_value={})
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock", "24h", http_session=mock_session
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
http_session=mock_session,
|
||||
geo_cache_lookup=geo_service.lookup_cached_only,
|
||||
geo_batch_lookup=mock_batch,
|
||||
)
|
||||
|
||||
# All countries resolved from cache — no background task needed.
|
||||
@@ -674,7 +734,7 @@ class TestBansbyCountryBackground:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
),
|
||||
patch(
|
||||
@@ -682,8 +742,13 @@ class TestBansbyCountryBackground:
|
||||
) as mock_create_task,
|
||||
):
|
||||
mock_session = AsyncMock()
|
||||
mock_batch = AsyncMock(return_value={})
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock", "24h", http_session=mock_session
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
http_session=mock_session,
|
||||
geo_cache_lookup=geo_service.lookup_cached_only,
|
||||
geo_batch_lookup=mock_batch,
|
||||
)
|
||||
|
||||
# Background task must have been scheduled for uncached IPs.
|
||||
@@ -701,7 +766,7 @@ class TestBansbyCountryBackground:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
),
|
||||
patch(
|
||||
@@ -727,7 +792,7 @@ class TestBanTrend:
|
||||
async def test_24h_returns_24_buckets(self, empty_f2b_db_path: str) -> None:
|
||||
"""``range_='24h'`` always yields exactly 24 buckets."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||
):
|
||||
result = await ban_service.ban_trend("/fake/sock", "24h")
|
||||
@@ -738,7 +803,7 @@ class TestBanTrend:
|
||||
async def test_7d_returns_28_buckets(self, empty_f2b_db_path: str) -> None:
|
||||
"""``range_='7d'`` yields 28 six-hour buckets."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||
):
|
||||
result = await ban_service.ban_trend("/fake/sock", "7d")
|
||||
@@ -749,7 +814,7 @@ class TestBanTrend:
|
||||
async def test_30d_returns_30_buckets(self, empty_f2b_db_path: str) -> None:
|
||||
"""``range_='30d'`` yields 30 daily buckets."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||
):
|
||||
result = await ban_service.ban_trend("/fake/sock", "30d")
|
||||
@@ -760,7 +825,7 @@ class TestBanTrend:
|
||||
async def test_365d_bucket_size_label(self, empty_f2b_db_path: str) -> None:
|
||||
"""``range_='365d'`` uses '7d' as the bucket size label."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||
):
|
||||
result = await ban_service.ban_trend("/fake/sock", "365d")
|
||||
@@ -771,7 +836,7 @@ class TestBanTrend:
|
||||
async def test_empty_db_all_buckets_zero(self, empty_f2b_db_path: str) -> None:
|
||||
"""All bucket counts are zero when the database has no bans."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||
):
|
||||
result = await ban_service.ban_trend("/fake/sock", "24h")
|
||||
@@ -781,7 +846,7 @@ class TestBanTrend:
|
||||
async def test_buckets_are_time_ordered(self, empty_f2b_db_path: str) -> None:
|
||||
"""Buckets are ordered chronologically (ascending timestamps)."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||
):
|
||||
result = await ban_service.ban_trend("/fake/sock", "7d")
|
||||
@@ -789,6 +854,19 @@ class TestBanTrend:
|
||||
timestamps = [b.timestamp for b in result.buckets]
|
||||
assert timestamps == sorted(timestamps)
|
||||
|
||||
async def test_ban_trend_source_archive_reads_archive(
|
||||
self, app_db_with_archive: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""``ban_trend`` accepts source='archive' and uses archived rows."""
|
||||
result = await ban_service.ban_trend(
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
source="archive",
|
||||
app_db=app_db_with_archive,
|
||||
)
|
||||
|
||||
assert sum(b.count for b in result.buckets) == 2
|
||||
|
||||
async def test_bans_counted_in_correct_bucket(self, tmp_path: Path) -> None:
|
||||
"""A ban at a known time appears in the expected bucket."""
|
||||
import time as _time
|
||||
@@ -804,7 +882,7 @@ class TestBanTrend:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
):
|
||||
result = await ban_service.ban_trend("/fake/sock", "24h")
|
||||
@@ -828,7 +906,7 @@ class TestBanTrend:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
):
|
||||
result = await ban_service.ban_trend(
|
||||
@@ -854,7 +932,7 @@ class TestBanTrend:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
):
|
||||
result = await ban_service.ban_trend(
|
||||
@@ -868,7 +946,7 @@ class TestBanTrend:
|
||||
from datetime import datetime
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||
):
|
||||
result = await ban_service.ban_trend("/fake/sock", "24h")
|
||||
@@ -904,7 +982,7 @@ class TestBansByJail:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
):
|
||||
result = await ban_service.bans_by_jail("/fake/sock", "24h")
|
||||
@@ -931,7 +1009,7 @@ class TestBansByJail:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
):
|
||||
result = await ban_service.bans_by_jail("/fake/sock", "24h")
|
||||
@@ -942,7 +1020,7 @@ class TestBansByJail:
|
||||
async def test_empty_db_returns_empty_list(self, empty_f2b_db_path: str) -> None:
|
||||
"""An empty database returns an empty jails list with total zero."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_jail("/fake/sock", "24h")
|
||||
@@ -954,7 +1032,7 @@ class TestBansByJail:
|
||||
"""Bans older than the time window are not counted."""
|
||||
# f2b_db_path has one ban from _TWO_DAYS_AGO, which is outside "24h".
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_jail("/fake/sock", "24h")
|
||||
@@ -965,7 +1043,7 @@ class TestBansByJail:
|
||||
async def test_origin_filter_blocklist(self, mixed_origin_db_path: str) -> None:
|
||||
"""``origin='blocklist'`` returns only the blocklist-import jail."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_jail(
|
||||
@@ -979,7 +1057,7 @@ class TestBansByJail:
|
||||
async def test_origin_filter_selfblock(self, mixed_origin_db_path: str) -> None:
|
||||
"""``origin='selfblock'`` excludes the blocklist-import jail."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_jail(
|
||||
@@ -995,7 +1073,7 @@ class TestBansByJail:
|
||||
) -> None:
|
||||
"""``origin=None`` returns bans from all jails."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_jail(
|
||||
@@ -1005,6 +1083,20 @@ class TestBansByJail:
|
||||
assert result.total == 3
|
||||
assert len(result.jails) == 3
|
||||
|
||||
async def test_bans_by_jail_source_archive_reads_archive(
|
||||
self, app_db_with_archive: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""``bans_by_jail`` accepts source='archive' and aggregates archived rows."""
|
||||
result = await ban_service.bans_by_jail(
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
source="archive",
|
||||
app_db=app_db_with_archive,
|
||||
)
|
||||
|
||||
assert result.total == 2
|
||||
assert any(j.jail == "sshd" for j in result.jails)
|
||||
|
||||
async def test_diagnostic_warning_when_zero_results_despite_data(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
@@ -1023,7 +1115,7 @@ class TestBansByJail:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
),
|
||||
patch("app.services.ban_service.log") as mock_log,
|
||||
|
||||
@@ -19,8 +19,8 @@ from unittest.mock import AsyncMock, patch
|
||||
import aiosqlite
|
||||
import pytest
|
||||
|
||||
from app.models.geo import GeoInfo
|
||||
from app.services import ban_service, geo_service
|
||||
from app.services.geo_service import GeoInfo
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
@@ -114,13 +114,13 @@ async def _seed_f2b_db(path: str, n: int) -> list[str]:
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def event_loop_policy() -> None: # type: ignore[misc]
|
||||
def event_loop_policy() -> None:
|
||||
"""Use the default event loop policy for module-scoped fixtures."""
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def perf_db_path(tmp_path_factory: Any) -> str: # type: ignore[misc]
|
||||
async def perf_db_path(tmp_path_factory: Any) -> str:
|
||||
"""Return the path to a fail2ban DB seeded with 10 000 synthetic bans.
|
||||
|
||||
Module-scoped so the database is created only once for all perf tests.
|
||||
@@ -161,7 +161,7 @@ class TestBanServicePerformance:
|
||||
return geo_service._cache.get(ip) # noqa: SLF001
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=perf_db_path),
|
||||
):
|
||||
start = time.perf_counter()
|
||||
@@ -191,7 +191,7 @@ class TestBanServicePerformance:
|
||||
return geo_service._cache.get(ip) # noqa: SLF001
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=perf_db_path),
|
||||
):
|
||||
start = time.perf_counter()
|
||||
@@ -217,7 +217,7 @@ class TestBanServicePerformance:
|
||||
return geo_service._cache.get(ip) # noqa: SLF001
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=perf_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
@@ -241,7 +241,7 @@ class TestBanServicePerformance:
|
||||
return geo_service._cache.get(ip) # noqa: SLF001
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=perf_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
|
||||
@@ -203,9 +203,15 @@ class TestImport:
|
||||
call_count += 1
|
||||
raise JailNotFoundError(jail)
|
||||
|
||||
with patch("app.services.jail_service.ban_ip", side_effect=_raise_jail_not_found):
|
||||
with patch("app.services.jail_service.ban_ip", side_effect=_raise_jail_not_found) as mocked_ban_ip:
|
||||
from app.services import jail_service
|
||||
|
||||
result = await blocklist_service.import_source(
|
||||
source, session, "/tmp/fake.sock", db
|
||||
source,
|
||||
session,
|
||||
"/tmp/fake.sock",
|
||||
db,
|
||||
ban_ip=jail_service.ban_ip,
|
||||
)
|
||||
|
||||
# Must abort after the first JailNotFoundError — only one ban attempt.
|
||||
@@ -226,7 +232,14 @@ class TestImport:
|
||||
with patch(
|
||||
"app.services.jail_service.ban_ip", new_callable=AsyncMock
|
||||
):
|
||||
result = await blocklist_service.import_all(db, session, "/tmp/fake.sock")
|
||||
from app.services import jail_service
|
||||
|
||||
result = await blocklist_service.import_all(
|
||||
db,
|
||||
session,
|
||||
"/tmp/fake.sock",
|
||||
ban_ip=jail_service.ban_ip,
|
||||
)
|
||||
|
||||
# Only S1 is enabled, S2 is disabled.
|
||||
assert len(result.results) == 1
|
||||
@@ -315,20 +328,15 @@ class TestGeoPrewarmCacheFilter:
|
||||
def _mock_is_cached(ip: str) -> bool:
|
||||
return ip == "1.2.3.4"
|
||||
|
||||
with (
|
||||
patch("app.services.jail_service.ban_ip", new_callable=AsyncMock),
|
||||
patch(
|
||||
"app.services.geo_service.is_cached",
|
||||
side_effect=_mock_is_cached,
|
||||
),
|
||||
patch(
|
||||
"app.services.geo_service.lookup_batch",
|
||||
new_callable=AsyncMock,
|
||||
return_value={},
|
||||
) as mock_batch,
|
||||
):
|
||||
mock_batch = AsyncMock(return_value={})
|
||||
with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock):
|
||||
result = await blocklist_service.import_source(
|
||||
source, session, "/tmp/fake.sock", db
|
||||
source,
|
||||
session,
|
||||
"/tmp/fake.sock",
|
||||
db,
|
||||
geo_is_cached=_mock_is_cached,
|
||||
geo_batch_lookup=mock_batch,
|
||||
)
|
||||
|
||||
assert result.ips_imported == 3
|
||||
@@ -337,3 +345,40 @@ class TestGeoPrewarmCacheFilter:
|
||||
call_ips = mock_batch.call_args[0][0]
|
||||
assert "1.2.3.4" not in call_ips
|
||||
assert set(call_ips) == {"5.6.7.8", "9.10.11.12"}
|
||||
|
||||
|
||||
class TestImportLogPagination:
|
||||
async def test_list_import_logs_empty(self, db: aiosqlite.Connection) -> None:
|
||||
"""list_import_logs returns an empty page when no logs exist."""
|
||||
resp = await blocklist_service.list_import_logs(
|
||||
db, source_id=None, page=1, page_size=10
|
||||
)
|
||||
assert resp.items == []
|
||||
assert resp.total == 0
|
||||
assert resp.page == 1
|
||||
assert resp.page_size == 10
|
||||
assert resp.total_pages == 1
|
||||
|
||||
async def test_list_import_logs_paginates(self, db: aiosqlite.Connection) -> None:
|
||||
"""list_import_logs computes total pages and returns the correct subset."""
|
||||
from app.repositories import import_log_repo
|
||||
|
||||
for i in range(3):
|
||||
await import_log_repo.add_log(
|
||||
db,
|
||||
source_id=None,
|
||||
source_url=f"https://example{i}.test/ips.txt",
|
||||
ips_imported=1,
|
||||
ips_skipped=0,
|
||||
errors=None,
|
||||
)
|
||||
|
||||
resp = await blocklist_service.list_import_logs(
|
||||
db, source_id=None, page=2, page_size=2
|
||||
)
|
||||
assert resp.total == 3
|
||||
assert resp.total_pages == 2
|
||||
assert resp.page == 2
|
||||
assert resp.page_size == 2
|
||||
assert len(resp.items) == 1
|
||||
assert resp.items[0].source_url == "https://example0.test/ips.txt"
|
||||
|
||||
@@ -6,7 +6,7 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.conffile_parser import (
|
||||
from app.utils.conffile_parser import (
|
||||
merge_action_update,
|
||||
merge_filter_update,
|
||||
parse_action_file,
|
||||
@@ -451,7 +451,7 @@ class TestParseJailFile:
|
||||
"""Unit tests for parse_jail_file."""
|
||||
|
||||
def test_minimal_parses_correctly(self) -> None:
|
||||
from app.services.conffile_parser import parse_jail_file
|
||||
from app.utils.conffile_parser import parse_jail_file
|
||||
|
||||
cfg = parse_jail_file(MINIMAL_JAIL, filename="sshd.conf")
|
||||
assert cfg.filename == "sshd.conf"
|
||||
@@ -463,7 +463,7 @@ class TestParseJailFile:
|
||||
assert jail.logpath == ["/var/log/auth.log"]
|
||||
|
||||
def test_full_parses_multiple_jails(self) -> None:
|
||||
from app.services.conffile_parser import parse_jail_file
|
||||
from app.utils.conffile_parser import parse_jail_file
|
||||
|
||||
cfg = parse_jail_file(FULL_JAIL)
|
||||
assert len(cfg.jails) == 2
|
||||
@@ -471,7 +471,7 @@ class TestParseJailFile:
|
||||
assert "nginx-botsearch" in cfg.jails
|
||||
|
||||
def test_full_jail_numeric_fields(self) -> None:
|
||||
from app.services.conffile_parser import parse_jail_file
|
||||
from app.utils.conffile_parser import parse_jail_file
|
||||
|
||||
jail = parse_jail_file(FULL_JAIL).jails["sshd"]
|
||||
assert jail.maxretry == 3
|
||||
@@ -479,7 +479,7 @@ class TestParseJailFile:
|
||||
assert jail.bantime == 3600
|
||||
|
||||
def test_full_jail_multiline_logpath(self) -> None:
|
||||
from app.services.conffile_parser import parse_jail_file
|
||||
from app.utils.conffile_parser import parse_jail_file
|
||||
|
||||
jail = parse_jail_file(FULL_JAIL).jails["sshd"]
|
||||
assert len(jail.logpath) == 2
|
||||
@@ -487,53 +487,53 @@ class TestParseJailFile:
|
||||
assert "/var/log/syslog" in jail.logpath
|
||||
|
||||
def test_full_jail_multiline_action(self) -> None:
|
||||
from app.services.conffile_parser import parse_jail_file
|
||||
from app.utils.conffile_parser import parse_jail_file
|
||||
|
||||
jail = parse_jail_file(FULL_JAIL).jails["nginx-botsearch"]
|
||||
assert len(jail.action) == 2
|
||||
assert "sendmail-whois" in jail.action
|
||||
|
||||
def test_enabled_true(self) -> None:
|
||||
from app.services.conffile_parser import parse_jail_file
|
||||
from app.utils.conffile_parser import parse_jail_file
|
||||
|
||||
jail = parse_jail_file(FULL_JAIL).jails["sshd"]
|
||||
assert jail.enabled is True
|
||||
|
||||
def test_enabled_false(self) -> None:
|
||||
from app.services.conffile_parser import parse_jail_file
|
||||
from app.utils.conffile_parser import parse_jail_file
|
||||
|
||||
jail = parse_jail_file(FULL_JAIL).jails["nginx-botsearch"]
|
||||
assert jail.enabled is False
|
||||
|
||||
def test_extra_keys_captured(self) -> None:
|
||||
from app.services.conffile_parser import parse_jail_file
|
||||
from app.utils.conffile_parser import parse_jail_file
|
||||
|
||||
jail = parse_jail_file(JAIL_WITH_EXTRA).jails["sshd"]
|
||||
assert jail.extra["custom_key"] == "custom_value"
|
||||
assert jail.extra["another_key"] == "42"
|
||||
|
||||
def test_extra_keys_not_in_named_fields(self) -> None:
|
||||
from app.services.conffile_parser import parse_jail_file
|
||||
from app.utils.conffile_parser import parse_jail_file
|
||||
|
||||
jail = parse_jail_file(JAIL_WITH_EXTRA).jails["sshd"]
|
||||
assert "enabled" not in jail.extra
|
||||
assert "logpath" not in jail.extra
|
||||
|
||||
def test_empty_file_yields_no_jails(self) -> None:
|
||||
from app.services.conffile_parser import parse_jail_file
|
||||
from app.utils.conffile_parser import parse_jail_file
|
||||
|
||||
cfg = parse_jail_file("")
|
||||
assert cfg.jails == {}
|
||||
|
||||
def test_invalid_ini_does_not_raise(self) -> None:
|
||||
from app.services.conffile_parser import parse_jail_file
|
||||
from app.utils.conffile_parser import parse_jail_file
|
||||
|
||||
# Should not raise; just parse what it can.
|
||||
cfg = parse_jail_file("@@@ not valid ini @@@", filename="bad.conf")
|
||||
assert isinstance(cfg.jails, dict)
|
||||
|
||||
def test_default_section_ignored(self) -> None:
|
||||
from app.services.conffile_parser import parse_jail_file
|
||||
from app.utils.conffile_parser import parse_jail_file
|
||||
|
||||
content = "[DEFAULT]\nignoreip = 127.0.0.1\n\n[sshd]\nenabled = true\n"
|
||||
cfg = parse_jail_file(content)
|
||||
@@ -550,7 +550,7 @@ class TestJailFileRoundTrip:
|
||||
"""Tests that parse → serialize → parse preserves values."""
|
||||
|
||||
def test_minimal_round_trip(self) -> None:
|
||||
from app.services.conffile_parser import parse_jail_file, serialize_jail_file_config
|
||||
from app.utils.conffile_parser import parse_jail_file, serialize_jail_file_config
|
||||
|
||||
original = parse_jail_file(MINIMAL_JAIL, filename="sshd.conf")
|
||||
serialized = serialize_jail_file_config(original)
|
||||
@@ -560,7 +560,7 @@ class TestJailFileRoundTrip:
|
||||
assert restored.jails["sshd"].logpath == original.jails["sshd"].logpath
|
||||
|
||||
def test_full_round_trip(self) -> None:
|
||||
from app.services.conffile_parser import parse_jail_file, serialize_jail_file_config
|
||||
from app.utils.conffile_parser import parse_jail_file, serialize_jail_file_config
|
||||
|
||||
original = parse_jail_file(FULL_JAIL)
|
||||
serialized = serialize_jail_file_config(original)
|
||||
@@ -573,7 +573,7 @@ class TestJailFileRoundTrip:
|
||||
assert sorted(restored_jail.action) == sorted(jail.action)
|
||||
|
||||
def test_extra_keys_round_trip(self) -> None:
|
||||
from app.services.conffile_parser import parse_jail_file, serialize_jail_file_config
|
||||
from app.utils.conffile_parser import parse_jail_file, serialize_jail_file_config
|
||||
|
||||
original = parse_jail_file(JAIL_WITH_EXTRA)
|
||||
serialized = serialize_jail_file_config(original)
|
||||
@@ -591,7 +591,7 @@ class TestMergeJailFileUpdate:
|
||||
|
||||
def test_none_update_returns_original(self) -> None:
|
||||
from app.models.config import JailFileConfigUpdate
|
||||
from app.services.conffile_parser import merge_jail_file_update, parse_jail_file
|
||||
from app.utils.conffile_parser import merge_jail_file_update, parse_jail_file
|
||||
|
||||
cfg = parse_jail_file(FULL_JAIL)
|
||||
update = JailFileConfigUpdate()
|
||||
@@ -600,7 +600,7 @@ class TestMergeJailFileUpdate:
|
||||
|
||||
def test_update_replaces_jail(self) -> None:
|
||||
from app.models.config import JailFileConfigUpdate, JailSectionConfig
|
||||
from app.services.conffile_parser import merge_jail_file_update, parse_jail_file
|
||||
from app.utils.conffile_parser import merge_jail_file_update, parse_jail_file
|
||||
|
||||
cfg = parse_jail_file(FULL_JAIL)
|
||||
new_sshd = JailSectionConfig(enabled=False, port="2222")
|
||||
@@ -613,7 +613,7 @@ class TestMergeJailFileUpdate:
|
||||
|
||||
def test_update_adds_new_jail(self) -> None:
|
||||
from app.models.config import JailFileConfigUpdate, JailSectionConfig
|
||||
from app.services.conffile_parser import merge_jail_file_update, parse_jail_file
|
||||
from app.utils.conffile_parser import merge_jail_file_update, parse_jail_file
|
||||
|
||||
cfg = parse_jail_file(MINIMAL_JAIL)
|
||||
new_jail = JailSectionConfig(enabled=True, port="443")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
@@ -256,6 +257,27 @@ class TestUpdateJailConfig:
|
||||
assert "bantime" in keys
|
||||
assert "maxretry" in keys
|
||||
|
||||
async def test_ignores_backend_field(self) -> None:
|
||||
"""update_jail_config does not send a set command for backend."""
|
||||
sent_commands: list[list[Any]] = []
|
||||
|
||||
async def _send(command: list[Any]) -> Any:
|
||||
sent_commands.append(command)
|
||||
return (0, "OK")
|
||||
|
||||
class _FakeClient:
|
||||
def __init__(self, **_kw: Any) -> None:
|
||||
self.send = AsyncMock(side_effect=_send)
|
||||
|
||||
from app.models.config import JailConfigUpdate
|
||||
|
||||
update = JailConfigUpdate(backend="polling")
|
||||
with patch("app.services.config_service.Fail2BanClient", _FakeClient):
|
||||
await config_service.update_jail_config(_SOCKET, "sshd", update)
|
||||
|
||||
keys = [cmd[2] for cmd in sent_commands if len(cmd) >= 3 and cmd[0] == "set"]
|
||||
assert "backend" not in keys
|
||||
|
||||
async def test_raises_validation_error_on_bad_regex(self) -> None:
|
||||
"""update_jail_config raises ConfigValidationError for invalid regex."""
|
||||
from app.models.config import JailConfigUpdate
|
||||
@@ -604,3 +626,210 @@ class TestPreviewLog:
|
||||
result = await config_service.preview_log(req)
|
||||
|
||||
assert result.total_lines <= 50
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# read_fail2ban_log
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReadFail2BanLog:
|
||||
"""Tests for :func:`config_service.read_fail2ban_log`."""
|
||||
|
||||
def _patch_client(self, log_level: str = "INFO", log_target: str = "/var/log/fail2ban.log") -> Any:
|
||||
"""Build a patched Fail2BanClient that returns *log_level* and *log_target*."""
|
||||
async def _send(command: list[Any]) -> Any:
|
||||
key = "|".join(str(c) for c in command)
|
||||
if key == "get|loglevel":
|
||||
return (0, log_level)
|
||||
if key == "get|logtarget":
|
||||
return (0, log_target)
|
||||
return (0, None)
|
||||
|
||||
class _FakeClient:
|
||||
def __init__(self, **_kw: Any) -> None:
|
||||
self.send = AsyncMock(side_effect=_send)
|
||||
|
||||
return patch("app.services.config_service.Fail2BanClient", _FakeClient)
|
||||
|
||||
async def test_returns_log_lines_from_file(self, tmp_path: Any) -> None:
|
||||
"""read_fail2ban_log returns lines from the file and counts totals."""
|
||||
log_file = tmp_path / "fail2ban.log"
|
||||
log_file.write_text("line1\nline2\nline3\n")
|
||||
log_dir = str(tmp_path)
|
||||
|
||||
# Patch _SAFE_LOG_PREFIXES to allow tmp_path
|
||||
with self._patch_client(log_target=str(log_file)), \
|
||||
patch("app.services.config_service._SAFE_LOG_PREFIXES", (log_dir,)):
|
||||
result = await config_service.read_fail2ban_log(_SOCKET, 200)
|
||||
|
||||
assert result.log_path == str(log_file.resolve())
|
||||
assert result.total_lines >= 3
|
||||
assert any("line1" in ln for ln in result.lines)
|
||||
assert result.log_level == "INFO"
|
||||
|
||||
async def test_filter_narrows_returned_lines(self, tmp_path: Any) -> None:
|
||||
"""read_fail2ban_log filters lines by substring."""
|
||||
log_file = tmp_path / "fail2ban.log"
|
||||
log_file.write_text("INFO sshd Found 1.2.3.4\nERROR something else\nINFO sshd Found 5.6.7.8\n")
|
||||
log_dir = str(tmp_path)
|
||||
|
||||
with self._patch_client(log_target=str(log_file)), \
|
||||
patch("app.services.config_service._SAFE_LOG_PREFIXES", (log_dir,)):
|
||||
result = await config_service.read_fail2ban_log(_SOCKET, 200, "Found")
|
||||
|
||||
assert all("Found" in ln for ln in result.lines)
|
||||
assert result.total_lines >= 3 # total is unfiltered
|
||||
|
||||
async def test_non_file_target_raises_operation_error(self) -> None:
|
||||
"""read_fail2ban_log raises ConfigOperationError for STDOUT target."""
|
||||
with self._patch_client(log_target="STDOUT"), \
|
||||
pytest.raises(config_service.ConfigOperationError, match="STDOUT"):
|
||||
await config_service.read_fail2ban_log(_SOCKET, 200)
|
||||
|
||||
async def test_syslog_target_raises_operation_error(self) -> None:
|
||||
"""read_fail2ban_log raises ConfigOperationError for SYSLOG target."""
|
||||
with self._patch_client(log_target="SYSLOG"), \
|
||||
pytest.raises(config_service.ConfigOperationError, match="SYSLOG"):
|
||||
await config_service.read_fail2ban_log(_SOCKET, 200)
|
||||
|
||||
async def test_path_outside_safe_dir_raises_operation_error(self, tmp_path: Any) -> None:
|
||||
"""read_fail2ban_log rejects a log_target outside allowed directories."""
|
||||
log_file = tmp_path / "secret.log"
|
||||
log_file.write_text("secret data\n")
|
||||
|
||||
# Allow only /var/log — tmp_path is deliberately not in the safe list.
|
||||
with self._patch_client(log_target=str(log_file)), \
|
||||
patch("app.services.config_service._SAFE_LOG_PREFIXES", ("/var/log",)), \
|
||||
pytest.raises(config_service.ConfigOperationError, match="outside the allowed"):
|
||||
await config_service.read_fail2ban_log(_SOCKET, 200)
|
||||
|
||||
async def test_missing_log_file_raises_operation_error(self, tmp_path: Any) -> None:
|
||||
"""read_fail2ban_log raises ConfigOperationError when the file does not exist."""
|
||||
missing = str(tmp_path / "nonexistent.log")
|
||||
log_dir = str(tmp_path)
|
||||
|
||||
with self._patch_client(log_target=missing), \
|
||||
patch("app.services.config_service._SAFE_LOG_PREFIXES", (log_dir,)), \
|
||||
pytest.raises(config_service.ConfigOperationError, match="not found"):
|
||||
await config_service.read_fail2ban_log(_SOCKET, 200)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_service_status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetServiceStatus:
|
||||
"""Tests for :func:`config_service.get_service_status`."""
|
||||
|
||||
async def test_online_status_includes_log_config(self) -> None:
|
||||
"""get_service_status returns correct fields when fail2ban is online."""
|
||||
from app.models.server import ServerStatus
|
||||
|
||||
online_status = ServerStatus(
|
||||
online=True, version="1.0.0", active_jails=2, total_bans=5, total_failures=3
|
||||
)
|
||||
|
||||
async def _send(command: list[Any]) -> Any:
|
||||
key = "|".join(str(c) for c in command)
|
||||
if key == "get|loglevel":
|
||||
return (0, "DEBUG")
|
||||
if key == "get|logtarget":
|
||||
return (0, "/var/log/fail2ban.log")
|
||||
return (0, None)
|
||||
|
||||
class _FakeClient:
|
||||
def __init__(self, **_kw: Any) -> None:
|
||||
self.send = AsyncMock(side_effect=_send)
|
||||
|
||||
with patch("app.services.config_service.Fail2BanClient", _FakeClient):
|
||||
result = await config_service.get_service_status(
|
||||
_SOCKET,
|
||||
probe_fn=AsyncMock(return_value=online_status),
|
||||
)
|
||||
|
||||
from app import __version__
|
||||
|
||||
assert result.online is True
|
||||
assert result.version == __version__
|
||||
assert result.jail_count == 2
|
||||
assert result.total_bans == 5
|
||||
assert result.total_failures == 3
|
||||
assert result.log_level == "DEBUG"
|
||||
assert result.log_target == "/var/log/fail2ban.log"
|
||||
|
||||
async def test_offline_status_returns_unknown_log_fields(self) -> None:
|
||||
"""get_service_status returns 'UNKNOWN' log fields when fail2ban is offline."""
|
||||
from app.models.server import ServerStatus
|
||||
|
||||
offline_status = ServerStatus(online=False)
|
||||
|
||||
result = await config_service.get_service_status(
|
||||
_SOCKET,
|
||||
probe_fn=AsyncMock(return_value=offline_status),
|
||||
)
|
||||
|
||||
assert result.online is False
|
||||
assert result.jail_count == 0
|
||||
assert result.log_level == "UNKNOWN"
|
||||
assert result.log_target == "UNKNOWN"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestConfigModuleIntegration:
|
||||
async def test_jail_config_service_list_inactive_jails_uses_imports(self, tmp_path: Any) -> None:
|
||||
from app.services.jail_config_service import list_inactive_jails
|
||||
|
||||
# Arrange: fake parse_jails output with one active and one inactive
|
||||
def fake_parse_jails_sync(path: Path) -> tuple[dict[str, dict[str, str]], dict[str, str]]:
|
||||
return (
|
||||
{
|
||||
"sshd": {
|
||||
"enabled": "true",
|
||||
"filter": "sshd",
|
||||
"logpath": "/var/log/auth.log",
|
||||
},
|
||||
"apache-auth": {
|
||||
"enabled": "false",
|
||||
"filter": "apache-auth",
|
||||
"logpath": "/var/log/apache2/error.log",
|
||||
},
|
||||
},
|
||||
{
|
||||
"sshd": str(path / "jail.conf"),
|
||||
"apache-auth": str(path / "jail.conf"),
|
||||
},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.services.jail_config_service._parse_jails_sync",
|
||||
new=fake_parse_jails_sync,
|
||||
), patch(
|
||||
"app.services.jail_config_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value={"sshd"}),
|
||||
):
|
||||
result = await list_inactive_jails(str(tmp_path), "/fake.sock")
|
||||
|
||||
names = {j.name for j in result.jails}
|
||||
assert "apache-auth" in names
|
||||
assert "sshd" not in names
|
||||
|
||||
async def test_filter_config_service_list_filters_uses_imports(self, tmp_path: Any) -> None:
|
||||
from app.services.filter_config_service import list_filters
|
||||
|
||||
# Arrange minimal filter and jail config files
|
||||
filter_d = tmp_path / "filter.d"
|
||||
filter_d.mkdir(parents=True)
|
||||
(filter_d / "sshd.conf").write_text("[Definition]\nfailregex = ^%(__prefix_line)s.*$\n")
|
||||
(tmp_path / "jail.conf").write_text("[sshd]\nfilter = sshd\nenabled = true\n")
|
||||
|
||||
with patch(
|
||||
"app.services.filter_config_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value={"sshd"}),
|
||||
):
|
||||
result = await list_filters(str(tmp_path), "/fake.sock")
|
||||
|
||||
assert result.total == 1
|
||||
assert result.filters[0].name == "sshd"
|
||||
assert result.filters[0].active is True
|
||||
|
||||
@@ -8,7 +8,7 @@ import pytest
|
||||
|
||||
from app.models.config import ActionConfigUpdate, FilterConfigUpdate, JailFileConfigUpdate
|
||||
from app.models.file_config import ConfFileCreateRequest, ConfFileUpdateRequest
|
||||
from app.services.file_config_service import (
|
||||
from app.services.raw_config_io_service import (
|
||||
ConfigDirError,
|
||||
ConfigFileExistsError,
|
||||
ConfigFileNameError,
|
||||
|
||||
@@ -2,12 +2,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.geo import GeoInfo
|
||||
from app.services import geo_service
|
||||
from app.services.geo_service import GeoInfo
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
@@ -44,7 +45,7 @@ def _make_session(response_json: dict[str, object], status: int = 200) -> MagicM
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_geo_cache() -> None: # type: ignore[misc]
|
||||
def clear_geo_cache() -> None:
|
||||
"""Flush the module-level geo cache before every test."""
|
||||
geo_service.clear_cache()
|
||||
|
||||
@@ -68,7 +69,7 @@ class TestLookupSuccess:
|
||||
"org": "AS3320 Deutsche Telekom AG",
|
||||
}
|
||||
)
|
||||
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
||||
result = await geo_service.lookup("1.2.3.4", session)
|
||||
|
||||
assert result is not None
|
||||
assert result.country_code == "DE"
|
||||
@@ -84,7 +85,7 @@ class TestLookupSuccess:
|
||||
"org": "Google LLC",
|
||||
}
|
||||
)
|
||||
result = await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type]
|
||||
result = await geo_service.lookup("8.8.8.8", session)
|
||||
|
||||
assert result is not None
|
||||
assert result.country_name == "United States"
|
||||
@@ -100,7 +101,7 @@ class TestLookupSuccess:
|
||||
"org": "Deutsche Telekom",
|
||||
}
|
||||
)
|
||||
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
||||
result = await geo_service.lookup("1.2.3.4", session)
|
||||
|
||||
assert result is not None
|
||||
assert result.asn == "AS3320"
|
||||
@@ -116,7 +117,7 @@ class TestLookupSuccess:
|
||||
"org": "Google LLC",
|
||||
}
|
||||
)
|
||||
result = await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type]
|
||||
result = await geo_service.lookup("8.8.8.8", session)
|
||||
|
||||
assert result is not None
|
||||
assert result.org == "Google LLC"
|
||||
@@ -142,8 +143,8 @@ class TestLookupCaching:
|
||||
}
|
||||
)
|
||||
|
||||
await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("1.2.3.4", session)
|
||||
await geo_service.lookup("1.2.3.4", session)
|
||||
|
||||
# The session.get() should only have been called once.
|
||||
assert session.get.call_count == 1
|
||||
@@ -160,9 +161,9 @@ class TestLookupCaching:
|
||||
}
|
||||
)
|
||||
|
||||
await geo_service.lookup("2.3.4.5", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("2.3.4.5", session)
|
||||
geo_service.clear_cache()
|
||||
await geo_service.lookup("2.3.4.5", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("2.3.4.5", session)
|
||||
|
||||
assert session.get.call_count == 2
|
||||
|
||||
@@ -172,8 +173,8 @@ class TestLookupCaching:
|
||||
{"status": "fail", "message": "reserved range"}
|
||||
)
|
||||
|
||||
await geo_service.lookup("192.168.1.1", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("192.168.1.1", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("192.168.1.1", session)
|
||||
await geo_service.lookup("192.168.1.1", session)
|
||||
|
||||
# Second call is blocked by the negative cache — only one API hit.
|
||||
assert session.get.call_count == 1
|
||||
@@ -190,7 +191,7 @@ class TestLookupFailures:
|
||||
async def test_non_200_response_returns_null_geo_info(self) -> None:
|
||||
"""A 429 or 500 status returns GeoInfo with null fields (not None)."""
|
||||
session = _make_session({}, status=429)
|
||||
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
||||
result = await geo_service.lookup("1.2.3.4", session)
|
||||
assert result is not None
|
||||
assert isinstance(result, GeoInfo)
|
||||
assert result.country_code is None
|
||||
@@ -203,7 +204,7 @@ class TestLookupFailures:
|
||||
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
session.get = MagicMock(return_value=mock_ctx)
|
||||
|
||||
result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type]
|
||||
result = await geo_service.lookup("10.0.0.1", session)
|
||||
assert result is not None
|
||||
assert isinstance(result, GeoInfo)
|
||||
assert result.country_code is None
|
||||
@@ -211,7 +212,7 @@ class TestLookupFailures:
|
||||
async def test_failed_status_returns_geo_info_with_nulls(self) -> None:
|
||||
"""When ip-api returns ``status=fail`` a GeoInfo with null fields is returned (but not cached)."""
|
||||
session = _make_session({"status": "fail", "message": "private range"})
|
||||
result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type]
|
||||
result = await geo_service.lookup("10.0.0.1", session)
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, GeoInfo)
|
||||
@@ -231,8 +232,8 @@ class TestNegativeCache:
|
||||
"""After a failed lookup the second call is served from the neg cache."""
|
||||
session = _make_session({"status": "fail", "message": "private range"})
|
||||
|
||||
r1 = await geo_service.lookup("192.0.2.1", session) # type: ignore[arg-type]
|
||||
r2 = await geo_service.lookup("192.0.2.1", session) # type: ignore[arg-type]
|
||||
r1 = await geo_service.lookup("192.0.2.1", session)
|
||||
r2 = await geo_service.lookup("192.0.2.1", session)
|
||||
|
||||
# Only one HTTP call should have been made; second served from neg cache.
|
||||
assert session.get.call_count == 1
|
||||
@@ -243,12 +244,12 @@ class TestNegativeCache:
|
||||
"""When the neg-cache entry is older than the TTL a new API call is made."""
|
||||
session = _make_session({"status": "fail", "message": "private range"})
|
||||
|
||||
await geo_service.lookup("192.0.2.2", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("192.0.2.2", session)
|
||||
|
||||
# Manually expire the neg-cache entry.
|
||||
geo_service._neg_cache["192.0.2.2"] -= geo_service._NEG_CACHE_TTL + 1 # type: ignore[attr-defined]
|
||||
geo_service._neg_cache["192.0.2.2"] -= geo_service._NEG_CACHE_TTL + 1
|
||||
|
||||
await geo_service.lookup("192.0.2.2", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("192.0.2.2", session)
|
||||
|
||||
# Both calls should have hit the API.
|
||||
assert session.get.call_count == 2
|
||||
@@ -257,9 +258,9 @@ class TestNegativeCache:
|
||||
"""After clearing the neg cache the IP is eligible for a new API call."""
|
||||
session = _make_session({"status": "fail", "message": "private range"})
|
||||
|
||||
await geo_service.lookup("192.0.2.3", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("192.0.2.3", session)
|
||||
geo_service.clear_neg_cache()
|
||||
await geo_service.lookup("192.0.2.3", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("192.0.2.3", session)
|
||||
|
||||
assert session.get.call_count == 2
|
||||
|
||||
@@ -275,9 +276,9 @@ class TestNegativeCache:
|
||||
}
|
||||
)
|
||||
|
||||
await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("1.2.3.4", session)
|
||||
|
||||
assert "1.2.3.4" not in geo_service._neg_cache # type: ignore[attr-defined]
|
||||
assert "1.2.3.4" not in geo_service._neg_cache
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -307,7 +308,7 @@ class TestGeoipFallback:
|
||||
mock_reader = self._make_geoip_reader("DE", "Germany")
|
||||
|
||||
with patch.object(geo_service, "_geoip_reader", mock_reader):
|
||||
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
||||
result = await geo_service.lookup("1.2.3.4", session)
|
||||
|
||||
mock_reader.country.assert_called_once_with("1.2.3.4")
|
||||
assert result is not None
|
||||
@@ -320,12 +321,12 @@ class TestGeoipFallback:
|
||||
mock_reader = self._make_geoip_reader("US", "United States")
|
||||
|
||||
with patch.object(geo_service, "_geoip_reader", mock_reader):
|
||||
await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("8.8.8.8", session)
|
||||
# Second call must be served from positive cache without hitting API.
|
||||
await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("8.8.8.8", session)
|
||||
|
||||
assert session.get.call_count == 1
|
||||
assert "8.8.8.8" in geo_service._cache # type: ignore[attr-defined]
|
||||
assert "8.8.8.8" in geo_service._cache
|
||||
|
||||
async def test_geoip_fallback_not_called_on_api_success(self) -> None:
|
||||
"""When ip-api succeeds, the geoip2 reader must not be consulted."""
|
||||
@@ -341,7 +342,7 @@ class TestGeoipFallback:
|
||||
mock_reader = self._make_geoip_reader("XX", "Nowhere")
|
||||
|
||||
with patch.object(geo_service, "_geoip_reader", mock_reader):
|
||||
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
||||
result = await geo_service.lookup("1.2.3.4", session)
|
||||
|
||||
mock_reader.country.assert_not_called()
|
||||
assert result is not None
|
||||
@@ -352,7 +353,7 @@ class TestGeoipFallback:
|
||||
session = _make_session({"status": "fail", "message": "private range"})
|
||||
|
||||
with patch.object(geo_service, "_geoip_reader", None):
|
||||
result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type]
|
||||
result = await geo_service.lookup("10.0.0.1", session)
|
||||
|
||||
assert result is not None
|
||||
assert result.country_code is None
|
||||
@@ -363,7 +364,7 @@ class TestGeoipFallback:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_batch_session(batch_response: list[dict[str, object]]) -> MagicMock:
|
||||
def _make_batch_session(batch_response: Sequence[Mapping[str, object]]) -> MagicMock:
|
||||
"""Build a mock aiohttp.ClientSession for batch POST calls.
|
||||
|
||||
Args:
|
||||
@@ -412,7 +413,7 @@ class TestLookupBatchSingleCommit:
|
||||
session = _make_batch_session(batch_response)
|
||||
db = _make_async_db()
|
||||
|
||||
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
|
||||
await geo_service.lookup_batch(ips, session, db=db)
|
||||
|
||||
db.commit.assert_awaited_once()
|
||||
|
||||
@@ -426,7 +427,7 @@ class TestLookupBatchSingleCommit:
|
||||
session = _make_batch_session(batch_response)
|
||||
db = _make_async_db()
|
||||
|
||||
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
|
||||
await geo_service.lookup_batch(ips, session, db=db)
|
||||
|
||||
db.commit.assert_awaited_once()
|
||||
|
||||
@@ -452,13 +453,13 @@ class TestLookupBatchSingleCommit:
|
||||
|
||||
async def test_no_commit_for_all_cached_ips(self) -> None:
|
||||
"""When all IPs are already cached, no HTTP call and no commit occur."""
|
||||
geo_service._cache["5.5.5.5"] = GeoInfo( # type: ignore[attr-defined]
|
||||
geo_service._cache["5.5.5.5"] = GeoInfo(
|
||||
country_code="FR", country_name="France", asn="AS1", org="ISP"
|
||||
)
|
||||
db = _make_async_db()
|
||||
session = _make_batch_session([])
|
||||
|
||||
result = await geo_service.lookup_batch(["5.5.5.5"], session, db=db) # type: ignore[arg-type]
|
||||
result = await geo_service.lookup_batch(["5.5.5.5"], session, db=db)
|
||||
|
||||
assert result["5.5.5.5"].country_code == "FR"
|
||||
db.commit.assert_not_awaited()
|
||||
@@ -476,26 +477,26 @@ class TestDirtySetTracking:
|
||||
def test_successful_resolution_adds_to_dirty(self) -> None:
|
||||
"""Storing a GeoInfo with a country_code adds the IP to _dirty."""
|
||||
info = GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")
|
||||
geo_service._store("1.2.3.4", info) # type: ignore[attr-defined]
|
||||
geo_service._store("1.2.3.4", info)
|
||||
|
||||
assert "1.2.3.4" in geo_service._dirty # type: ignore[attr-defined]
|
||||
assert "1.2.3.4" in geo_service._dirty
|
||||
|
||||
def test_null_country_does_not_add_to_dirty(self) -> None:
|
||||
"""Storing a GeoInfo with country_code=None must not pollute _dirty."""
|
||||
info = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
|
||||
geo_service._store("10.0.0.1", info) # type: ignore[attr-defined]
|
||||
geo_service._store("10.0.0.1", info)
|
||||
|
||||
assert "10.0.0.1" not in geo_service._dirty # type: ignore[attr-defined]
|
||||
assert "10.0.0.1" not in geo_service._dirty
|
||||
|
||||
def test_clear_cache_also_clears_dirty(self) -> None:
|
||||
"""clear_cache() must discard any pending dirty entries."""
|
||||
info = GeoInfo(country_code="US", country_name="United States", asn="AS1", org="ISP")
|
||||
geo_service._store("8.8.8.8", info) # type: ignore[attr-defined]
|
||||
assert geo_service._dirty # type: ignore[attr-defined]
|
||||
geo_service._store("8.8.8.8", info)
|
||||
assert geo_service._dirty
|
||||
|
||||
geo_service.clear_cache()
|
||||
|
||||
assert not geo_service._dirty # type: ignore[attr-defined]
|
||||
assert not geo_service._dirty
|
||||
|
||||
async def test_lookup_batch_populates_dirty(self) -> None:
|
||||
"""After lookup_batch() with db=None, resolved IPs appear in _dirty."""
|
||||
@@ -509,7 +510,7 @@ class TestDirtySetTracking:
|
||||
await geo_service.lookup_batch(ips, session, db=None)
|
||||
|
||||
for ip in ips:
|
||||
assert ip in geo_service._dirty # type: ignore[attr-defined]
|
||||
assert ip in geo_service._dirty
|
||||
|
||||
|
||||
class TestFlushDirty:
|
||||
@@ -518,8 +519,8 @@ class TestFlushDirty:
|
||||
async def test_flush_writes_and_clears_dirty(self) -> None:
|
||||
"""flush_dirty() inserts all dirty IPs and clears _dirty afterwards."""
|
||||
info = GeoInfo(country_code="GB", country_name="United Kingdom", asn="AS2856", org="BT")
|
||||
geo_service._store("100.0.0.1", info) # type: ignore[attr-defined]
|
||||
assert "100.0.0.1" in geo_service._dirty # type: ignore[attr-defined]
|
||||
geo_service._store("100.0.0.1", info)
|
||||
assert "100.0.0.1" in geo_service._dirty
|
||||
|
||||
db = _make_async_db()
|
||||
count = await geo_service.flush_dirty(db)
|
||||
@@ -527,7 +528,7 @@ class TestFlushDirty:
|
||||
assert count == 1
|
||||
db.executemany.assert_awaited_once()
|
||||
db.commit.assert_awaited_once()
|
||||
assert "100.0.0.1" not in geo_service._dirty # type: ignore[attr-defined]
|
||||
assert "100.0.0.1" not in geo_service._dirty
|
||||
|
||||
async def test_flush_returns_zero_when_nothing_dirty(self) -> None:
|
||||
"""flush_dirty() returns 0 and makes no DB calls when _dirty is empty."""
|
||||
@@ -541,7 +542,7 @@ class TestFlushDirty:
|
||||
async def test_flush_re_adds_to_dirty_on_db_error(self) -> None:
|
||||
"""When the DB write fails, entries are re-added to _dirty for retry."""
|
||||
info = GeoInfo(country_code="AU", country_name="Australia", asn="AS1", org="ISP")
|
||||
geo_service._store("200.0.0.1", info) # type: ignore[attr-defined]
|
||||
geo_service._store("200.0.0.1", info)
|
||||
|
||||
db = _make_async_db()
|
||||
db.executemany = AsyncMock(side_effect=OSError("disk full"))
|
||||
@@ -549,7 +550,7 @@ class TestFlushDirty:
|
||||
count = await geo_service.flush_dirty(db)
|
||||
|
||||
assert count == 0
|
||||
assert "200.0.0.1" in geo_service._dirty # type: ignore[attr-defined]
|
||||
assert "200.0.0.1" in geo_service._dirty
|
||||
|
||||
async def test_flush_batch_and_lookup_batch_integration(self) -> None:
|
||||
"""lookup_batch() populates _dirty; flush_dirty() then persists them."""
|
||||
@@ -562,14 +563,14 @@ class TestFlushDirty:
|
||||
|
||||
# Resolve without DB to populate only in-memory cache and _dirty.
|
||||
await geo_service.lookup_batch(ips, session, db=None)
|
||||
assert geo_service._dirty == set(ips) # type: ignore[attr-defined]
|
||||
assert geo_service._dirty == set(ips)
|
||||
|
||||
# Now flush to the DB.
|
||||
db = _make_async_db()
|
||||
count = await geo_service.flush_dirty(db)
|
||||
|
||||
assert count == 2
|
||||
assert not geo_service._dirty # type: ignore[attr-defined]
|
||||
assert not geo_service._dirty
|
||||
db.commit.assert_awaited_once()
|
||||
|
||||
|
||||
@@ -585,7 +586,7 @@ class TestLookupBatchThrottling:
|
||||
"""When more than _BATCH_SIZE IPs are sent, asyncio.sleep is called
|
||||
between consecutive batch HTTP calls with at least _BATCH_DELAY."""
|
||||
# Generate _BATCH_SIZE + 1 IPs so we get exactly 2 batch calls.
|
||||
batch_size: int = geo_service._BATCH_SIZE # type: ignore[attr-defined]
|
||||
batch_size: int = geo_service._BATCH_SIZE
|
||||
ips = [f"10.0.{i // 256}.{i % 256}" for i in range(batch_size + 1)]
|
||||
|
||||
def _make_result(chunk: list[str], _session: object) -> dict[str, GeoInfo]:
|
||||
@@ -608,7 +609,7 @@ class TestLookupBatchThrottling:
|
||||
assert mock_batch.call_count == 2
|
||||
mock_sleep.assert_awaited_once()
|
||||
delay_arg: float = mock_sleep.call_args[0][0]
|
||||
assert delay_arg >= geo_service._BATCH_DELAY # type: ignore[attr-defined]
|
||||
assert delay_arg >= geo_service._BATCH_DELAY
|
||||
|
||||
async def test_lookup_batch_retries_on_full_chunk_failure(self) -> None:
|
||||
"""When a chunk returns all-None on first try, it retries and succeeds."""
|
||||
@@ -650,7 +651,7 @@ class TestLookupBatchThrottling:
|
||||
_empty = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
|
||||
_failure: dict[str, GeoInfo] = dict.fromkeys(ips, _empty)
|
||||
|
||||
max_retries: int = geo_service._BATCH_MAX_RETRIES # type: ignore[attr-defined]
|
||||
max_retries: int = geo_service._BATCH_MAX_RETRIES
|
||||
|
||||
with (
|
||||
patch(
|
||||
@@ -667,11 +668,11 @@ class TestLookupBatchThrottling:
|
||||
# IP should have no country.
|
||||
assert result["9.9.9.9"].country_code is None
|
||||
# Negative cache should contain the IP.
|
||||
assert "9.9.9.9" in geo_service._neg_cache # type: ignore[attr-defined]
|
||||
assert "9.9.9.9" in geo_service._neg_cache
|
||||
# Sleep called for each retry with exponential backoff.
|
||||
assert mock_sleep.call_count == max_retries
|
||||
backoff_values = [call.args[0] for call in mock_sleep.call_args_list]
|
||||
batch_delay: float = geo_service._BATCH_DELAY # type: ignore[attr-defined]
|
||||
batch_delay: float = geo_service._BATCH_DELAY
|
||||
for i, val in enumerate(backoff_values):
|
||||
expected = batch_delay * (2 ** (i + 1))
|
||||
assert val == pytest.approx(expected)
|
||||
@@ -709,7 +710,7 @@ class TestErrorLogging:
|
||||
import structlog.testing
|
||||
|
||||
with structlog.testing.capture_logs() as captured:
|
||||
result = await geo_service.lookup("197.221.98.153", session) # type: ignore[arg-type]
|
||||
result = await geo_service.lookup("197.221.98.153", session)
|
||||
|
||||
assert result is not None
|
||||
assert result.country_code is None
|
||||
@@ -733,7 +734,7 @@ class TestErrorLogging:
|
||||
import structlog.testing
|
||||
|
||||
with structlog.testing.capture_logs() as captured:
|
||||
await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("10.0.0.1", session)
|
||||
|
||||
request_failed = [e for e in captured if e.get("event") == "geo_lookup_request_failed"]
|
||||
assert len(request_failed) == 1
|
||||
@@ -757,7 +758,7 @@ class TestErrorLogging:
|
||||
import structlog.testing
|
||||
|
||||
with structlog.testing.capture_logs() as captured:
|
||||
result = await geo_service._batch_api_call(["1.2.3.4"], session) # type: ignore[attr-defined]
|
||||
result = await geo_service._batch_api_call(["1.2.3.4"], session)
|
||||
|
||||
assert result["1.2.3.4"].country_code is None
|
||||
|
||||
@@ -778,7 +779,7 @@ class TestLookupCachedOnly:
|
||||
|
||||
def test_returns_cached_ips(self) -> None:
|
||||
"""IPs already in the cache are returned in the geo_map."""
|
||||
geo_service._cache["1.1.1.1"] = GeoInfo( # type: ignore[attr-defined]
|
||||
geo_service._cache["1.1.1.1"] = GeoInfo(
|
||||
country_code="AU", country_name="Australia", asn="AS13335", org="Cloudflare"
|
||||
)
|
||||
geo_map, uncached = geo_service.lookup_cached_only(["1.1.1.1"])
|
||||
@@ -798,7 +799,7 @@ class TestLookupCachedOnly:
|
||||
"""IPs in the negative cache within TTL are not re-queued as uncached."""
|
||||
import time
|
||||
|
||||
geo_service._neg_cache["10.0.0.1"] = time.monotonic() # type: ignore[attr-defined]
|
||||
geo_service._neg_cache["10.0.0.1"] = time.monotonic()
|
||||
|
||||
geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.1"])
|
||||
|
||||
@@ -807,7 +808,7 @@ class TestLookupCachedOnly:
|
||||
|
||||
def test_expired_neg_cache_requeued(self) -> None:
|
||||
"""IPs whose neg-cache entry has expired are listed as uncached."""
|
||||
geo_service._neg_cache["10.0.0.2"] = 0.0 # epoch 0 → expired # type: ignore[attr-defined]
|
||||
geo_service._neg_cache["10.0.0.2"] = 0.0 # epoch 0 → expired
|
||||
|
||||
_geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.2"])
|
||||
|
||||
@@ -815,12 +816,12 @@ class TestLookupCachedOnly:
|
||||
|
||||
def test_mixed_ips(self) -> None:
|
||||
"""A mix of cached, neg-cached, and unknown IPs is split correctly."""
|
||||
geo_service._cache["1.2.3.4"] = GeoInfo( # type: ignore[attr-defined]
|
||||
geo_service._cache["1.2.3.4"] = GeoInfo(
|
||||
country_code="DE", country_name="Germany", asn=None, org=None
|
||||
)
|
||||
import time
|
||||
|
||||
geo_service._neg_cache["5.5.5.5"] = time.monotonic() # type: ignore[attr-defined]
|
||||
geo_service._neg_cache["5.5.5.5"] = time.monotonic()
|
||||
|
||||
geo_map, uncached = geo_service.lookup_cached_only(["1.2.3.4", "5.5.5.5", "9.9.9.9"])
|
||||
|
||||
@@ -829,7 +830,7 @@ class TestLookupCachedOnly:
|
||||
|
||||
def test_deduplication(self) -> None:
|
||||
"""Duplicate IPs in the input appear at most once in the output."""
|
||||
geo_service._cache["1.2.3.4"] = GeoInfo( # type: ignore[attr-defined]
|
||||
geo_service._cache["1.2.3.4"] = GeoInfo(
|
||||
country_code="US", country_name="United States", asn=None, org=None
|
||||
)
|
||||
|
||||
@@ -866,7 +867,7 @@ class TestLookupBatchBulkWrites:
|
||||
session = _make_batch_session(batch_response)
|
||||
db = _make_async_db()
|
||||
|
||||
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
|
||||
await geo_service.lookup_batch(ips, session, db=db)
|
||||
|
||||
# One executemany for the positive rows.
|
||||
assert db.executemany.await_count >= 1
|
||||
@@ -883,7 +884,7 @@ class TestLookupBatchBulkWrites:
|
||||
session = _make_batch_session(batch_response)
|
||||
db = _make_async_db()
|
||||
|
||||
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
|
||||
await geo_service.lookup_batch(ips, session, db=db)
|
||||
|
||||
assert db.executemany.await_count >= 1
|
||||
db.execute.assert_not_awaited()
|
||||
@@ -905,7 +906,7 @@ class TestLookupBatchBulkWrites:
|
||||
session = _make_batch_session(batch_response)
|
||||
db = _make_async_db()
|
||||
|
||||
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
|
||||
await geo_service.lookup_batch(ips, session, db=db)
|
||||
|
||||
# One executemany for positives, one for negatives.
|
||||
assert db.executemany.await_count == 2
|
||||
|
||||
@@ -11,6 +11,7 @@ from unittest.mock import AsyncMock, patch
|
||||
import aiosqlite
|
||||
import pytest
|
||||
|
||||
from app.db import init_db
|
||||
from app.services import history_service
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -64,7 +65,7 @@ async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
||||
async def f2b_db_path(tmp_path: Path) -> str:
|
||||
"""Return the path to a test fail2ban SQLite database."""
|
||||
path = str(tmp_path / "fail2ban_test.sqlite3")
|
||||
await _create_f2b_db(
|
||||
@@ -123,7 +124,7 @@ class TestListHistory:
|
||||
) -> None:
|
||||
"""No filter returns every record in the database."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history("fake_socket")
|
||||
@@ -135,7 +136,7 @@ class TestListHistory:
|
||||
) -> None:
|
||||
"""The ``range_`` filter excludes bans older than the window."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
# "24h" window should include only the two recent bans
|
||||
@@ -147,7 +148,7 @@ class TestListHistory:
|
||||
async def test_jail_filter(self, f2b_db_path: str) -> None:
|
||||
"""Jail filter restricts results to bans from that jail."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history("fake_socket", jail="nginx")
|
||||
@@ -157,7 +158,7 @@ class TestListHistory:
|
||||
async def test_ip_prefix_filter(self, f2b_db_path: str) -> None:
|
||||
"""IP prefix filter restricts results to matching IPs."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history(
|
||||
@@ -170,7 +171,7 @@ class TestListHistory:
|
||||
async def test_combined_filters(self, f2b_db_path: str) -> None:
|
||||
"""Jail + IP prefix filters applied together narrow the result set."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history(
|
||||
@@ -179,10 +180,23 @@ class TestListHistory:
|
||||
# 2 sshd bans for 1.2.3.4
|
||||
assert result.total == 2
|
||||
|
||||
async def test_origin_filter_selfblock(self, f2b_db_path: str) -> None:
|
||||
"""Origin filter should include only selfblock entries."""
|
||||
with patch(
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history(
|
||||
"fake_socket", origin="selfblock"
|
||||
)
|
||||
|
||||
assert result.total == 4
|
||||
assert all(item.jail != "blocklist-import" for item in result.items)
|
||||
|
||||
async def test_unknown_ip_returns_empty(self, f2b_db_path: str) -> None:
|
||||
"""Filtering by a non-existent IP returns an empty result set."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history(
|
||||
@@ -196,7 +210,7 @@ class TestListHistory:
|
||||
) -> None:
|
||||
"""``failures`` field is parsed from the JSON ``data`` column."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history(
|
||||
@@ -210,7 +224,7 @@ class TestListHistory:
|
||||
) -> None:
|
||||
"""``matches`` list is parsed from the JSON ``data`` column."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history(
|
||||
@@ -226,7 +240,7 @@ class TestListHistory:
|
||||
) -> None:
|
||||
"""Records with ``data=NULL`` produce failures=0 and matches=[]."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history(
|
||||
@@ -240,7 +254,7 @@ class TestListHistory:
|
||||
async def test_pagination(self, f2b_db_path: str) -> None:
|
||||
"""Pagination returns the correct slice."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history(
|
||||
@@ -251,6 +265,31 @@ class TestListHistory:
|
||||
assert result.page == 1
|
||||
assert result.page_size == 2
|
||||
|
||||
async def test_source_archive_reads_from_archive(self, f2b_db_path: str, tmp_path: Path) -> None:
|
||||
"""Using source='archive' reads from the BanGUI archive table."""
|
||||
app_db_path = str(tmp_path / "app_archive.db")
|
||||
async with aiosqlite.connect(app_db_path) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
await init_db(db)
|
||||
await db.execute(
|
||||
"INSERT INTO history_archive (jail, ip, timeofban, bancount, data, action) VALUES (?, ?, ?, ?, ?, ?)",
|
||||
("sshd", "10.0.0.1", _ONE_HOUR_AGO, 1, '{"matches": [], "failures": 0}', "ban"),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
with patch(
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history(
|
||||
"fake_socket",
|
||||
source="archive",
|
||||
db=db,
|
||||
)
|
||||
|
||||
assert result.total == 1
|
||||
assert result.items[0].ip == "10.0.0.1"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_ip_detail tests
|
||||
@@ -265,7 +304,7 @@ class TestGetIpDetail:
|
||||
) -> None:
|
||||
"""Returns ``None`` when the IP has no records in the database."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.get_ip_detail("fake_socket", "99.99.99.99")
|
||||
@@ -276,7 +315,7 @@ class TestGetIpDetail:
|
||||
) -> None:
|
||||
"""Returns an IpDetailResponse with correct totals for a known IP."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.get_ip_detail("fake_socket", "1.2.3.4")
|
||||
@@ -291,7 +330,7 @@ class TestGetIpDetail:
|
||||
) -> None:
|
||||
"""Timeline events are ordered newest-first."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.get_ip_detail("fake_socket", "1.2.3.4")
|
||||
@@ -304,7 +343,7 @@ class TestGetIpDetail:
|
||||
async def test_last_ban_at_is_most_recent(self, f2b_db_path: str) -> None:
|
||||
"""``last_ban_at`` matches the banned_at of the first timeline event."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.get_ip_detail("fake_socket", "1.2.3.4")
|
||||
@@ -316,7 +355,7 @@ class TestGetIpDetail:
|
||||
self, f2b_db_path: str
|
||||
) -> None:
|
||||
"""Geolocation is applied when a geo_enricher is provided."""
|
||||
from app.services.geo_service import GeoInfo
|
||||
from app.models.geo import GeoInfo
|
||||
|
||||
mock_geo = GeoInfo(
|
||||
country_code="US",
|
||||
@@ -327,7 +366,7 @@ class TestGetIpDetail:
|
||||
fake_enricher = AsyncMock(return_value=mock_geo)
|
||||
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.get_ip_detail(
|
||||
|
||||
@@ -7,7 +7,7 @@ from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.ban import ActiveBanListResponse
|
||||
from app.models.ban import ActiveBanListResponse, JailBannedIpsResponse
|
||||
from app.models.jail import JailDetailResponse, JailListResponse
|
||||
from app.services import jail_service
|
||||
from app.services.jail_service import JailNotFoundError, JailOperationError
|
||||
@@ -184,10 +184,90 @@ class TestListJails:
|
||||
with patch("app.services.jail_service.Fail2BanClient", _FailClient), pytest.raises(Fail2BanConnectionError):
|
||||
await jail_service.list_jails(_SOCKET)
|
||||
|
||||
async def test_backend_idle_commands_unsupported(self) -> None:
|
||||
"""list_jails handles unsupported backend and idle commands gracefully.
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_jail
|
||||
# ---------------------------------------------------------------------------
|
||||
When the fail2ban daemon does not support get ... backend/idle commands,
|
||||
list_jails should not send them, avoiding "Invalid command" errors in the
|
||||
fail2ban log.
|
||||
"""
|
||||
# Reset the capability cache to test detection.
|
||||
jail_service._backend_cmd_supported = None
|
||||
|
||||
responses = {
|
||||
"status": _make_global_status("sshd"),
|
||||
"status|sshd|short": _make_short_status(),
|
||||
# Capability probe: get backend fails (command not supported).
|
||||
"get|sshd|backend": (1, Exception("Invalid command (no get action or not yet implemented)")),
|
||||
# Subsequent gets should still work.
|
||||
"get|sshd|bantime": (0, 600),
|
||||
"get|sshd|findtime": (0, 600),
|
||||
"get|sshd|maxretry": (0, 5),
|
||||
}
|
||||
with _patch_client(responses):
|
||||
result = await jail_service.list_jails(_SOCKET)
|
||||
|
||||
# Verify the result uses the default values for backend and idle.
|
||||
jail = result.jails[0]
|
||||
assert jail.backend == "polling" # default
|
||||
assert jail.idle is False # default
|
||||
# Capability should now be cached as False.
|
||||
assert jail_service._backend_cmd_supported is False
|
||||
|
||||
async def test_backend_idle_commands_supported(self) -> None:
|
||||
"""list_jails detects and sends backend/idle commands when supported."""
|
||||
# Reset the capability cache to test detection.
|
||||
jail_service._backend_cmd_supported = None
|
||||
|
||||
responses = {
|
||||
"status": _make_global_status("sshd"),
|
||||
"status|sshd|short": _make_short_status(),
|
||||
# Capability probe: get backend succeeds.
|
||||
"get|sshd|backend": (0, "systemd"),
|
||||
# All other commands.
|
||||
"get|sshd|bantime": (0, 600),
|
||||
"get|sshd|findtime": (0, 600),
|
||||
"get|sshd|maxretry": (0, 5),
|
||||
"get|sshd|idle": (0, True),
|
||||
}
|
||||
with _patch_client(responses):
|
||||
result = await jail_service.list_jails(_SOCKET)
|
||||
|
||||
# Verify real values are returned.
|
||||
jail = result.jails[0]
|
||||
assert jail.backend == "systemd" # real value
|
||||
assert jail.idle is True # real value
|
||||
# Capability should now be cached as True.
|
||||
assert jail_service._backend_cmd_supported is True
|
||||
|
||||
async def test_backend_idle_commands_cached_after_first_probe(self) -> None:
|
||||
"""list_jails caches capability result and reuses it across polling cycles."""
|
||||
# Reset the capability cache.
|
||||
jail_service._backend_cmd_supported = None
|
||||
|
||||
responses = {
|
||||
"status": _make_global_status("sshd, nginx"),
|
||||
# Probes happen once per polling cycle (for the first jail listed).
|
||||
"status|sshd|short": _make_short_status(),
|
||||
"status|nginx|short": _make_short_status(),
|
||||
# Capability probe: backend is unsupported.
|
||||
"get|sshd|backend": (1, Exception("Invalid command")),
|
||||
# Subsequent jails do not trigger another probe; they use cached result.
|
||||
# (The mock doesn't have get|nginx|backend because it shouldn't be called.)
|
||||
"get|sshd|bantime": (0, 600),
|
||||
"get|sshd|findtime": (0, 600),
|
||||
"get|sshd|maxretry": (0, 5),
|
||||
"get|nginx|bantime": (0, 600),
|
||||
"get|nginx|findtime": (0, 600),
|
||||
"get|nginx|maxretry": (0, 5),
|
||||
}
|
||||
with _patch_client(responses):
|
||||
result = await jail_service.list_jails(_SOCKET)
|
||||
|
||||
# Both jails should return default values (cached result is False).
|
||||
for jail in result.jails:
|
||||
assert jail.backend == "polling"
|
||||
assert jail.idle is False
|
||||
|
||||
|
||||
class TestGetJail:
|
||||
@@ -339,6 +419,55 @@ class TestJailControls:
|
||||
_SOCKET, include_jails=["new"], exclude_jails=["old"]
|
||||
)
|
||||
|
||||
async def test_reload_all_unknown_jail_raises_jail_not_found(self) -> None:
|
||||
"""reload_all detects UnknownJailException and raises JailNotFoundError.
|
||||
|
||||
When fail2ban cannot load a jail due to invalid configuration (e.g.,
|
||||
missing logpath), it raises UnknownJailException during reload. This
|
||||
test verifies that reload_all detects this and re-raises as
|
||||
JailNotFoundError instead of the generic JailOperationError.
|
||||
"""
|
||||
with _patch_client(
|
||||
{
|
||||
"status": _make_global_status("sshd"),
|
||||
"reload|--all|[]|[['start', 'airsonic-auth'], ['start', 'sshd']]": (
|
||||
1,
|
||||
Exception("UnknownJailException('airsonic-auth')"),
|
||||
),
|
||||
}
|
||||
), pytest.raises(jail_service.JailNotFoundError) as exc_info:
|
||||
await jail_service.reload_all(
|
||||
_SOCKET, include_jails=["airsonic-auth"]
|
||||
)
|
||||
assert exc_info.value.name == "airsonic-auth"
|
||||
|
||||
async def test_restart_sends_stop_command(self) -> None:
|
||||
"""restart() sends the ['stop'] command to the fail2ban socket."""
|
||||
with _patch_client({"stop": (0, None)}):
|
||||
await jail_service.restart(_SOCKET) # should not raise
|
||||
|
||||
async def test_restart_operation_error_raises(self) -> None:
|
||||
"""restart() raises JailOperationError when fail2ban rejects the stop."""
|
||||
with _patch_client({"stop": (1, Exception("cannot stop"))}), pytest.raises(
|
||||
JailOperationError
|
||||
):
|
||||
await jail_service.restart(_SOCKET)
|
||||
|
||||
async def test_restart_connection_error_propagates(self) -> None:
|
||||
"""restart() propagates Fail2BanConnectionError when socket is unreachable."""
|
||||
|
||||
class _FailClient:
|
||||
def __init__(self, **_kw: Any) -> None:
|
||||
self.send = AsyncMock(
|
||||
side_effect=Fail2BanConnectionError("no socket", _SOCKET)
|
||||
)
|
||||
|
||||
with (
|
||||
patch("app.services.jail_service.Fail2BanClient", _FailClient),
|
||||
pytest.raises(Fail2BanConnectionError),
|
||||
):
|
||||
await jail_service.restart(_SOCKET)
|
||||
|
||||
async def test_start_not_found_raises(self) -> None:
|
||||
"""start_jail raises JailNotFoundError for unknown jail."""
|
||||
with _patch_client({"start|ghost": (1, Exception("Unknown jail: 'ghost'"))}), pytest.raises(JailNotFoundError):
|
||||
@@ -506,7 +635,7 @@ class TestGetActiveBans:
|
||||
|
||||
async def test_http_session_triggers_lookup_batch(self) -> None:
|
||||
"""When http_session is provided, geo_service.lookup_batch is used."""
|
||||
from app.services.geo_service import GeoInfo
|
||||
from app.models.geo import GeoInfo
|
||||
|
||||
responses = {
|
||||
"status": _make_global_status("sshd"),
|
||||
@@ -516,17 +645,14 @@ class TestGetActiveBans:
|
||||
),
|
||||
}
|
||||
mock_geo = {"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")}
|
||||
mock_batch = AsyncMock(return_value=mock_geo)
|
||||
|
||||
with (
|
||||
_patch_client(responses),
|
||||
patch(
|
||||
"app.services.geo_service.lookup_batch",
|
||||
new=AsyncMock(return_value=mock_geo),
|
||||
) as mock_batch,
|
||||
):
|
||||
with _patch_client(responses):
|
||||
mock_session = AsyncMock()
|
||||
result = await jail_service.get_active_bans(
|
||||
_SOCKET, http_session=mock_session
|
||||
_SOCKET,
|
||||
http_session=mock_session,
|
||||
geo_batch_lookup=mock_batch,
|
||||
)
|
||||
|
||||
mock_batch.assert_awaited_once()
|
||||
@@ -543,16 +669,14 @@ class TestGetActiveBans:
|
||||
),
|
||||
}
|
||||
|
||||
with (
|
||||
_patch_client(responses),
|
||||
patch(
|
||||
"app.services.geo_service.lookup_batch",
|
||||
new=AsyncMock(side_effect=RuntimeError("geo down")),
|
||||
),
|
||||
):
|
||||
failing_batch = AsyncMock(side_effect=RuntimeError("geo down"))
|
||||
|
||||
with _patch_client(responses):
|
||||
mock_session = AsyncMock()
|
||||
result = await jail_service.get_active_bans(
|
||||
_SOCKET, http_session=mock_session
|
||||
_SOCKET,
|
||||
http_session=mock_session,
|
||||
geo_batch_lookup=failing_batch,
|
||||
)
|
||||
|
||||
assert result.total == 1
|
||||
@@ -560,7 +684,7 @@ class TestGetActiveBans:
|
||||
|
||||
async def test_geo_enricher_still_used_without_http_session(self) -> None:
|
||||
"""Legacy geo_enricher is still called when http_session is not provided."""
|
||||
from app.services.geo_service import GeoInfo
|
||||
from app.models.geo import GeoInfo
|
||||
|
||||
responses = {
|
||||
"status": _make_global_status("sshd"),
|
||||
@@ -700,3 +824,199 @@ class TestUnbanAllIps:
|
||||
pytest.raises(Fail2BanConnectionError),
|
||||
):
|
||||
await jail_service.unban_all_ips(_SOCKET)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_jail_banned_ips
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
#: A raw ban entry string in the format produced by fail2ban --with-time.
|
||||
_BAN_ENTRY_1 = "1.2.3.4\t2025-01-01 10:00:00 + 600 = 2025-01-01 10:10:00"
|
||||
_BAN_ENTRY_2 = "5.6.7.8\t2025-01-01 11:00:00 + 600 = 2025-01-01 11:10:00"
|
||||
_BAN_ENTRY_3 = "9.10.11.12\t2025-01-01 12:00:00 + 600 = 2025-01-01 12:10:00"
|
||||
|
||||
|
||||
def _banned_ips_responses(jail: str = "sshd", entries: list[str] | None = None) -> dict[str, Any]:
|
||||
"""Build mock responses for get_jail_banned_ips tests."""
|
||||
if entries is None:
|
||||
entries = [_BAN_ENTRY_1, _BAN_ENTRY_2]
|
||||
return {
|
||||
f"status|{jail}|short": _make_short_status(),
|
||||
f"get|{jail}|banip|--with-time": (0, entries),
|
||||
}
|
||||
|
||||
|
||||
class TestGetJailBannedIps:
|
||||
"""Unit tests for :func:`~app.services.jail_service.get_jail_banned_ips`."""
|
||||
|
||||
async def test_returns_jail_banned_ips_response(self) -> None:
|
||||
"""get_jail_banned_ips returns a JailBannedIpsResponse."""
|
||||
with _patch_client(_banned_ips_responses()):
|
||||
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd")
|
||||
|
||||
assert isinstance(result, JailBannedIpsResponse)
|
||||
|
||||
async def test_total_reflects_all_entries(self) -> None:
|
||||
"""total equals the number of parsed ban entries."""
|
||||
with _patch_client(_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3])):
|
||||
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd")
|
||||
|
||||
assert result.total == 3
|
||||
|
||||
async def test_page_1_returns_first_n_items(self) -> None:
|
||||
"""page=1 with page_size=2 returns the first two entries."""
|
||||
with _patch_client(
|
||||
_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3])
|
||||
):
|
||||
result = await jail_service.get_jail_banned_ips(
|
||||
_SOCKET, "sshd", page=1, page_size=2
|
||||
)
|
||||
|
||||
assert len(result.items) == 2
|
||||
assert result.items[0].ip == "1.2.3.4"
|
||||
assert result.items[1].ip == "5.6.7.8"
|
||||
assert result.total == 3
|
||||
|
||||
async def test_page_2_returns_remaining_items(self) -> None:
|
||||
"""page=2 with page_size=2 returns the third entry."""
|
||||
with _patch_client(
|
||||
_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3])
|
||||
):
|
||||
result = await jail_service.get_jail_banned_ips(
|
||||
_SOCKET, "sshd", page=2, page_size=2
|
||||
)
|
||||
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].ip == "9.10.11.12"
|
||||
|
||||
async def test_page_beyond_last_returns_empty_items(self) -> None:
|
||||
"""Requesting a page past the end returns an empty items list."""
|
||||
with _patch_client(_banned_ips_responses()):
|
||||
result = await jail_service.get_jail_banned_ips(
|
||||
_SOCKET, "sshd", page=99, page_size=25
|
||||
)
|
||||
|
||||
assert result.items == []
|
||||
assert result.total == 2
|
||||
|
||||
async def test_search_filter_narrows_results(self) -> None:
|
||||
"""search parameter filters entries by IP substring."""
|
||||
with _patch_client(_banned_ips_responses()):
|
||||
result = await jail_service.get_jail_banned_ips(
|
||||
_SOCKET, "sshd", search="1.2.3"
|
||||
)
|
||||
|
||||
assert result.total == 1
|
||||
assert result.items[0].ip == "1.2.3.4"
|
||||
|
||||
async def test_search_filter_case_insensitive(self) -> None:
|
||||
"""search filter is case-insensitive."""
|
||||
entries = ["192.168.0.1\t2025-01-01 10:00:00 + 600 = 2025-01-01 10:10:00"]
|
||||
with _patch_client(_banned_ips_responses(entries=entries)):
|
||||
result = await jail_service.get_jail_banned_ips(
|
||||
_SOCKET, "sshd", search="192.168"
|
||||
)
|
||||
|
||||
assert result.total == 1
|
||||
|
||||
async def test_search_no_match_returns_empty(self) -> None:
|
||||
"""search that matches nothing returns empty items and total=0."""
|
||||
with _patch_client(_banned_ips_responses()):
|
||||
result = await jail_service.get_jail_banned_ips(
|
||||
_SOCKET, "sshd", search="999.999"
|
||||
)
|
||||
|
||||
assert result.total == 0
|
||||
assert result.items == []
|
||||
|
||||
async def test_empty_ban_list_returns_total_zero(self) -> None:
|
||||
"""get_jail_banned_ips handles an empty ban list gracefully."""
|
||||
responses = {
|
||||
"status|sshd|short": _make_short_status(),
|
||||
"get|sshd|banip|--with-time": (0, []),
|
||||
}
|
||||
with _patch_client(responses):
|
||||
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd")
|
||||
|
||||
assert result.total == 0
|
||||
assert result.items == []
|
||||
|
||||
async def test_page_size_clamped_to_max(self) -> None:
|
||||
"""page_size values above 100 are silently clamped to 100."""
|
||||
entries = [f"10.0.0.{i}\t2025-01-01 10:00:00 + 600 = 2025-01-01 10:10:00" for i in range(1, 101)]
|
||||
responses = {
|
||||
"status|sshd|short": _make_short_status(),
|
||||
"get|sshd|banip|--with-time": (0, entries),
|
||||
}
|
||||
with _patch_client(responses):
|
||||
result = await jail_service.get_jail_banned_ips(
|
||||
_SOCKET, "sshd", page=1, page_size=200
|
||||
)
|
||||
|
||||
assert len(result.items) <= 100
|
||||
|
||||
async def test_geo_enrichment_called_for_page_slice_only(self) -> None:
|
||||
"""Geo enrichment is requested only for IPs in the current page."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from app.services import geo_service
|
||||
|
||||
http_session = MagicMock()
|
||||
geo_enrichment_ips: list[list[str]] = []
|
||||
|
||||
async def _mock_lookup_batch(
|
||||
ips: list[str], _session: Any, **_kw: Any
|
||||
) -> dict[str, Any]:
|
||||
geo_enrichment_ips.append(list(ips))
|
||||
return {}
|
||||
|
||||
with (
|
||||
_patch_client(
|
||||
_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3])
|
||||
),
|
||||
patch.object(geo_service, "lookup_batch", side_effect=_mock_lookup_batch),
|
||||
):
|
||||
result = await jail_service.get_jail_banned_ips(
|
||||
_SOCKET,
|
||||
"sshd",
|
||||
page=1,
|
||||
page_size=2,
|
||||
http_session=http_session,
|
||||
geo_batch_lookup=geo_service.lookup_batch,
|
||||
)
|
||||
|
||||
# Only the 2-IP page slice should be passed to geo enrichment.
|
||||
assert len(geo_enrichment_ips) == 1
|
||||
assert len(geo_enrichment_ips[0]) == 2
|
||||
assert result.total == 3
|
||||
|
||||
async def test_unknown_jail_raises_jail_not_found_error(self) -> None:
|
||||
"""get_jail_banned_ips raises JailNotFoundError for unknown jail."""
|
||||
# Simulate fail2ban returning an "unknown jail" error.
|
||||
class _FakeClient:
|
||||
def __init__(self, **_kw: Any) -> None:
|
||||
pass
|
||||
|
||||
async def send(self, command: list[Any]) -> Any:
|
||||
raise ValueError("Unknown jail: ghost")
|
||||
|
||||
with (
|
||||
patch("app.services.jail_service.Fail2BanClient", _FakeClient),
|
||||
pytest.raises(JailNotFoundError),
|
||||
):
|
||||
await jail_service.get_jail_banned_ips(_SOCKET, "ghost")
|
||||
|
||||
async def test_connection_error_propagates(self) -> None:
|
||||
"""get_jail_banned_ips propagates Fail2BanConnectionError."""
|
||||
|
||||
class _FailClient:
|
||||
def __init__(self, **_kw: Any) -> None:
|
||||
self.send = AsyncMock(
|
||||
side_effect=Fail2BanConnectionError("no socket", _SOCKET)
|
||||
)
|
||||
|
||||
with (
|
||||
patch("app.services.jail_service.Fail2BanClient", _FailClient),
|
||||
pytest.raises(Fail2BanConnectionError),
|
||||
):
|
||||
await jail_service.get_jail_banned_ips(_SOCKET, "sshd")
|
||||
|
||||
@@ -63,6 +63,16 @@ class TestGetSettings:
|
||||
assert result.settings.log_target == "/var/log/fail2ban.log"
|
||||
assert result.settings.db_purge_age == 86400
|
||||
assert result.settings.db_max_matches == 10
|
||||
assert result.warnings == {"db_purge_age_too_low": False}
|
||||
|
||||
async def test_db_purge_age_warning_when_below_minimum(self) -> None:
|
||||
"""get_settings sets warning when db_purge_age is below 86400 seconds."""
|
||||
responses = {**_DEFAULT_RESPONSES, "get|dbpurgeage": (0, 3600)}
|
||||
with _patch_client(responses):
|
||||
result = await server_service.get_settings(_SOCKET)
|
||||
|
||||
assert result.settings.db_purge_age == 3600
|
||||
assert result.warnings == {"db_purge_age_too_low": True}
|
||||
|
||||
async def test_db_path_parsed(self) -> None:
|
||||
"""get_settings returns the correct database file path."""
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user